Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ add_library(acestep-core STATIC
)
link_ggml_backends(acestep-core)

# dit-vae: full pipeline (text-enc + cond + dit + vae + wav)
add_executable(dit-vae tools/dit-vae.cpp)
# dit-vae: full pipeline (text-enc + cond + dit + vae + wav) + LoRA support
add_executable(dit-vae tools/dit-vae.cpp src/dit-lora.cpp)
target_link_libraries(dit-vae PRIVATE acestep-core)
link_ggml_backends(dit-vae)

Expand Down
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,40 @@ cd examples
./partial.sh # caption + lyrics + duration
./full.sh # all metadata provided
./dit-only.sh # skip LLM, DiT from noise
./lora.sh # LoRA adapter (PEFT safetensors)
```

Each example has a `-sft` variant (SFT model, 50 steps, CFG 7.0)
alongside the turbo default (8 steps, no CFG).

## LoRA

`dit-vae` supports PEFT LoRA adapters in `adapter_model.safetensors` format.
Pass `--lora <path>` and optionally `--lora-scale <float>` (default 1.0, typical alpha/rank):

```bash
./build/dit-vae \
--request request0.json \
--text-encoder models/Qwen3-Embedding-0.6B-Q8_0.gguf \
--dit models/acestep-v15-turbo-Q8_0.gguf \
--vae models/vae-BF16.gguf \
--lora lora/adapter_model.safetensors \
--lora-scale 1.0
```

Use `custom_tag` in the request JSON to append a trigger word to the caption:

```json
{
"caption": "Nu-disco track with funky bassline",
"custom_tag": "crydamoure",
"inference_steps": 8,
"shift": 3
}
```

See `examples/lora.sh` and `examples/lora.json` for a complete example.

## Generation modes

The LLM fills what's missing in the JSON and generates audio codes.
Expand Down
17 changes: 17 additions & 0 deletions examples/lora.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"task_type": "text2music",
"caption": "An energetic nu-disco track built on a foundation of a tight, funky slap bassline and a crisp, four-on-the-floor drum machine beat. The song opens with a distinctive, filtered wah-wah guitar riff that serves as a recurring motif. The arrangement is layered with shimmering synth pads, punchy synth stabs, and subtle arpeggiated synth textures that add movement. The track progresses through dynamic sections, including a brief atmospheric breakdown before rebuilding the main groove.",
"genre": "Nu-disco",
"lyrics": "[Instrumental]",
"bpm": 115,
"keyscale": "C# major",
"timesignature": "4",
"duration": 256,
"language": "unknown",
"instrumental": true,
"custom_tag": "crydamoure",
"inference_steps": 8,
"guidance_scale": 1,
"shift": 3,
"seed": -1
}
27 changes: 27 additions & 0 deletions examples/lora.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/bin/bash
# LoRA example: generate with a PEFT LoRA adapter (e.g. duckdbot/acestep-lora-cryda).
# Requires adapter_model.safetensors in lora/ (download once; see below).
set -eu
cd "$(dirname "$0")"

ADAPTER="lora/adapter_model.safetensors"
if [ ! -f "$ADAPTER" ]; then
echo "LoRA adapter not found at $ADAPTER"
exit 1
fi

# LLM: fill lyrics + codes
../build/ace-qwen3 \
--request lora.json \
--model ../models/acestep-5Hz-lm-4B-Q8_0.gguf

# DiT+VAE with LoRA (scale = alpha/rank; 1.0 is typical)
../build/dit-vae \
--request lora0.json \
--text-encoder ../models/Qwen3-Embedding-0.6B-Q8_0.gguf \
--dit ../models/acestep-v15-turbo-Q8_0.gguf \
--vae ../models/vae-BF16.gguf \
--lora "$ADAPTER" \
--lora-scale 1.0

echo "Done. Check lora00.wav"
93 changes: 78 additions & 15 deletions src/dit-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,23 @@ static struct ggml_tensor * dit_ggml_linear(
return ggml_mul_mat(ctx, weight, input);
}

// Linear with optional LoRA: out = W@x + scale * (B@(A@x)). lora_a/lora_b may be NULL.
static struct ggml_tensor * dit_ggml_linear_lora(
struct ggml_context * ctx,
struct ggml_tensor * weight,
struct ggml_tensor * lora_a, // [in, r]
struct ggml_tensor * lora_b, // [r, out]
float lora_scale,
struct ggml_tensor * input) {
struct ggml_tensor * out = ggml_mul_mat(ctx, weight, input);
if (lora_a && lora_b && lora_scale != 0.0f) {
struct ggml_tensor * ax = ggml_mul_mat(ctx, lora_a, input);
struct ggml_tensor * bax = ggml_mul_mat(ctx, lora_b, ax);
out = ggml_add(ctx, out, ggml_scale(ctx, bax, lora_scale));
}
return out;
}

// Helper: Linear layer with bias
static struct ggml_tensor * dit_ggml_linear_bias(
struct ggml_context * ctx,
Expand Down Expand Up @@ -161,20 +178,36 @@ static struct ggml_tensor * dit_ggml_build_self_attn(
struct ggml_tensor * q, * k, * v;
int q_dim = Nh * D;
int kv_dim = Nkv * D;
float lora_scale = m->lora_scale;
if (ly->sa_qkv) {
struct ggml_tensor * qkv = dit_ggml_linear(ctx, ly->sa_qkv, norm_sa);
q = ggml_cont(ctx, ggml_view_3d(ctx, qkv, q_dim, S, N, qkv->nb[1], qkv->nb[2], 0));
k = ggml_cont(ctx, ggml_view_3d(ctx, qkv, kv_dim, S, N, qkv->nb[1], qkv->nb[2], (size_t)q_dim * qkv->nb[0]));
v = ggml_cont(ctx, ggml_view_3d(ctx, qkv, kv_dim, S, N, qkv->nb[1], qkv->nb[2], (size_t)(q_dim + kv_dim) * qkv->nb[0]));
// LoRA on fused path: add scale * (B @ (A @ x)) per projection when adapters are loaded
if (lora_scale != 0.0f) {
if (ly->lora_sa_q_a && ly->lora_sa_q_b)
q = ggml_add(ctx, q, ggml_scale(ctx, ggml_mul_mat(ctx, ly->lora_sa_q_b, ggml_mul_mat(ctx, ly->lora_sa_q_a, norm_sa)), lora_scale));
if (ly->lora_sa_k_a && ly->lora_sa_k_b)
k = ggml_add(ctx, k, ggml_scale(ctx, ggml_mul_mat(ctx, ly->lora_sa_k_b, ggml_mul_mat(ctx, ly->lora_sa_k_a, norm_sa)), lora_scale));
if (ly->lora_sa_v_a && ly->lora_sa_v_b)
v = ggml_add(ctx, v, ggml_scale(ctx, ggml_mul_mat(ctx, ly->lora_sa_v_b, ggml_mul_mat(ctx, ly->lora_sa_v_a, norm_sa)), lora_scale));
}
} else if (ly->sa_qk) {
struct ggml_tensor * qk = dit_ggml_linear(ctx, ly->sa_qk, norm_sa);
q = ggml_cont(ctx, ggml_view_3d(ctx, qk, q_dim, S, N, qk->nb[1], qk->nb[2], 0));
k = ggml_cont(ctx, ggml_view_3d(ctx, qk, kv_dim, S, N, qk->nb[1], qk->nb[2], (size_t)q_dim * qk->nb[0]));
v = dit_ggml_linear(ctx, ly->sa_v_proj, norm_sa);
if (lora_scale != 0.0f) {
if (ly->lora_sa_q_a && ly->lora_sa_q_b)
q = ggml_add(ctx, q, ggml_scale(ctx, ggml_mul_mat(ctx, ly->lora_sa_q_b, ggml_mul_mat(ctx, ly->lora_sa_q_a, norm_sa)), lora_scale));
if (ly->lora_sa_k_a && ly->lora_sa_k_b)
k = ggml_add(ctx, k, ggml_scale(ctx, ggml_mul_mat(ctx, ly->lora_sa_k_b, ggml_mul_mat(ctx, ly->lora_sa_k_a, norm_sa)), lora_scale));
}
v = dit_ggml_linear_lora(ctx, ly->sa_v_proj, ly->lora_sa_v_a, ly->lora_sa_v_b, lora_scale, norm_sa);
} else {
q = dit_ggml_linear(ctx, ly->sa_q_proj, norm_sa);
k = dit_ggml_linear(ctx, ly->sa_k_proj, norm_sa);
v = dit_ggml_linear(ctx, ly->sa_v_proj, norm_sa);
q = dit_ggml_linear_lora(ctx, ly->sa_q_proj, ly->lora_sa_q_a, ly->lora_sa_q_b, lora_scale, norm_sa);
k = dit_ggml_linear_lora(ctx, ly->sa_k_proj, ly->lora_sa_k_a, ly->lora_sa_k_b, lora_scale, norm_sa);
v = dit_ggml_linear_lora(ctx, ly->sa_v_proj, ly->lora_sa_v_a, ly->lora_sa_v_b, lora_scale, norm_sa);
}

// 2) Reshape to heads: [Nh*D, S, N] -> [D, Nh, S, N]
Expand Down Expand Up @@ -236,7 +269,7 @@ static struct ggml_tensor * dit_ggml_build_self_attn(
}

// 8) O projection: [Nh*D, S, N] -> [H, S, N]
struct ggml_tensor * out = dit_ggml_linear(ctx, ly->sa_o_proj, attn);
struct ggml_tensor * out = dit_ggml_linear_lora(ctx, ly->sa_o_proj, ly->lora_sa_o_a, ly->lora_sa_o_b, m->lora_scale, attn);
return out;
}

Expand All @@ -250,20 +283,34 @@ static struct ggml_tensor * dit_ggml_build_mlp(
struct ggml_tensor * norm_ffn,
int S) {

DiTGGMLConfig & c = m->cfg;
int I = c.intermediate_size;
int N = (int)norm_ffn->ne[2];
float lora_scale = m->lora_scale;
struct ggml_tensor * ff;
if (ly->gate_up) {
// Fused: single matmul [H, 2*I] x [H, S, N] -> [2*I, S, N], then swiglu splits ne[0]
struct ggml_tensor * gu = dit_ggml_linear(ctx, ly->gate_up, norm_ffn);
ff = ggml_swiglu(ctx, gu);
if (lora_scale != 0.0f && ((ly->lora_gate_a && ly->lora_gate_b) || (ly->lora_up_a && ly->lora_up_b))) {
struct ggml_tensor * gate = ggml_cont(ctx, ggml_view_3d(ctx, gu, I, S, N, gu->nb[1], gu->nb[2], 0));
struct ggml_tensor * up = ggml_cont(ctx, ggml_view_3d(ctx, gu, I, S, N, gu->nb[1], gu->nb[2], (size_t)I * gu->nb[0]));
if (ly->lora_gate_a && ly->lora_gate_b)
gate = ggml_add(ctx, gate, ggml_scale(ctx, ggml_mul_mat(ctx, ly->lora_gate_b, ggml_mul_mat(ctx, ly->lora_gate_a, norm_ffn)), lora_scale));
if (ly->lora_up_a && ly->lora_up_b)
up = ggml_add(ctx, up, ggml_scale(ctx, ggml_mul_mat(ctx, ly->lora_up_b, ggml_mul_mat(ctx, ly->lora_up_a, norm_ffn)), lora_scale));
ff = ggml_swiglu_split(ctx, gate, up);
} else {
ff = ggml_swiglu(ctx, gu);
}
} else {
// Separate: two matmuls + split swiglu
struct ggml_tensor * gate = dit_ggml_linear(ctx, ly->gate_proj, norm_ffn);
struct ggml_tensor * up = dit_ggml_linear(ctx, ly->up_proj, norm_ffn);
// Separate: two matmuls + split swiglu (with optional LoRA)
struct ggml_tensor * gate = dit_ggml_linear_lora(ctx, ly->gate_proj, ly->lora_gate_a, ly->lora_gate_b, lora_scale, norm_ffn);
struct ggml_tensor * up = dit_ggml_linear_lora(ctx, ly->up_proj, ly->lora_up_a, ly->lora_up_b, lora_scale, norm_ffn);
ff = ggml_swiglu_split(ctx, gate, up);
}

// Down projection: [I, S] -> [H, S]
return dit_ggml_linear(ctx, ly->down_proj, ff);
return dit_ggml_linear_lora(ctx, ly->down_proj, ly->lora_down_a, ly->lora_down_b, lora_scale, ff);
}

// Build cross-attention sub-graph for a single layer.
Expand All @@ -289,6 +336,7 @@ static struct ggml_tensor * dit_ggml_build_cross_attn(
// Q from hidden, KV from encoder (full fused, Q+KV partial, separate)
int q_dim = Nh * D;
int kv_dim = Nkv * D;
float lora_scale = m->lora_scale;
struct ggml_tensor * q, * k, * v;
if (ly->ca_qkv) {
// Full QKV fused: split Q from hidden, KV from enc via weight views
Expand All @@ -300,16 +348,31 @@ static struct ggml_tensor * dit_ggml_build_cross_attn(
struct ggml_tensor * kv = ggml_mul_mat(ctx, w_kv, enc);
k = ggml_cont(ctx, ggml_view_3d(ctx, kv, kv_dim, enc_S, N, kv->nb[1], kv->nb[2], 0));
v = ggml_cont(ctx, ggml_view_3d(ctx, kv, kv_dim, enc_S, N, kv->nb[1], kv->nb[2], (size_t)kv_dim * kv->nb[0]));
// LoRA on fused path: add scale * (B @ (A @ x)) for Q (from norm_ca), K/V (from enc)
if (lora_scale != 0.0f) {
if (ly->lora_ca_q_a && ly->lora_ca_q_b)
q = ggml_add(ctx, q, ggml_scale(ctx, ggml_mul_mat(ctx, ly->lora_ca_q_b, ggml_mul_mat(ctx, ly->lora_ca_q_a, norm_ca)), lora_scale));
if (ly->lora_ca_k_a && ly->lora_ca_k_b)
k = ggml_add(ctx, k, ggml_scale(ctx, ggml_mul_mat(ctx, ly->lora_ca_k_b, ggml_mul_mat(ctx, ly->lora_ca_k_a, enc)), lora_scale));
if (ly->lora_ca_v_a && ly->lora_ca_v_b)
v = ggml_add(ctx, v, ggml_scale(ctx, ggml_mul_mat(ctx, ly->lora_ca_v_b, ggml_mul_mat(ctx, ly->lora_ca_v_a, enc)), lora_scale));
}
} else if (ly->ca_kv) {
// Q separate, K+V fused
q = dit_ggml_linear(ctx, ly->ca_q_proj, norm_ca);
q = dit_ggml_linear_lora(ctx, ly->ca_q_proj, ly->lora_ca_q_a, ly->lora_ca_q_b, lora_scale, norm_ca);
struct ggml_tensor * kv = ggml_mul_mat(ctx, ly->ca_kv, enc);
k = ggml_cont(ctx, ggml_view_3d(ctx, kv, kv_dim, enc_S, N, kv->nb[1], kv->nb[2], 0));
v = ggml_cont(ctx, ggml_view_3d(ctx, kv, kv_dim, enc_S, N, kv->nb[1], kv->nb[2], (size_t)kv_dim * kv->nb[0]));
if (lora_scale != 0.0f) {
if (ly->lora_ca_k_a && ly->lora_ca_k_b)
k = ggml_add(ctx, k, ggml_scale(ctx, ggml_mul_mat(ctx, ly->lora_ca_k_b, ggml_mul_mat(ctx, ly->lora_ca_k_a, enc)), lora_scale));
if (ly->lora_ca_v_a && ly->lora_ca_v_b)
v = ggml_add(ctx, v, ggml_scale(ctx, ggml_mul_mat(ctx, ly->lora_ca_v_b, ggml_mul_mat(ctx, ly->lora_ca_v_a, enc)), lora_scale));
}
} else {
q = dit_ggml_linear(ctx, ly->ca_q_proj, norm_ca);
k = dit_ggml_linear(ctx, ly->ca_k_proj, enc);
v = dit_ggml_linear(ctx, ly->ca_v_proj, enc);
q = dit_ggml_linear_lora(ctx, ly->ca_q_proj, ly->lora_ca_q_a, ly->lora_ca_q_b, m->lora_scale, norm_ca);
k = dit_ggml_linear_lora(ctx, ly->ca_k_proj, ly->lora_ca_k_a, ly->lora_ca_k_b, m->lora_scale, enc);
v = dit_ggml_linear_lora(ctx, ly->ca_v_proj, ly->lora_ca_v_a, ly->lora_ca_v_b, m->lora_scale, enc);
}

// reshape to [D, heads, seq, N] then permute to [D, seq, heads, N]
Expand Down Expand Up @@ -339,7 +402,7 @@ static struct ggml_tensor * dit_ggml_build_cross_attn(
attn = ggml_reshape_3d(ctx, attn, Nh * D, S, N);

// O projection
return dit_ggml_linear(ctx, ly->ca_o_proj, attn);
return dit_ggml_linear_lora(ctx, ly->ca_o_proj, ly->lora_ca_o_a, ly->lora_ca_o_b, m->lora_scale, attn);
}

// Build one full DiT layer (AdaLN + self-attn + cross-attn + FFN + gated residuals)
Expand Down
Loading