diff --git a/README.md b/README.md index 368738d..5126990 100644 --- a/README.md +++ b/README.md @@ -106,6 +106,12 @@ MAX_JOBS=8 \ - `TORCH_CUDA_ARCH_LIST` — set to your GPU's compute capability: `8.0` (A100), `8.6` (A10/RTX 3090), `8.9` (L4/RTX 4090), `9.0` (H100/H200) - `MAX_JOBS` — number of parallel compile jobs; 4–8 is typical, reduce if you run out of RAM during compilation +**Note:** `flash-attn` is not declared in `pyproject.toml`, so a plain `uv sync` will remove it. Use `uv sync --inexact` to install/update dependencies without removing packages that aren't in the lockfile: + +```bash +uv sync --inexact +``` + ## Quick Start Launch the Gradio UI: @@ -199,6 +205,37 @@ audio_out = ae.decode(latents) See [Autoencoder Workflows](docs/workflows/autoencoder.md) for encoding batches, chunked processing, and pre-encoding datasets for LoRA training. +## CLI + +A `stable-audio` cli is included for running generation without writing any Python. + +**Text-to-audio:** +```bash +stable-audio --model small-music -p "lo-fi hip hop beat, 90 BPM" --duration 30 -o beat.wav +``` + +**Audio-to-audio** — restyle an existing recording: +```bash +stable-audio -p "bossa nova bassline" --init-audio input.wav --init-noise-level 0.8 -o out.wav +``` + +**Inpainting** — regenerate a region while keeping the rest: +```bash +stable-audio -p "punchy kick drum fill" --inpaint-audio input.wav --inpaint-start 4 --inpaint-end 8 -o out.wav +``` + +**Continuation** — extend a clip beyond its original length: +```bash +stable-audio -p "dreamy synth outro" --inpaint-audio input.wav --inpaint-start 10 --inpaint-end 30 --duration 30 -o out.wav +``` + +**With a LoRA:** +```bash +stable-audio -p "orchestral strings" --lora-ckpt-path my_lora.safetensors --lora-strength 0.8 -o out.wav +``` + +Run `stable-audio --help` for the full list of flags. + ## Hardware Support Stable Audio 3 scales from a laptop to a GPU server. diff --git a/pyproject.toml b/pyproject.toml index 6422d6f..cddc3ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,9 @@ torchaudio = [ { index = "pytorch-cu126", marker = "sys_platform == 'linux' and platform_machine == 'x86_64'" } ] +[project.scripts] +stable-audio = "stable_audio_3.cli:main" + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/stable_audio_3/cli.py b/stable_audio_3/cli.py new file mode 100644 index 0000000..7b43ac0 --- /dev/null +++ b/stable_audio_3/cli.py @@ -0,0 +1,292 @@ +""" +stable-audio — command-line interface for Stable Audio 3. + +Basic usage:: + + stable-audio --model small-music -p "lo-fi hip hop beat, 90 BPM" --duration 30 -o beat.wav + +""" + +import argparse +import os +import torch +import torchaudio + +from stable_audio_3 import StableAudioModel + + +def _save_output(audio: torch.Tensor, sample_rate: int, output: str, batch_size: int): + """Save generated audio tensor(s) to disk.""" + base, ext = os.path.splitext(output) + if not ext: + ext = ".wav" + for i in range(batch_size): + path = f"{base}_{i}{ext}" if batch_size > 1 else f"{base}{ext}" + torchaudio.save(path, audio[i].cpu(), sample_rate) + print(f"Saved: {path}") + + +def main(): + parser = argparse.ArgumentParser( + prog="stable-audio", + description="Stable Audio 3 — CLI for text-to-audio, audio-to-audio, and inpainting", + ) + + # Model + parser.add_argument( + "--model", + default="medium", + choices=[ + "medium", + "small-music", + "small-sfx", + "medium-base", + "small-music-base", + "small-sfx-base", + ], + help="Model to load (default: medium)", + ) + parser.add_argument( + "--device", + default=None, + help="Device: cuda / mps / cpu (auto-detected if omitted)", + ) + parser.add_argument( + "--no-half", action="store_true", help="Disable half-precision (fp16) on CUDA" + ) + + # Generation + parser.add_argument( + "-p", + "--prompt", + required=True, + nargs="+", + help="Text prompt(s). Pass multiple for per-batch prompts", + ) + parser.add_argument( + "--negative-prompt", nargs="+", default=None, help="Negative prompt(s)" + ) + parser.add_argument( + "--duration", + type=float, + nargs="+", + default=[120.0], + help="Duration in seconds (default: 120). Pass multiple for per-batch durations", + ) + parser.add_argument( + "--steps", type=int, default=8, help="Diffusion steps (default: 8)" + ) + parser.add_argument( + "--cfg-scale", + type=float, + default=1.0, + help="CFG scale (default: 1.0; try 7.0 for base models)", + ) + parser.add_argument( + "--seed", type=int, default=-1, help="Random seed (-1 = random, default: -1)" + ) + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Batch size (default: inferred from number of prompts, or 1)", + ) + parser.add_argument( + "-o", + "--output", + default="output.wav", + help="Output file path (default: output.wav)", + ) + + # Audio-to-Audio + parser.add_argument( + "--init-audio", + default=None, + metavar="PATH", + help="Source audio file for audio-to-audio generation", + ) + parser.add_argument( + "--init-noise-level", + type=float, + default=0.9, + help="Noise level for audio-to-audio (0.0–1.0, default: 0.9)", + ) + + # Inpainting / Continuation + parser.add_argument( + "--inpaint-audio", + default=None, + metavar="PATH", + help="Source audio file for inpainting or continuation", + ) + parser.add_argument( + "--inpaint-start", + type=float, + action="append", + dest="inpaint_starts", + metavar="SECONDS", + help="Start of inpaint region in seconds. Repeat for multiple regions.", + ) + parser.add_argument( + "--inpaint-end", + type=float, + action="append", + dest="inpaint_ends", + metavar="SECONDS", + help="End of inpaint region in seconds. Repeat for multiple regions.", + ) + + # Chunked decode + decode_group = parser.add_mutually_exclusive_group() + decode_group.add_argument( + "--chunked-decode", + action="store_true", + default=None, + help="Force chunked decoding on", + ) + decode_group.add_argument( + "--no-chunked-decode", + action="store_true", + default=None, + help="Force chunked decoding off", + ) + + # LoRA + parser.add_argument( + "--lora-ckpt-path", + action="append", + dest="loras", + metavar="PATH", + help="LoRA checkpoint path. Repeat to stack multiple LoRAs.", + ) + parser.add_argument( + "--lora-strength", + type=float, + default=None, + help="LoRA strength (applied to all LoRAs)", + ) + parser.add_argument( + "--lora-index", + type=int, + default=None, + help="Target a specific LoRA index when setting strength", + ) + + args = parser.parse_args() + + # --- Validate inpaint args --- + if (args.inpaint_starts is None) != (args.inpaint_ends is None): + parser.error("--inpaint-start and --inpaint-end must both be provided together") + if args.inpaint_starts and len(args.inpaint_starts) != len(args.inpaint_ends): + parser.error( + "--inpaint-start and --inpaint-end must be specified the same number of times" + ) + if args.inpaint_starts and not args.inpaint_audio: + parser.error("--inpaint-start/--inpaint-end require --inpaint-audio") + if args.inpaint_audio and not args.inpaint_starts: + parser.error("--inpaint-audio requires --inpaint-start and --inpaint-end") + + # --- Resolve batch size --- + n_prompts = len(args.prompt) + if args.batch_size is None: + batch_size = n_prompts + elif n_prompts > 1 and args.batch_size != n_prompts: + parser.error( + f"--batch-size {args.batch_size} does not match the number of prompts " + f"({n_prompts}); omit --batch-size to have it inferred automatically" + ) + else: + batch_size = args.batch_size + + # --- Validate list-flag lengths against batch size --- + if ( + args.negative_prompt + and len(args.negative_prompt) > 1 + and len(args.negative_prompt) != batch_size + ): + parser.error( + f"Got {len(args.negative_prompt)} --negative-prompt values but batch size is {batch_size}" + ) + if len(args.duration) > 1 and len(args.duration) != batch_size: + parser.error( + f"Got {len(args.duration)} --duration values but batch size is {batch_size}" + ) + + # --- Build scalar / list args --- + prompt = args.prompt[0] if len(args.prompt) == 1 else args.prompt + negative_prompt = None + if args.negative_prompt: + negative_prompt = ( + args.negative_prompt[0] + if len(args.negative_prompt) == 1 + else args.negative_prompt + ) + duration = args.duration[0] if len(args.duration) == 1 else args.duration + + # --- chunked_decode flag --- + chunked_decode = None + if args.chunked_decode: + chunked_decode = True + elif args.no_chunked_decode: + chunked_decode = False + + # --- Load model --- + print(f"Loading model '{args.model}'…") + model = StableAudioModel.from_pretrained( + args.model, device=args.device, model_half=not args.no_half + ) + + # --- LoRA --- + if args.loras: + print(f"Loading LoRA(s): {args.loras}") + model.load_lora(args.loras) + if args.lora_strength is not None: + model.set_lora_strength(args.lora_strength, lora_index=args.lora_index) + + # --- Load audio inputs --- + # torchaudio.load returns (waveform, sample_rate); model.generate expects (sample_rate, waveform) + init_audio = None + if args.init_audio: + waveform, sr = torchaudio.load(args.init_audio) + init_audio = (sr, waveform) + + inpaint_audio = None + if args.inpaint_audio: + waveform, sr = torchaudio.load(args.inpaint_audio) + inpaint_audio = (sr, waveform) + + inpaint_start = None + inpaint_end = None + if args.inpaint_starts: + inpaint_start = ( + args.inpaint_starts[0] + if len(args.inpaint_starts) == 1 + else args.inpaint_starts + ) + inpaint_end = ( + args.inpaint_ends[0] if len(args.inpaint_ends) == 1 else args.inpaint_ends + ) + + # --- Generate --- + print("Generating…") + audio = model.generate( + prompt=prompt, + negative_prompt=negative_prompt, + duration=duration, + steps=args.steps, + cfg_scale=args.cfg_scale, + seed=args.seed, + batch_size=batch_size, + init_audio=init_audio, + init_noise_level=args.init_noise_level, + inpaint_audio=inpaint_audio, + inpaint_mask_start_seconds=inpaint_start, + inpaint_mask_end_seconds=inpaint_end, + chunked_decode=chunked_decode, + ) + + _save_output(audio, model.model.sample_rate, args.output, batch_size) + + +if __name__ == "__main__": + main() diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..d865077 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,462 @@ +""" +Tests for the stable-audio CLI (stable_audio_3/cli.py). + +These are unit tests: the model and audio I/O are mocked so they run without +downloading weights or touching the GPU. Model behaviour is covered separately +in test_inference.py; here we verify that every CLI flag is wired correctly. +""" + +import sys +from unittest.mock import MagicMock, patch + +import pytest +import torch + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +SAMPLE_RATE = 44100 +CHANNELS = 2 +FAKE_AUDIO_PATH = "some/audio.wav" +_FAKE_WAVEFORM = torch.zeros(CHANNELS, SAMPLE_RATE * 5) +_FAKE_LOAD_RESULT = (_FAKE_WAVEFORM, SAMPLE_RATE) + + +def _fake_audio(batch: int = 1, duration: float = 5.0) -> torch.Tensor: + return torch.zeros(batch, CHANNELS, int(SAMPLE_RATE * duration)) + + +def _make_model_mock(batch: int = 1, duration: float = 5.0): + model = MagicMock() + model.model.sample_rate = SAMPLE_RATE + model.generate.return_value = _fake_audio(batch, duration) + return model + + +def _run(argv: list[str]): + """Invoke cli.main() with the given argument list.""" + from stable_audio_3.cli import main + + with patch.object(sys, "argv", ["stable-audio"] + argv): + main() + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def mock_torchaudio_save(): + with patch("stable_audio_3.cli.torchaudio.save") as mock: + yield mock + + +@pytest.fixture(autouse=True) +def mock_torchaudio_load(): + with patch("stable_audio_3.cli.torchaudio.load", return_value=_FAKE_LOAD_RESULT): + yield + + +@pytest.fixture() +def mock_model(): + model = _make_model_mock() + with patch( + "stable_audio_3.cli.StableAudioModel.from_pretrained", return_value=model + ): + yield model + + +# --------------------------------------------------------------------------- +# Model loading +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "model_name", + [ + "medium", + "small-music", + "small-sfx", + "medium-base", + "small-music-base", + "small-sfx-base", + ], +) +def test_model_selection(mock_model, model_name): + _run(["--model", model_name, "-p", "test"]) + from stable_audio_3.cli import StableAudioModel + + StableAudioModel.from_pretrained.assert_called_once_with( + model_name, device=None, model_half=True + ) + + +def test_device_flag(mock_model): + _run(["--device", "cpu", "-p", "test"]) + from stable_audio_3.cli import StableAudioModel + + _, kwargs = StableAudioModel.from_pretrained.call_args + assert kwargs["device"] == "cpu" + + +def test_no_half_flag(mock_model): + _run(["--no-half", "-p", "test"]) + from stable_audio_3.cli import StableAudioModel + + _, kwargs = StableAudioModel.from_pretrained.call_args + assert kwargs["model_half"] is False + + +# --------------------------------------------------------------------------- +# Text-to-audio +# --------------------------------------------------------------------------- + + +def test_text_to_audio_defaults(mock_model): + _run(["-p", "ocean waves"]) + mock_model.generate.assert_called_once() + kwargs = mock_model.generate.call_args.kwargs + assert kwargs["prompt"] == "ocean waves" + assert kwargs["duration"] == 120.0 + assert kwargs["steps"] == 8 + assert kwargs["cfg_scale"] == 1.0 + assert kwargs["seed"] == -1 + assert kwargs["batch_size"] == 1 + assert kwargs["negative_prompt"] is None + assert kwargs["init_audio"] is None + assert kwargs["inpaint_audio"] is None + assert kwargs["chunked_decode"] is None + + +def test_generation_flags(mock_model): + _run( + [ + "-p", + "drums", + "--duration", + "20", + "--steps", + "50", + "--cfg-scale", + "7", + "--seed", + "42", + ] + ) + kwargs = mock_model.generate.call_args.kwargs + assert kwargs["duration"] == 20.0 + assert kwargs["steps"] == 50 + assert kwargs["cfg_scale"] == 7.0 + assert kwargs["seed"] == 42 + + +def test_negative_prompt(mock_model): + _run(["-p", "jazz", "--negative-prompt", "bad quality"]) + kwargs = mock_model.generate.call_args.kwargs + assert kwargs["negative_prompt"] == "bad quality" + + +# --------------------------------------------------------------------------- +# Output file saving +# --------------------------------------------------------------------------- + + +def test_output_single(mock_model, mock_torchaudio_save, tmp_path): + out = str(tmp_path / "out.wav") + _run(["-p", "test", "-o", out]) + assert mock_torchaudio_save.call_count == 1 + saved_path, saved_tensor, saved_sr = mock_torchaudio_save.call_args.args + assert saved_path == out + assert saved_sr == SAMPLE_RATE + assert torch.equal(saved_tensor, mock_model.generate.return_value[0].cpu()) + + +def test_output_batch_naming(mock_torchaudio_save, tmp_path): + model = _make_model_mock(batch=3) + with patch( + "stable_audio_3.cli.StableAudioModel.from_pretrained", return_value=model + ): + out = str(tmp_path / "out.wav") + _run(["-p", "a", "b", "c", "--batch-size", "3", "-o", out]) + + saved_paths = [c.args[0] for c in mock_torchaudio_save.call_args_list] + base = str(tmp_path / "out") + assert saved_paths == [f"{base}_0.wav", f"{base}_1.wav", f"{base}_2.wav"] + + +# --------------------------------------------------------------------------- +# Batch with per-batch prompts and durations +# --------------------------------------------------------------------------- + + +def test_batch_per_batch_prompts_infers_batch_size(): + model = _make_model_mock(batch=3) + with patch( + "stable_audio_3.cli.StableAudioModel.from_pretrained", return_value=model + ): + _run(["-p", "p1", "p2", "p3"]) # no --batch-size; should be auto-inferred as 3 + kwargs = model.generate.call_args.kwargs + assert kwargs["prompt"] == ["p1", "p2", "p3"] + assert kwargs["batch_size"] == 3 + + +def test_batch_explicit_batch_size_matches_prompts(): + model = _make_model_mock(batch=3) + with patch( + "stable_audio_3.cli.StableAudioModel.from_pretrained", return_value=model + ): + _run(["-p", "p1", "p2", "p3", "--batch-size", "3"]) + assert model.generate.call_args.kwargs["batch_size"] == 3 + + +def test_batch_prompt_count_mismatch_fails(): + with pytest.raises(SystemExit): + _run(["-p", "p1", "p2", "p3", "--batch-size", "2"]) + + +def test_batch_per_batch_durations(mock_torchaudio_save): + model = _make_model_mock(batch=2) + with patch( + "stable_audio_3.cli.StableAudioModel.from_pretrained", return_value=model + ): + _run(["-p", "p1", "p2", "--duration", "20", "30"]) + kwargs = model.generate.call_args.kwargs + assert kwargs["duration"] == [20.0, 30.0] + + +def test_batch_duration_count_mismatch_fails(): + with pytest.raises(SystemExit): + _run(["-p", "p1", "p2", "--duration", "20", "30", "40"]) + + +def test_batch_per_batch_negative_prompts(mock_torchaudio_save): + model = _make_model_mock(batch=2) + with patch( + "stable_audio_3.cli.StableAudioModel.from_pretrained", return_value=model + ): + _run(["-p", "p1", "p2", "--negative-prompt", "n1", "n2"]) + kwargs = model.generate.call_args.kwargs + assert kwargs["negative_prompt"] == ["n1", "n2"] + + +def test_batch_negative_prompt_count_mismatch_fails(): + with pytest.raises(SystemExit): + _run(["-p", "p1", "p2", "--negative-prompt", "n1", "n2", "n3"]) + + +# --------------------------------------------------------------------------- +# Audio-to-audio +# --------------------------------------------------------------------------- + + +def test_audio_to_audio(mock_model): + _run( + [ + "-p", + "bossa nova", + "--init-audio", + FAKE_AUDIO_PATH, + "--init-noise-level", + "0.7", + ] + ) + kwargs = mock_model.generate.call_args.kwargs + sr, waveform = kwargs["init_audio"] + assert sr == SAMPLE_RATE + assert torch.equal(waveform, _FAKE_WAVEFORM) + assert kwargs["init_noise_level"] == 0.7 + assert kwargs["inpaint_audio"] is None + + +def test_audio_to_audio_default_noise_level(mock_model): + _run(["-p", "test", "--init-audio", FAKE_AUDIO_PATH]) + assert mock_model.generate.call_args.kwargs["init_noise_level"] == 0.9 + + +# --------------------------------------------------------------------------- +# Inpainting +# --------------------------------------------------------------------------- + + +def test_inpaint_single_region(mock_model): + _run( + [ + "-p", + "kick drum", + "--inpaint-audio", + FAKE_AUDIO_PATH, + "--inpaint-start", + "2.0", + "--inpaint-end", + "5.0", + ] + ) + kwargs = mock_model.generate.call_args.kwargs + sr, waveform = kwargs["inpaint_audio"] + assert sr == SAMPLE_RATE + assert torch.equal(waveform, _FAKE_WAVEFORM) + assert kwargs["init_audio"] is None + assert kwargs["inpaint_mask_start_seconds"] == 2.0 + assert kwargs["inpaint_mask_end_seconds"] == 5.0 + + +def test_inpaint_multiple_regions(mock_model): + _run( + [ + "-p", + "fill", + "--inpaint-audio", + FAKE_AUDIO_PATH, + "--inpaint-start", + "1.0", + "--inpaint-start", + "8.0", + "--inpaint-end", + "4.0", + "--inpaint-end", + "12.0", + ] + ) + kwargs = mock_model.generate.call_args.kwargs + assert kwargs["inpaint_mask_start_seconds"] == [1.0, 8.0] + assert kwargs["inpaint_mask_end_seconds"] == [4.0, 12.0] + + +def test_inpaint_continuation(mock_model): + """Continuation: inpaint_start == length of source audio, duration > source length.""" + _run( + [ + "-p", + "continue", + "--inpaint-audio", + FAKE_AUDIO_PATH, + "--inpaint-start", + "5.0", + "--inpaint-end", + "15.0", + "--duration", + "15", + ] + ) + kwargs = mock_model.generate.call_args.kwargs + assert kwargs["inpaint_mask_start_seconds"] == 5.0 + assert kwargs["inpaint_mask_end_seconds"] == 15.0 + assert kwargs["duration"] == 15.0 + + +def test_inpaint_region_without_audio_fails(): + with pytest.raises(SystemExit): + _run(["-p", "test", "--inpaint-start", "2.0", "--inpaint-end", "5.0"]) + + +def test_inpaint_audio_without_region_fails(): + with pytest.raises(SystemExit): + _run(["-p", "test", "--inpaint-audio", FAKE_AUDIO_PATH]) + + +def test_inpaint_start_without_end_fails(): + with pytest.raises(SystemExit): + _run( + ["-p", "test", "--inpaint-audio", FAKE_AUDIO_PATH, "--inpaint-start", "2.0"] + ) + + +def test_inpaint_mismatched_region_count_fails(): + with pytest.raises(SystemExit): + _run( + [ + "-p", + "test", + "--inpaint-audio", + FAKE_AUDIO_PATH, + "--inpaint-start", + "1.0", + "--inpaint-start", + "5.0", + "--inpaint-end", + "3.0", + ] + ) + + +# --------------------------------------------------------------------------- +# Chunked decode +# --------------------------------------------------------------------------- + + +def test_chunked_decode_on(mock_model): + _run(["-p", "test", "--chunked-decode"]) + assert mock_model.generate.call_args.kwargs["chunked_decode"] is True + + +def test_chunked_decode_off(mock_model): + _run(["-p", "test", "--no-chunked-decode"]) + assert mock_model.generate.call_args.kwargs["chunked_decode"] is False + + +def test_chunked_decode_default(mock_model): + _run(["-p", "test"]) + assert mock_model.generate.call_args.kwargs["chunked_decode"] is None + + +def test_chunked_decode_flags_mutually_exclusive(): + with pytest.raises(SystemExit): + _run(["-p", "test", "--chunked-decode", "--no-chunked-decode"]) + + +# --------------------------------------------------------------------------- +# LoRA +# --------------------------------------------------------------------------- + + +def test_lora_single(mock_model): + _run(["-p", "test", "--lora-ckpt-path", "lora.safetensors"]) + mock_model.load_lora.assert_called_once_with(["lora.safetensors"]) + + +def test_lora_stacked(mock_model): + _run( + [ + "-p", + "test", + "--lora-ckpt-path", + "a.safetensors", + "--lora-ckpt-path", + "b.safetensors", + ] + ) + mock_model.load_lora.assert_called_once_with(["a.safetensors", "b.safetensors"]) + + +def test_lora_strength(mock_model): + _run( + ["-p", "test", "--lora-ckpt-path", "lora.safetensors", "--lora-strength", "0.5"] + ) + mock_model.set_lora_strength.assert_called_once_with(0.5, lora_index=None) + + +def test_lora_strength_with_index(mock_model): + _run( + [ + "-p", + "test", + "--lora-ckpt-path", + "a.safetensors", + "--lora-ckpt-path", + "b.safetensors", + "--lora-strength", + "0.3", + "--lora-index", + "1", + ] + ) + mock_model.set_lora_strength.assert_called_once_with(0.3, lora_index=1) + + +def test_no_lora_no_load(mock_model): + _run(["-p", "test"]) + mock_model.load_lora.assert_not_called() + mock_model.set_lora_strength.assert_not_called()