Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 37 additions & 63 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ cmake --build . --config Release -j$(nproc)

Builds two binaries: `ace-qwen3` (LLM) and `dit-vae` (DiT + VAE).

**CI (GitHub Actions)**
- **Build**: on every push/PR, builds on Ubuntu (BLAS) and macOS (Metal); smoke test runs each binary `--help`.
- **Test generation**: on release or manual trigger only; runs the same checks as **local** `tests/run-generation-tests.sh`. Validate locally first (build + `./models.sh`, then `tests/run-generation-tests.sh`), then use CI to confirm. See `.github/workflows/`.

## Models

Pre-quantized GGUFs on [Hugging Face](https://huggingface.co/Serveurperso/ACE-Step-1.5-GGUF).
Expand Down Expand Up @@ -143,16 +139,10 @@ cd examples
./partial.sh # caption + lyrics + duration
./full.sh # all metadata provided
./dit-only.sh # skip LLM, DiT from noise
./cover.sh # cover mode: decode precomputed audio_codes (no LLM)
./cover-reference.sh # cover + reference_audio for timbre (WAV/MP3; needs reference.wav or .mp3)
./test-reference.sh # reference_audio (WAV or MP3) + audio_cover_strength
./lora.sh # DiT + LoRA adapter
```

Each example has a `-sft` variant (SFT model, 50 steps, CFG 7.0)
alongside the turbo default (8 steps, no CFG). For **reference timbre**, set `reference_audio` to a **WAV or MP3** path; dit-vae loads it (MP3 decoded in memory via header-only minimp3, no temp files), encodes with the VAE encoder (requires a full VAE GGUF that includes encoder weights).

**LoRA adapters**: use `--lora <path>` and optional `--lora-scale <float>` with dit-vae to run the DiT with PEFT-style Ace-Step LoRAs.
alongside the turbo default (8 steps, no CFG).

## Generation modes

Expand Down Expand Up @@ -180,11 +170,10 @@ Run `dit-vae` to decode existing codes. See `examples/dit-only.json`.

## Request JSON reference

All fields with defaults. Only `caption` is required. Built-in modes (text2music, cover, repaint) and audio inputs follow the [ACE-Step 1.5 Tutorial](https://github.com/ace-step/ACE-Step-1.5/blob/main/docs/en/Tutorial.md); see [docs/MODES.md](docs/MODES.md) for what is implemented.
All fields with defaults. Only `caption` is required.

```json
{
"task_type": "text2music",
"caption": "",
"lyrics": "",
"instrumental": false,
Expand All @@ -199,12 +188,7 @@ All fields with defaults. Only `caption` is required. Built-in modes (text2music
"lm_top_p": 0.9,
"lm_top_k": 0,
"lm_negative_prompt": "",
"reference_audio": "",
"src_audio": "",
"audio_codes": "",
"audio_cover_strength": 1.0,
"repainting_start": 0.0,
"repainting_end": 0.0,
"inference_steps": 8,
"guidance_scale": 7.0,
"shift": 3.0
Expand All @@ -214,12 +198,7 @@ All fields with defaults. Only `caption` is required. Built-in modes (text2music
Key fields: `seed` -1 means random (resolved once, then +1 per batch
element). `audio_codes` is generated by ace-qwen3 and consumed by
dit-vae (comma separated FSQ token IDs). When present, the LLM is
skipped entirely (cover-style generation). `reference_audio`: path to a **WAV or MP3** file for global timbre/style (MP3 decoded in memory; encoded via built-in VAE encoder; requires VAE GGUF with encoder weights). `src_audio`: path to a **WAV or MP3** for cover source; dit-vae encodes it (VAE + FSQ nearest-codeword) to codes internally, no Python required (see docs/MODES.md).

**Reference and cover strength (not the same as guidance_scale):**
- **`audio_cover_strength`** (0.0–1.0): Controls how strongly the **cover/source** (from `audio_codes` or `src_audio`) influences the DiT context. The context is blended with silence: `(1 - audio_cover_strength)*silence + audio_cover_strength*decoded`. Use 1.0 for full cover influence, lower values to soften it. Only applies when cover context is present.
- **`reference_audio`**: Timbre from the reference file is applied at full strength; there is no separate strength parameter for reference timbre.
- **`guidance_scale`**: This is **DiT classifier-free guidance** (conditioned vs unconditioned prediction), not reference or cover strength. Turbo models ignore it (forced to 1.0).
skipped entirely.

Turbo preset: `inference_steps=8, shift=3.0` (no guidance_scale, turbo models don't use CFG).
SFT preset: `inference_steps=50, guidance_scale=4.0, shift=6.0`.
Expand All @@ -241,6 +220,7 @@ Output naming: input.json -> input0.json, input1.json, ... (last digit = batch i
Debug:
--max-seq <N> KV cache size (default: 8192)
--no-fsm Disable FSM constrained decoding
--no-fa Disable flash attention
--dump-logits <path> Dump prefill logits (binary f32)
--dump-tokens <path> Dump prompt token IDs (CSV)
```
Expand All @@ -262,10 +242,6 @@ Required:
--dit <gguf> DiT GGUF file
--vae <gguf> VAE GGUF file

LoRA:
--lora <path> LoRA adapter (adapter_model.safetensors)
--lora-scale <float> LoRA scale, e.g. alpha/rank (default: 1.0)

Batch:
--batch <N> DiT variations per request (default: 1, max 9)

Expand All @@ -276,6 +252,7 @@ VAE tiling (memory control):
--vae-overlap <N> Overlap frames per side (default: 64)

Debug:
--no-fa Disable flash attention
--dump <dir> Dump intermediate tensors
```

Expand Down Expand Up @@ -320,10 +297,7 @@ conditional and N unconditional sequences are packed into a single forward pass
`logits = uncond + scale * (cond - uncond)`. The KV cache is a single 4D tensor
`[D, max_seq, Nkv, n_sets]` shared across all batch elements and CFG paths. Shared
prompts are prefilled once and cloned to other KV sets via copy, avoiding redundant
prefills. Embedding lookup bypasses ggml_get_rows entirely: rows are read directly
from the mmap'd GGUF file on CPU, dequantized, and uploaded as F32 input tensors.
Decode uses a dedicated single-backend graph allocator (gallocr) with no scheduler
dispatch overhead, while prefill uses the multi-backend scheduler for flexibility.
prefills.

## Accuracy

Expand All @@ -343,42 +317,42 @@ python3 debug-dit-cossim.py # DiT: per-layer cossim GGML vs Python (turbo/

## Patched GGML fork

Uses a patched GGML fork (submodule) with ops added for the Oobleck VAE decoder.
Uses a patched GGML fork (submodule) with two new ops and a CUDA bugfix for the Oobleck
VAE decoder. All backends: CPU, CUDA, Metal, Vulkan. F32/F16/BF16 data types.
The DiT uses only standard GGML ops and needs no patches.

The VAE reconstructs audio from latent space through 5 upsampling blocks (total 1920x),
each running a transposed convolution followed by 3 WaveNet-style residual units with
dilated convolutions and Snake activations. A single tile builds a graph of 36 snake
activations, 5 transposed convolutions, and 32 regular convolutions. At the final blocks,
sequence lengths reach 491520 timesteps, which stresses GGML ops designed for short NLP sequences.
The DiT (flow matching diffusion transformer) uses only standard GGML ops and needs no patches.

Patches on top of upstream GGML, oldest first:

| Commit | Scope | Description |
|--------|-------|-------------|
| `8c70db84` | CUDA | `conv_transpose_1d`: replace O(T_in) brute-force loop with bounded range |
| `b65bf458` | CUDA | `im2col`: grid-stride loop on OW to fix gridDim.y overflow when T > 65535 |
| `e0e36f3c` | Metal | `conv_transpose_1d`: same bounded loop fix as CUDA |
| `2b9080bd` | CPU, CUDA, Metal | New `GGML_OP_COL2IM_1D`: scatter-add for GEMM-based conv_transpose_1d decomposition |
| `02c8041f` | CPU, CUDA, Metal | New `GGML_OP_SNAKE`: fused activation y = x + sin^2(a*x) / b (replaces 5 element-wise ops) |
| `3f60b19c` | Metal | Fix snake kernel to use current C wrapper API |
| `cb5d7067` | Vulkan | Guard `VK_EXT_layer_settings` for legacy Vulkan SDK (fixes MI50/gfx906) |
| `1f0f4214` | Vulkan | `col2im_1d`: add Vulkan backend |
| `efbf3df6` | Vulkan | `snake`: add Vulkan backend |
| `6608cd11` | Vulkan | Fix rvalue ref for `col2im_1d` and `snake` push constants |
| `06101d38` | Vulkan | Fix double-division dispatch for `col2im_1d` and `snake` |
| `91416cee` | CPU, CUDA, Metal, Vulkan | `col2im_1d`: fuse padding crop via p0 parameter (saves 5 allocs + 5 memcpy per VAE tile) |
| `20675b09` | Vulkan | `col2im_1d`, `snake`: 2D dispatch (fixes workgroup overflow on MI50) |

**Why col2im_1d**: upstream `ggml_conv_transpose_1d` uses a naive CUDA kernel (one scalar
FMA loop per output element, no shared memory, no tensor cores). The VAE spends 40% of its
FLOP budget on transposed convolutions. We decompose it as `mul_mat + col2im_1d`, routing
the heavy GEMM through cuBLAS/BLAS/MPS tensor cores. The col2im_1d gather has a 2-iteration
inner loop and is pure bandwidth.

**Why snake**: the Oobleck VAE uses Snake1d activation (x + sin^2(a*x) / b) 36 times per
tile. Without a fused op, each activation requires 5 separate GGML kernels (mul, sin, sqr,
mul, add), causing 5x the memory traffic. The fused kernel reads x once, writes y once.
sequence lengths reach 491520 timesteps, which stresses GGML ops designed for short NLP
sequences.

### `GGML_OP_SNAKE` (fused Snake activation)

Computes y = x + sin^2(a * x) * inv_b in a single kernel.
The Oobleck VAE calls this 36 times per tile. Without a fused op, each activation
requires 5 separate GGML kernels (mul, sin, sqr, mul, add), causing 5x the memory
traffic. The fused kernel reads x once and writes y once. BF16 cast nodes before/after
each snake call halve memory bandwidth at the cost of negligible precision loss
(cossim > 0.999 vs F32 baseline).

### `GGML_OP_COL2IM_1D` (scatter-add for GEMM-based conv_transpose_1d)

Gather-based reconstruction of a 1D signal from GEMM columns [K*OC, T_in] to
[T_out, OC], with fused padding crop via the p0 parameter.
Upstream `ggml_conv_transpose_1d` uses a naive kernel (one scalar FMA loop per output
element, no shared memory, no tensor cores). The VAE spends 40% of its FLOP budget on
transposed convolutions. We decompose each as `mul_mat + col2im_1d`, routing the heavy
GEMM through cuBLAS/BLAS/MPS tensor cores. The col2im_1d gather has a 2-iteration inner
loop and is pure bandwidth. BF16 cast nodes around col2im_1d halve the scatter bandwidth.

### Bugfix: `im2col` gridDim.y overflow (CUDA)

Upstream `im2col_kernel` uses OW directly as grid dimension Y, which exceeds the CUDA
65535 gridDim limit on long sequences. The VAE calls `ggml_conv_1d` (im2col path) 32
times per tile at output widths up to 491520. Fixed with a grid-stride loop on OW and
`MIN(OW, MAX_GRIDDIM_Z)` clamping.

## Acknowledgements

Expand Down
1 change: 1 addition & 0 deletions _codeql_detected_source_root
8 changes: 8 additions & 0 deletions buildcuda.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!/bin/bash

rm -rf build
mkdir build
cd build

cmake .. -DGGML_CUDA=ON -DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc
cmake --build . --config Release -j "$(nproc)"
8 changes: 6 additions & 2 deletions src/cond-enc.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct CondGGML {
ggml_backend_t backend;
ggml_backend_t cpu_backend;
ggml_backend_sched_t sched;
bool use_flash_attn;
WeightCtx wctx;
};

Expand All @@ -78,6 +79,7 @@ static void cond_ggml_init_backend(CondGGML * m) {
m->backend = bp.backend;
m->cpu_backend = bp.cpu_backend;
m->sched = backend_sched_new(bp, 8192);
m->use_flash_attn = true;
}

// Load from ACEStep DiT GGUF
Expand Down Expand Up @@ -191,7 +193,8 @@ static void cond_ggml_forward(CondGGML * m,
for (int i = 0; i < m->lyric_cfg.n_layers; i++) {
struct ggml_tensor * layer_mask = (i % 2 == 0) ? lyric_slide_mask : NULL;
lyric_h = qwen3_build_layer(ctx, m->lyric_cfg, &m->lyric_layers[i],
lyric_h, lyric_pos, layer_mask, S_lyric);
lyric_h, lyric_pos, layer_mask, S_lyric,
m->use_flash_attn);
}
lyric_h = qwen3_rms_norm(ctx, lyric_h, m->lyric_norm, m->lyric_cfg.rms_norm_eps);

Expand Down Expand Up @@ -236,7 +239,8 @@ static void cond_ggml_forward(CondGGML * m,
for (int i = 0; i < m->timbre_cfg.n_layers; i++) {
struct ggml_tensor * layer_mask = (i % 2 == 0) ? timbre_slide_mask : NULL;
timbre_h = qwen3_build_layer(ctx, m->timbre_cfg, &m->timbre_layers[i],
timbre_h, timbre_pos, layer_mask, S_ref);
timbre_h, timbre_pos, layer_mask, S_ref,
m->use_flash_attn);
}
timbre_h = qwen3_rms_norm(ctx, timbre_h, m->timbre_norm, m->timbre_cfg.rms_norm_eps);

Expand Down
5 changes: 4 additions & 1 deletion src/fsq-detok.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ struct DetokGGML {
ggml_backend_t backend;
ggml_backend_t cpu_backend;
ggml_backend_sched_t sched;
bool use_flash_attn;
WeightCtx wctx;
};

Expand All @@ -73,6 +74,7 @@ static bool detok_ggml_load(DetokGGML * m, const char * gguf_path,
m->cfg = detok_config();
m->backend = backend;
m->cpu_backend = cpu_backend;
m->use_flash_attn = true;

GGUFModel gf;
if (!gf_load(&gf, gguf_path)) {
Expand Down Expand Up @@ -169,7 +171,8 @@ static int detok_ggml_decode(DetokGGML * m, const int * codes, int T_5Hz,

// 2L encoder + norm (non-causal, no mask needed at S=5)
hidden = qwen3_build_layers(ctx, m->cfg, m->layers, m->norm,
hidden, positions, NULL, P);
hidden, positions, NULL, P,
m->use_flash_attn);

// proj_out: [2048, 5] -> [64, 5]
struct ggml_tensor * output = ggml_mul_mat(ctx, m->proj_out_w, hidden);
Expand Down
Loading