Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions invokeai/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from invokeai.app.services.workflow_records.workflow_records_sqlite import SqliteWorkflowRecordsStorage
from invokeai.app.services.workflow_thumbnails.workflow_thumbnails_disk import WorkflowThumbnailFileStorageDisk
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
AnimaConditioningInfo,
BasicConditioningInfo,
CogView4ConditioningInfo,
ConditioningFieldData,
Expand Down Expand Up @@ -140,6 +141,7 @@ def initialize(
SD3ConditioningInfo,
CogView4ConditioningInfo,
ZImageConditioningInfo,
AnimaConditioningInfo,
],
ephemeral=True,
),
Expand Down
710 changes: 710 additions & 0 deletions invokeai/app/invocations/anima_denoise.py

Large diffs are not rendered by default.

119 changes: 119 additions & 0 deletions invokeai/app/invocations/anima_image_to_latents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Anima image-to-latents invocation.

Encodes an image to latent space using the Anima VAE (AutoencoderKLWan or FLUX VAE).

For Wan VAE (AutoencoderKLWan):
- Input image is converted to 5D tensor [B, C, T, H, W] with T=1
- After encoding, latents are normalized: (latents - mean) / std
(inverse of the denormalization in anima_latents_to_image.py)

For FLUX VAE (AutoEncoder):
- Encoding is handled internally by the FLUX VAE
"""

from typing import Union

import einops
import torch
from diffusers.models.autoencoders import AutoencoderKLWan

from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
ImageField,
Input,
InputField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder as FluxAutoEncoder
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux

AnimaVAE = Union[AutoencoderKLWan, FluxAutoEncoder]


@invocation(
"anima_i2l",
title="Image to Latents - Anima",
tags=["image", "latents", "vae", "i2l", "anima"],
category="image",
version="1.0.0",
classification=Classification.Prototype,
)
class AnimaImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates latents from an image using the Anima VAE (supports Wan 2.1 and FLUX VAE)."""

image: ImageField = InputField(description="The image to encode.")
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)

@staticmethod
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
if not isinstance(vae_info.model, (AutoencoderKLWan, FluxAutoEncoder)):
raise TypeError(
f"Expected AutoencoderKLWan or FluxAutoEncoder for Anima VAE, got {type(vae_info.model).__name__}."
)

estimated_working_memory = estimate_vae_working_memory_flux(
operation="encode",
image_tensor=image_tensor,
vae=vae_info.model,
)

with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
if not isinstance(vae, (AutoencoderKLWan, FluxAutoEncoder)):
raise TypeError(f"Expected AutoencoderKLWan or FluxAutoEncoder, got {type(vae).__name__}.")

vae_dtype = next(iter(vae.parameters())).dtype
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)

with torch.inference_mode():
if isinstance(vae, FluxAutoEncoder):
# FLUX VAE handles scaling internally
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
latents = vae.encode(image_tensor, sample=True, generator=generator)
else:
# AutoencoderKLWan expects 5D input [B, C, T, H, W]
if image_tensor.ndim == 4:
image_tensor = image_tensor.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]

encoded = vae.encode(image_tensor, return_dict=False)[0]
latents = encoded.sample().to(dtype=vae_dtype)

# Normalize to denoiser space: (latents - mean) / std
# This is the inverse of the denormalization in anima_latents_to_image.py
latents_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latents)
latents_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(latents)
latents = (latents - latents_mean) / latents_std

# Remove temporal dimension: [B, C, 1, H, W] -> [B, C, H, W]
if latents.ndim == 5:
latents = latents.squeeze(2)

return latents

@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
image = context.images.get_pil(self.image.image_name)

image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
if image_tensor.dim() == 3:
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")

vae_info = context.models.load(self.vae.vae)
if not isinstance(vae_info.model, (AutoencoderKLWan, FluxAutoEncoder)):
raise TypeError(
f"Expected AutoencoderKLWan or FluxAutoEncoder for Anima VAE, got {type(vae_info.model).__name__}."
)

context.util.signal_progress("Running Anima VAE encode")
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)

latents = latents.to("cpu")
name = context.tensors.save(tensor=latents)
return LatentsOutput.build(latents_name=name, latents=latents, seed=None)
108 changes: 108 additions & 0 deletions invokeai/app/invocations/anima_latents_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Anima latents-to-image invocation.

Decodes Anima latents using the QwenImage VAE (AutoencoderKLWan) or
compatible FLUX VAE as fallback.

Latents from the denoiser are in normalized space (zero-centered). Before
VAE decode, they must be denormalized using the Wan 2.1 per-channel
mean/std: latents = latents * std + mean (matching diffusers WanPipeline).

The VAE expects 5D latents [B, C, T, H, W] — for single images, T=1.
"""

import torch
from diffusers.models.autoencoders import AutoencoderKLWan
from einops import rearrange
from PIL import Image

from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.fields import (
FieldDescriptions,
Input,
InputField,
LatentsField,
WithBoard,
WithMetadata,
)
from invokeai.app.invocations.model import VAEField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.flux.modules.autoencoder import AutoEncoder as FluxAutoEncoder
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.vae_working_memory import estimate_vae_working_memory_flux


@invocation(
"anima_l2i",
title="Latents to Image - Anima",
tags=["latents", "image", "vae", "l2i", "anima"],
category="latents",
version="1.0.2",
classification=Classification.Prototype,
)
class AnimaLatentsToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Generates an image from latents using the Anima VAE.

Supports the Wan 2.1 QwenImage VAE (AutoencoderKLWan) with explicit
latent denormalization, and FLUX VAE as fallback.
"""

latents: LatentsField = InputField(description=FieldDescriptions.latents, input=Input.Connection)
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)

@torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.tensors.load(self.latents.latents_name)

vae_info = context.models.load(self.vae.vae)
if not isinstance(vae_info.model, (AutoencoderKLWan, FluxAutoEncoder)):
raise TypeError(
f"Expected AutoencoderKLWan or FluxAutoEncoder for Anima VAE, got {type(vae_info.model).__name__}."
)

estimated_working_memory = estimate_vae_working_memory_flux(
operation="decode",
image_tensor=latents,
vae=vae_info.model,
)

with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
context.util.signal_progress("Running Anima VAE decode")
if not isinstance(vae, (AutoencoderKLWan, FluxAutoEncoder)):
raise TypeError(f"Expected AutoencoderKLWan or FluxAutoEncoder, got {type(vae).__name__}.")

vae_dtype = next(iter(vae.parameters())).dtype
latents = latents.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)

TorchDevice.empty_cache()

with torch.inference_mode():
if isinstance(vae, FluxAutoEncoder):
# FLUX VAE handles scaling internally, expects 4D [B, C, H, W]
img = vae.decode(latents)
else:
# Expects 5D latents [B, C, T, H, W]
if latents.ndim == 4:
latents = latents.unsqueeze(2) # [B, C, H, W] -> [B, C, 1, H, W]

# Denormalize from denoiser space to raw VAE space
# (same as diffusers WanPipeline and ComfyUI Wan21.process_out)
latents_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(latents)
latents_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(latents)
latents = latents * latents_std + latents_mean

decoded = vae.decode(latents, return_dict=False)[0]

# Output is 5D [B, C, T, H, W] — squeeze temporal dim
if decoded.ndim == 5:
decoded = decoded.squeeze(2)
img = decoded

img = img.clamp(-1, 1)
img = rearrange(img[0], "c h w -> h w c")
img_pil = Image.fromarray((127.5 * (img + 1.0)).byte().cpu().numpy())

TorchDevice.empty_cache()

image_dto = context.images.save(image=img_pil)
return ImageOutput.build(image_dto)
151 changes: 151 additions & 0 deletions invokeai/app/invocations/anima_lora_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from typing import Optional

from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField
from invokeai.app.invocations.model import LoRAField, ModelIdentifierField, Qwen3EncoderField, TransformerField
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelType


@invocation_output("anima_lora_loader_output")
class AnimaLoRALoaderOutput(BaseInvocationOutput):
"""Anima LoRA Loader Output"""

transformer: Optional[TransformerField] = OutputField(
default=None, description=FieldDescriptions.transformer, title="Anima Transformer"
)
qwen3_encoder: Optional[Qwen3EncoderField] = OutputField(
default=None, description=FieldDescriptions.qwen3_encoder, title="Qwen3 Encoder"
)


@invocation(
"anima_lora_loader",
title="Apply LoRA - Anima",
tags=["lora", "model", "anima"],
category="model",
version="1.0.0",
)
class AnimaLoRALoaderInvocation(BaseInvocation):
"""Apply a LoRA model to an Anima transformer and/or Qwen3 text encoder."""

lora: ModelIdentifierField = InputField(
description=FieldDescriptions.lora_model,
title="LoRA",
ui_model_base=BaseModelType.Anima,
ui_model_type=ModelType.LoRA,
)
weight: float = InputField(default=0.75, description=FieldDescriptions.lora_weight)
transformer: TransformerField | None = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Anima Transformer",
)
qwen3_encoder: Qwen3EncoderField | None = InputField(
default=None,
title="Qwen3 Encoder",
description=FieldDescriptions.qwen3_encoder,
input=Input.Connection,
)

def invoke(self, context: InvocationContext) -> AnimaLoRALoaderOutput:
lora_key = self.lora.key

if not context.models.exists(lora_key):
raise ValueError(f"Unknown lora: {lora_key}!")

if self.transformer and any(lora.lora.key == lora_key for lora in self.transformer.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to transformer.')
if self.qwen3_encoder and any(lora.lora.key == lora_key for lora in self.qwen3_encoder.loras):
raise ValueError(f'LoRA "{lora_key}" already applied to Qwen3 encoder.')

output = AnimaLoRALoaderOutput()

if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)
output.transformer.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)
if self.qwen3_encoder is not None:
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)
output.qwen3_encoder.loras.append(
LoRAField(
lora=self.lora,
weight=self.weight,
)
)

return output


@invocation(
"anima_lora_collection_loader",
title="Apply LoRA Collection - Anima",
tags=["lora", "model", "anima"],
category="model",
version="1.0.0",
)
class AnimaLoRACollectionLoader(BaseInvocation):
"""Applies a collection of LoRAs to an Anima transformer."""

loras: Optional[LoRAField | list[LoRAField]] = InputField(
default=None, description="LoRA models and weights. May be a single LoRA or collection.", title="LoRAs"
)

transformer: Optional[TransformerField] = InputField(
default=None,
description=FieldDescriptions.transformer,
input=Input.Connection,
title="Transformer",
)
qwen3_encoder: Qwen3EncoderField | None = InputField(
default=None,
title="Qwen3 Encoder",
description=FieldDescriptions.qwen3_encoder,
input=Input.Connection,
)

def invoke(self, context: InvocationContext) -> AnimaLoRALoaderOutput:
output = AnimaLoRALoaderOutput()
loras = self.loras if isinstance(self.loras, list) else [self.loras]
added_loras: list[str] = []

if self.transformer is not None:
output.transformer = self.transformer.model_copy(deep=True)

if self.qwen3_encoder is not None:
output.qwen3_encoder = self.qwen3_encoder.model_copy(deep=True)

for lora in loras:
if lora is None:
continue
if lora.lora.key in added_loras:
continue

if not context.models.exists(lora.lora.key):
raise Exception(f"Unknown lora: {lora.lora.key}!")

if lora.lora.base is not BaseModelType.Anima:
raise ValueError(
f"LoRA '{lora.lora.key}' is for {lora.lora.base.value if lora.lora.base else 'unknown'} models, "
"not Anima models. Ensure you are using an Anima compatible LoRA."
)

added_loras.append(lora.lora.key)

if self.transformer is not None and output.transformer is not None:
output.transformer.loras.append(lora)

if self.qwen3_encoder is not None and output.qwen3_encoder is not None:
output.qwen3_encoder.loras.append(lora)

return output
Loading
Loading