Skip to content
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ index-out-of-bounds = "ignore" # mypy is more permissive with tuple indexing
unresolved-attribute = "ignore" # mypy is more permissive with module attributes
redundant-cast = "ignore" # mypy doesn't warn about redundant casts
unsupported-operator = "ignore" # mypy supports | syntax with from __future__ import annotations
invalid-argument-type = "ignore" # mypy is more permissive with argument types
invalid-return-type = "ignore" # mypy is more permissive with return types
invalid-parameter-default = "ignore" # mypy is more permissive with parameter defaults
no-matching-overload = "ignore" # mypy is more permissive with overloads
Expand Down Expand Up @@ -197,7 +196,7 @@ dev = [
"pre-commit",
"twine",
"pyc-wheel",
"ruff",
"ruff>=0.15.3", # introduction of D420 rule
"numpydoc>=1.9.0",
"numpydoc-validation",
"pytest",
Expand Down
6 changes: 3 additions & 3 deletions src/pruna/algorithms/c_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_hyperparameters(self) -> list:
"weight_bits",
sequence=[8, 16],
default_value=16,
meta=dict(desc="Sets the number of bits to use for weight quantization."),
meta={"desc": "Sets the number of bits to use for weight quantization."},
),
]

Expand Down Expand Up @@ -392,7 +392,7 @@ def __call__(
The generated sequence.
"""
if type(x) is dict or isinstance(x, transformers.tokenization_utils_base.BatchEncoding):
x_tensor = x["input_ids"]
x_tensor = x["input_ids"] # type: ignore[invalid-argument-type]
else:
x_tensor = x
token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))] # type: ignore[not-subscriptable]
Expand Down Expand Up @@ -468,7 +468,7 @@ def __call__(
if "max_length" in kwargs:
max_decoding_length = kwargs["max_length"]
if type(x) is dict or isinstance(x, transformers.tokenization_utils_base.BatchEncoding):
x_tensor = x["input_ids"]
x_tensor = x["input_ids"] # type: ignore[invalid-argument-type]
else:
x_tensor = x
token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))] # type: ignore[not-subscriptable]
Expand Down
6 changes: 3 additions & 3 deletions src/pruna/algorithms/deepcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def get_hyperparameters(self) -> list:
"interval",
sequence=[1, 2, 3, 4, 5],
default_value=2,
meta=dict(
desc="Interval at which to cache - 1 disables caching. Higher is faster but might affect quality."
),
meta={
"desc": "Interval at which to cache - 1 disables caching. Higher is faster but might affect quality."
},
),
]

Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def get_hyperparameters(self) -> list:
upper=1.0,
default_value=0.02,
log=False,
meta=dict(desc="Strength of the denoising/refinement. Lower values mean less change/more refinement."),
meta={"desc": "Strength of the denoising/refinement. Lower values mean less change/more refinement."},
),
]

Expand Down
6 changes: 3 additions & 3 deletions src/pruna/algorithms/fastercache.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def get_hyperparameters(self) -> list:
"interval",
sequence=[1, 2, 3, 4, 5],
default_value=2,
meta=dict(
desc="Interval at which to cache spatial attention blocks - 1 disables caching."
meta={
"desc": "Interval at which to cache spatial attention blocks - 1 disables caching."
"Higher is faster but might degrade quality."
),
},
),
]

Expand Down
8 changes: 6 additions & 2 deletions src/pruna/algorithms/flash_attn3.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,11 @@ def _flash_attention_3(
)
else:
out, _, *_ = torch.ops.flash_attn_pruna._flash_attn_forward(
q=query, k=key, v=value, softmax_scale=scale, causal=is_causal
q=query, # type: ignore
k=key, # type: ignore
v=value, # type: ignore
softmax_scale=scale, # type: ignore
causal=is_causal, # type: ignore
)
return out

Expand Down Expand Up @@ -286,7 +290,7 @@ def __torch_function__(self, func, types, args=(), kwargs=None): # noqa: D105
def _flash_attention3(query, key, value, *, is_causal=False, softmax_scale=None, kernel=None):
# convert (B, H, S, D) → (B, S, H, D)
q, k, v = [x.transpose(1, 2).contiguous() for x in (query, key, value)]
out, _ = torch.ops.flash_attn_pruna._flash_attn_forward(q, k, v, causal=is_causal, softmax_scale=softmax_scale)
out, _ = torch.ops.flash_attn_pruna._flash_attn_forward(q, k, v, causal=is_causal, softmax_scale=softmax_scale) # type: ignore
# back to (B, H, S, D) for the rest of the pipeline
return out.transpose(1, 2)

Expand Down
7 changes: 4 additions & 3 deletions src/pruna/algorithms/fora.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,20 @@ def get_hyperparameters(self) -> list:
"interval",
sequence=range(1, 6),
default_value=2,
meta=dict(desc="Interval at which the outputs are computed. Higher is faster, but reduces quality."),
meta={"desc": "Interval at which the outputs are computed. Higher is faster, but reduces quality."},
),
OrdinalHyperparameter(
"start_step",
sequence=range(11),
default_value=2,
meta=dict(desc="How many steps to wait before starting to cache."),
meta={"desc": "How many steps to wait before starting to cache."},
),
OrdinalHyperparameter(
"backbone_calls_per_step",
sequence=range(1, 4),
default_value=1,
meta=dict(desc="Number of backbone forward passes per diffusion step (e.g., 2 for CFG)."),
meta={"desc": "Number of backbone forward passes per diffusion step (e.g., 2 for CFG)."}

),
]

Expand Down
16 changes: 8 additions & 8 deletions src/pruna/algorithms/global_utils/recovery/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,30 +73,30 @@ def get_hyperparameters(cls, task_name: str, **override_defaults: Any) -> list:
"r",
sequence=[4, 8, 16, 32, 64, 128],
default_value=default_hyperparameters["r"],
meta=dict(desc="Rank of the LoRA layers."),
meta={"desc": "Rank of the LoRA layers."},
),
OrdinalHyperparameter(
"alpha_r_ratio",
sequence=[0.5, 1.0, 2.0],
default_value=default_hyperparameters["alpha_r_ratio"],
meta=dict(desc="Alpha/Rank ratio of the LoRA layers."),
meta={"desc": "Alpha/Rank ratio of the LoRA layers."},
),
CategoricalHyperparameter(
"target_modules",
choices=[None, "all-linear"],
default_value=default_hyperparameters["target_modules"],
meta=dict(desc="Target modules for the LoRA layers."),
meta={"desc": "Target modules for the LoRA layers."},
),
Constant(
"dropout",
default_hyperparameters["dropout"],
meta=dict(desc="Dropout rate of the LoRA layers during training."),
meta={"desc": "Dropout rate of the LoRA layers during training."},
),
CategoricalHyperparameter(
"variant",
choices=["lora", "pissa"],
default_value=default_hyperparameters["variant"],
meta=dict(desc="Variant of the LoRA adapter."),
meta={"desc": "Variant of the LoRA adapter."},
),
]

Expand All @@ -116,13 +116,13 @@ def get_hyperparameters(cls, task_name: str, **override_defaults: Any) -> list:
"r",
sequence=[4, 8, 16, 32, 64, 128],
default_value=default_hyperparameters["r"],
meta=dict(desc="Rank of the LoRA layers."),
meta={"desc": "Rank of the LoRA layers."},
),
OrdinalHyperparameter(
"alpha_r_ratio",
sequence=[0.5, 1.0, 2.0],
default_value=default_hyperparameters["alpha_r_ratio"],
meta=dict(desc="Alpha/Rank ratio of the LoRA layers."),
meta={"desc": "Alpha/Rank ratio of the LoRA layers."},
),
Constant(
"target_modules", default_hyperparameters["target_modules"]
Expand All @@ -132,7 +132,7 @@ def get_hyperparameters(cls, task_name: str, **override_defaults: Any) -> list:
"variant",
choices=["lora", "pissa"],
default_value=default_hyperparameters["variant"],
meta=dict(desc="Variant of the LoRA adapter."),
meta={"desc": "Variant of the LoRA adapter."},
),
]
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,67 +95,67 @@ def get_hyperparameters(cls, **override_defaults) -> List:
lower=0,
upper=4096,
default_value=numeric_defaults["training_batch_size"],
meta=dict(desc="Number of steps from each diffusion process to use for distillation."),
meta={"desc": "Number of steps from each diffusion process to use for distillation."},
),
UniformIntegerHyperparameter(
"gradient_accumulation_steps",
lower=1,
upper=1024,
default_value=numeric_defaults["gradient_accumulation_steps"],
meta=dict(desc="Number of captions processed to estimate each gradient step."),
meta={"desc": "Number of captions processed to estimate each gradient step."},
),
UniformIntegerHyperparameter(
"num_epochs",
lower=0,
upper=4096,
default_value=numeric_defaults["num_epochs"],
meta=dict(desc="Number of epochs for distillation."),
meta={"desc": "Number of epochs for distillation."},
),
UniformFloatHyperparameter(
"validate_every_n_epoch",
lower=0.0,
upper=4096.0,
default_value=numeric_defaults["validate_every_n_epoch"],
meta=dict(
desc="Number of epochs between each round of validation and model checkpointing. "
meta={
"desc": "Number of epochs between each round of validation and model checkpointing. "
"If the value is between 0 and 1, validation will be performed multiple times per epoch, "
"e.g. 1/8 will result in 8 validations per epoch."
),
},
),
UniformFloatHyperparameter(
"learning_rate",
lower=0.0,
upper=1.0,
default_value=numeric_defaults["learning_rate"],
meta=dict(desc="Learning rate for distillation."),
meta={"desc": "Learning rate for distillation."},
),
Constant("weight_decay", numeric_defaults["weight_decay"]),
# report_to: for consistency with text-to-text-lora but wandb and tensorboard are not supported yet
Constant("report_to", string_defaults["report_to"]),
Boolean(
"use_cpu_offloading",
default=False,
meta=dict(desc="Whether to use CPU offloading for distillation."),
meta={"desc": "Whether to use CPU offloading for distillation."},
),
CategoricalHyperparameter(
"optimizer",
choices=["AdamW8bit", "AdamW", "Adam"],
default_value=string_defaults["optimizer"],
meta=dict(desc="Which optimizer to use for distillation."),
meta={"desc": "Which optimizer to use for distillation."},
),
UniformFloatHyperparameter(
"lr_decay",
lower=0.0,
upper=1.0,
default_value=numeric_defaults["lr_decay"],
meta=dict(desc="Learning rate decay, applied at each epoch."),
meta={"desc": "Learning rate decay, applied at each epoch."},
),
UniformIntegerHyperparameter(
"warmup_steps",
lower=0,
upper=2**14,
default_value=numeric_defaults["warmup_steps"],
meta=dict(desc="Number of warmup steps for the learning rate scheduler."),
meta={"desc": "Number of warmup steps for the learning rate scheduler."},
),
]

Expand Down Expand Up @@ -405,7 +405,7 @@ def distillation_forward(*args, **kwargs):
output["sample"] if ("return_dict" in kwargs and kwargs["return_dict"]) else output[0]
)
loss = self.loss(latent_output, latent_targets[self.num_previous_steps])
if is_training:
if is_training and active_steps is not None:
accumulation_normalized_loss = loss / (len(active_steps) * self.gradient_accumulation_steps)
self.manual_backward(accumulation_normalized_loss)
diffusion_step_losses.append(loss)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,53 +86,53 @@ def get_hyperparameters(cls, **override_defaults) -> List:
lower=0,
upper=4096,
default_value=numeric_defaults["training_batch_size"],
meta=dict(desc="Batch size for finetuning."),
meta={"desc": "Batch size for finetuning."},
),
UniformIntegerHyperparameter(
"gradient_accumulation_steps",
lower=1,
upper=1024,
default_value=numeric_defaults["gradient_accumulation_steps"],
meta=dict(desc="Number of gradient accumulation steps for finetuning."),
meta={"desc": "Number of gradient accumulation steps for finetuning."},
),
UniformIntegerHyperparameter(
"num_epochs",
lower=0,
upper=4096,
default_value=numeric_defaults["num_epochs"],
meta=dict(desc="Number of epochs for finetuning."),
meta={"desc": "Number of epochs for finetuning."},
),
UniformFloatHyperparameter(
"validate_every_n_epoch",
lower=0.0,
upper=4096.0,
default_value=numeric_defaults["validate_every_n_epoch"],
meta=dict(
desc="Number of epochs between each round of validation and model checkpointing. "
meta={
"desc": "Number of epochs between each round of validation and model checkpointing. "
"If the value is between 0 and 1, validation will be performed multiple times per epoch, "
"e.g. 1/8 will result in 8 validations per epoch."
),
},
),
UniformFloatHyperparameter(
"learning_rate",
lower=0.0,
upper=1.0,
default_value=numeric_defaults["learning_rate"],
meta=dict(desc="Learning rate for finetuning."),
meta={"desc": "Learning rate for finetuning."},
),
Constant("weight_decay", numeric_defaults["weight_decay"]),
# report_to: for consistency with text-to-text-lora but wandb and tensorboard are not supported yet
Constant("report_to", string_defaults["report_to"]),
Boolean(
"use_cpu_offloading",
default=False,
meta=dict(desc="Whether to use CPU offloading for finetuning."),
meta={"desc": "Whether to use CPU offloading for finetuning."},
), # necessary for Flux in float16 on L40S GPU (48gb VRAM)
CategoricalHyperparameter(
"optimizer",
choices=["AdamW8bit", "AdamW", "Adam"],
default_value=string_defaults["optimizer"],
meta=dict(desc="Which optimizer to use for finetuning."),
meta={"desc": "Which optimizer to use for finetuning."},
),
]

Expand Down Expand Up @@ -540,7 +540,6 @@ def configure_optimizers(self) -> torch.optim.Optimizer:
"""
lr = self.learning_rate
wd = self.weight_decay
kwargs = {"eps": 1e-7} if self.trainer.precision in [16, "16", "16-true"] else {}

if self.optimizer_name == "AdamW8bit":
optimizer_cls = AdamW8bit
Expand All @@ -553,4 +552,7 @@ def configure_optimizers(self) -> torch.optim.Optimizer:
optimizer_cls = getattr(torch.optim, self.optimizer_name)
finetune_params = get_trainable_parameters(self.pipeline)

return optimizer_cls(finetune_params, lr=lr, weight_decay=wd, **kwargs)
if self.trainer.precision in [16, "16", "16-true"]:
return optimizer_cls(finetune_params, lr=lr, weight_decay=wd, eps=1e-7)
else:
return optimizer_cls(finetune_params, lr=lr, weight_decay=wd)
Loading