diff --git a/docs/workflows/inference.md b/docs/workflows/inference.md index 6f01be5..4988f28 100644 --- a/docs/workflows/inference.md +++ b/docs/workflows/inference.md @@ -218,10 +218,6 @@ model.set_lora_strength(0.0) # Disable without unloading # With multiple LoRAs, target by index: model.set_lora_strength(1.0, lora_index=0) model.set_lora_strength(0.3, lora_index=1) - -# Target only the Diffusion Transformer backbone or conditioner independently: -model.set_lora_strength(1.0, target="dit") -model.set_lora_strength(0.0, target="conditioner") ``` For full details on LoRA training see [LoRA Training](lora.md). diff --git a/pyproject.toml b/pyproject.toml index 8b2c35a..6a48191 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,4 +66,5 @@ exclude = [ "stable_audio_3/interface", "stable_audio_3/data", "stable_audio_3/training", + "optimized", ] diff --git a/scripts/optimized/.gitkeep b/scripts/optimized/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/stable_audio_3/model.py b/stable_audio_3/model.py index d4bd8a9..ec44811 100644 --- a/stable_audio_3/model.py +++ b/stable_audio_3/model.py @@ -28,7 +28,7 @@ def __init__(self, model, model_config, device, model_half): torch.backends.cudnn.benchmark = False @staticmethod - def from_pretrained(model_name_or_path, device=None, model_half=True): + def from_pretrained(model_name, device=None, model_half=True): # Load the model and any necessary components here if device is None and torch.cuda.is_available(): device = "cuda" @@ -38,18 +38,18 @@ def from_pretrained(model_name_or_path, device=None, model_half=True): device = "cpu" if not torch.cuda.is_available(): - if model_name_or_path in ("medium", "medium-base"): + if model_name in ("medium", "medium-base"): print( - f"Warning: You are loading the {model_name_or_path} model without a GPU. This model is not designed to run on cpu" + f"Warning: You are loading the {model_name} model without a GPU. This model is not designed to run on cpu" ) model_half = False - if model_name_or_path not in all_models: + if model_name not in all_models: raise ValueError( - f"Unknown model '{model_name_or_path}'. Valid models: {list(all_models)}" + f"Unknown model '{model_name}'. Valid models: {list(all_models)}" ) - model_cfg = all_models[model_name_or_path] + model_cfg = all_models[model_name] local_config, local_ckpt = model_cfg.resolve() with open(local_config) as f: model_config = json.load(f) diff --git a/stable_audio_3/models/dit.py b/stable_audio_3/models/dit.py index c9dd05e..5d4d774 100644 --- a/stable_audio_3/models/dit.py +++ b/stable_audio_3/models/dit.py @@ -319,11 +319,11 @@ def apg_project(self, v0, v1, padding_mask=None): If provided, only valid positions contribute to the projection. """ dtype = v0.dtype - v0, v1 = v0.double(), v1.double() + v0, v1 = v0.float(), v1.float() if padding_mask is not None: # Expand mask to match tensor shape: (B, T) -> (B, 1, T) - mask = padding_mask.unsqueeze(1).double() + mask = padding_mask.unsqueeze(1).float() # Zero out padding positions for projection computation v0_masked = v0 * mask v1_masked = v1 * mask