Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ macro(link_ggml_backends target)
target_link_libraries(${target} PRIVATE ggml-${backend})
string(TOUPPER ${backend} BACKEND_UPPER)
target_compile_definitions(${target} PRIVATE ACESTEP_HAVE_${BACKEND_UPPER})
if(backend STREQUAL "cuda")
find_package(CUDAToolkit QUIET)
if(CUDAToolkit_FOUND)
target_link_libraries(${target} PRIVATE CUDA::cudart)
endif()
endif()
endif()
endforeach()
endmacro()
Expand Down
25 changes: 22 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,9 @@ python3 debug-dit-cossim.py # DiT: per-layer cossim GGML vs Python (turbo/

## Patched GGML fork

Uses a patched GGML fork (submodule) with two new ops and a CUDA bugfix for the Oobleck
VAE decoder. All backends: CPU, CUDA, Metal, Vulkan. F32/F16/BF16 data types.
The DiT uses only standard GGML ops and needs no patches.
Uses a patched GGML fork (submodule) with two new ops, a Metal im2col optimization, and
a CUDA bugfix for the Oobleck VAE decoder. All backends: CPU, CUDA, Metal, Vulkan.
F32/F16/BF16 data types. The DiT uses only standard GGML ops and needs no patches.

The VAE reconstructs audio from latent space through 5 upsampling blocks (total 1920x),
each running a transposed convolution followed by 3 WaveNet-style residual units with
Expand Down Expand Up @@ -347,6 +347,25 @@ transposed convolutions. We decompose each as `mul_mat + col2im_1d`, routing the
GEMM through cuBLAS/BLAS/MPS tensor cores. The col2im_1d gather has a 2-iteration inner
loop and is pure bandwidth. BF16 cast nodes around col2im_1d halve the scatter bandwidth.

### Metal: `kernel_im2col_1d` (flat 1D dispatch)

The generic Metal `kernel_im2col` dispatches (IC, 1, OW) threadgroups with K threads
each. For the VAE's 1D convolutions with small kernels (k=1 or k=7), this wastes 78-97%
of SIMD lanes (7 or 1 active threads per 32-wide SIMD group). The dedicated
`kernel_im2col_1d` uses a flat dispatch identical to snake and col2im_1d:
(total/256, 1, 1) threadgroups with 256 threads, achieving full SIMD utilization.
The dispatch branches on `is_2D` at runtime; the 2D path and kernel are unchanged.
CUDA and Vulkan already use flat dispatch and are not affected.

VAE decode (M2 Pro 16GB, 86.8s audio @ 48kHz stereo):

| chunk | overlap | im2col | tiles | time |
|------:|--------:|-----------|------:|-------:|
| 256 | 64 | generic | 17 | 71.2s |
| 1024 | 16 | generic | 3 | 38.9s |
| 256 | 64 | im2col_1d | 17 | 31.8s |
| 1024 | 16 | im2col_1d | 3 | 18.3s |

### Bugfix: `im2col` gridDim.y overflow (CUDA)

Upstream `im2col_kernel` uses OW directly as grid dimension Y, which exceeds the CUDA
Expand Down
17 changes: 17 additions & 0 deletions src/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@

#include "ggml-backend.h"
#include "ggml-cpu.h"
#ifdef ACESTEP_HAVE_CUDA
// Query compute capability without pulling in cuda_runtime.h.
// cudaDeviceGetAttribute takes an int enum value; we pass the raw constants.
extern "C" int cudaDeviceGetAttribute(int *, int, int);
#endif
#include <cstdio>
#include <cstring>
#include <thread>

struct BackendPair {
ggml_backend_t backend;
ggml_backend_t cpu_backend;
int gpu_cc; // CUDA compute capability (e.g. 720 for sm_72), 0 if not CUDA
};

// Initialize backends: load all available (CUDA, Metal, Vulkan...),
Expand All @@ -37,6 +43,17 @@ static BackendPair backend_init(const char * label) {
}
fprintf(stderr, "[Load] %s backend: %s (CPU threads: %d)\n",
label, ggml_backend_name(bp.backend), n_threads);

bp.gpu_cc = 0;
#ifdef ACESTEP_HAVE_CUDA
if (!best_is_cpu) {
int major = 0, minor = 0;
cudaDeviceGetAttribute(&major, 75, 0); // cudaDevAttrComputeCapabilityMajor
cudaDeviceGetAttribute(&minor, 76, 0); // cudaDevAttrComputeCapabilityMinor
bp.gpu_cc = major * 100 + minor * 10;
}
#endif

return bp;
}

Expand Down
2 changes: 1 addition & 1 deletion src/bpe.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ static std::vector<std::string> gpt2_pre_tokenize(const std::string &text) {
// case-insensitive compare
for (int k = 0; k < slen; k++) {
char c1 = rest[k], c2 = suffix[k];
if (c1 >= 'A' && c1 <= 'Z') c1 += 32;
if (c1 >= 'A' && c1 <= 'Z') c1 = (char)(c1 + 32);
if (c1 != c2) return false;
}
// next char should NOT be a letter
Expand Down
13 changes: 13 additions & 0 deletions src/cond-enc.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ struct CondGGML {
ggml_backend_t cpu_backend;
ggml_backend_sched_t sched;
bool use_flash_attn;
bool clamp_fp16; // clamp encoder output on sub-Ampere CUDA (FP16 accumulation overflow)
WeightCtx wctx;
};

Expand All @@ -80,6 +81,12 @@ static void cond_ggml_init_backend(CondGGML * m) {
m->cpu_backend = bp.cpu_backend;
m->sched = backend_sched_new(bp, 8192);
m->use_flash_attn = true;
// Sub-Ampere tensor cores accumulate in FP16 (max 65504).
// Deep encoders can overflow to inf, causing NaN in rms_norm.
m->clamp_fp16 = (bp.gpu_cc > 0 && bp.gpu_cc < 800);
if (m->clamp_fp16) {
fprintf(stderr, "[CondEncoder] FP16 clamp enabled (cc=%d)\n", bp.gpu_cc);
}
}

// Load from ACEStep DiT GGUF
Expand Down Expand Up @@ -196,6 +203,9 @@ static void cond_ggml_forward(CondGGML * m,
lyric_h, lyric_pos, layer_mask, S_lyric,
m->use_flash_attn);
}
if (m->clamp_fp16) {
lyric_h = ggml_clamp(ctx, lyric_h, -65504.0f, 65504.0f);
}
lyric_h = qwen3_rms_norm(ctx, lyric_h, m->lyric_norm, m->lyric_cfg.rms_norm_eps);

ggml_set_name(lyric_h, "lyric_out");
Expand Down Expand Up @@ -242,6 +252,9 @@ static void cond_ggml_forward(CondGGML * m,
timbre_h, timbre_pos, layer_mask, S_ref,
m->use_flash_attn);
}
if (m->clamp_fp16) {
timbre_h = ggml_clamp(ctx, timbre_h, -65504.0f, 65504.0f);
}
timbre_h = qwen3_rms_norm(ctx, timbre_h, m->timbre_norm, m->timbre_cfg.rms_norm_eps);

// Take first frame: [2048, S_ref] -> view [2048, 1]
Expand Down
11 changes: 11 additions & 0 deletions src/qwen3-lm.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct Qwen3LM {
ggml_backend_t cpu_backend;
ggml_backend_sched_t sched;
bool use_flash_attn;
bool clamp_fp16; // clamp hidden state on sub-Ampere CUDA (FP16 accumulation overflow)

// KV cache: per-set, per-layer [D, max_seq, Nkv] f16
struct ggml_context * kv_ctx;
Expand Down Expand Up @@ -145,6 +146,10 @@ static void qw3lm_init_backend(Qwen3LM * m) {
m->cpu_backend = bp.cpu_backend;
m->sched = backend_sched_new(bp, 8192);
m->use_flash_attn = true;
m->clamp_fp16 = (bp.gpu_cc > 0 && bp.gpu_cc < 800);
if (m->clamp_fp16) {
fprintf(stderr, "[LM] FP16 clamp enabled (cc=%d)\n", bp.gpu_cc);
}
}

// Allocate KV cache
Expand Down Expand Up @@ -414,6 +419,9 @@ static void qw3lm_forward(Qwen3LM * m, const int * token_ids, int n_tokens,
norm = qwen3_rms_norm(ctx, hidden, ly->post_attn_layernorm, c.rms_norm_eps);
struct ggml_tensor * mlp = qwen3_build_mlp(ctx, ly, norm, n_tokens);
hidden = ggml_add(ctx, hidden, mlp);
if (m->clamp_fp16) {
hidden = ggml_clamp(ctx, hidden, -65504.0f, 65504.0f);
}
}

// Final norm
Expand Down Expand Up @@ -630,6 +638,9 @@ static void qw3lm_forward_batch(Qwen3LM * m, const int * token_ids,
norm = qwen3_rms_norm(ctx, hidden, ly->post_attn_layernorm, c.rms_norm_eps);
struct ggml_tensor * mlp = qwen3_build_mlp(ctx, ly, norm, N);
hidden = ggml_add(ctx, hidden, mlp);
if (m->clamp_fp16) {
hidden = ggml_clamp(ctx, hidden, -65504.0f, 65504.0f);
}
}

// Final norm + LM head
Expand Down
40 changes: 11 additions & 29 deletions src/vae.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,27 +248,19 @@ static void vae_ggml_load(VAEGGML * m, const char * path) {
vae_load_snake_inv(m->sb, gf, "decoder.snake1.beta");
vae_fuse_wn(m->c2w, gf, "decoder.conv2");

fprintf(stderr, "[VAE] Loaded: 5 blocks, upsample=1920x, BF16 activations\n");
fprintf(stderr, "[VAE] Loaded: 5 blocks, upsample=1920x, F32 activations\n");
gf_close(&gf);
}

// Graph building
// Snake activation (fused): y = x + sin^2(a * x) * inv_b
// x: [T, C], exp_a: [1, C], inv_b: [1, C] (pre-computed at load)
// Casts to BF16 before snake, back to F32 after.
static struct ggml_tensor * vae_snake(
struct ggml_context * ctx,
struct ggml_tensor * x,
struct ggml_tensor * exp_a,
struct ggml_tensor * inv_b) {
if (x->type == GGML_TYPE_F32) {
x = ggml_cast(ctx, x, GGML_TYPE_BF16);
}
x = ggml_snake(ctx, x, exp_a, inv_b);
if (x->type != GGML_TYPE_F32) {
x = ggml_cast(ctx, x, GGML_TYPE_F32);
}
return x;
return ggml_snake(ctx, x, exp_a, inv_b);
}

// Conv1d + bias: data [T, IC] -> [T_out, OC]
Expand Down Expand Up @@ -306,21 +298,10 @@ static struct ggml_tensor * vae_conv_t1d(
// w: [IC, K*OC] xt: [IC, T_in] -> col: [K*OC, T_in]
struct ggml_tensor * col = ggml_mul_mat(ctx, w, xt);

// Step 3: cast to BF16 before col2im_1d
if (col->type == GGML_TYPE_F32) {
col = ggml_cast(ctx, col, GGML_TYPE_BF16);
}

// Step 4: col2im - scatter-add columns to signal, fused padding crop
// [K*OC, T_in] -> [T_out, OC] where T_out = (T_in-1)*stride + K - 2*padding
// Step 3: col2im_1d scatter-add (F32 path, no BF16 casts)
struct ggml_tensor * y = ggml_col2im_1d(ctx, col, stride, oc, padding);

// Step 5: cast back to F32
if (y->type != GGML_TYPE_F32) {
y = ggml_cast(ctx, y, GGML_TYPE_F32);
}

// Step 6: Add bias
// Step 4: Add bias
if (b) {
struct ggml_tensor * b2d = ggml_reshape_2d(ctx, b, 1, b->ne[0]);
y = ggml_add(ctx, y, b2d);
Expand Down Expand Up @@ -389,6 +370,7 @@ static int vae_ggml_compute(
ggml_free(m->graph_ctx);
free(m->graph_buf);
}

// Graph context (generous fixed allocation)
size_t ctx_size = ggml_tensor_overhead() * 1024 + ggml_graph_overhead_custom(8192, false);
m->graph_buf = (uint8_t *)malloc(ctx_size);
Expand All @@ -407,14 +389,15 @@ static int vae_ggml_compute(
ggml_build_forward_expand(m->graph, m->graph_output);

if (!ggml_backend_sched_alloc_graph(m->sched, m->graph)) {
fprintf(stderr, "[VAE] FATAL: graph alloc failed\n");
fprintf(stderr, "[VAE] FATAL: graph alloc failed for T=%d\n", T_latent);
ggml_free(ctx);
free(m->graph_buf);
m->graph_ctx = NULL;
m->graph_buf = NULL;
m->graph_T = 0;
return -1;
}

m->graph_ctx = ctx;
m->graph_T = T_latent;
fprintf(stderr, "[VAE] Graph: %d nodes, T_latent=%d\n",
Expand Down Expand Up @@ -462,11 +445,10 @@ static int vae_ggml_decode(
}

// Tiled decode: overlap-discard chunking for bounded VRAM usage.
// Matches Python handler.py tiled_decode / _tiled_decode_gpu:
// stride = chunk_size - 2*overlap
// For each tile: decode latent window with overlap context, trim to core, concatenate.
// Default chunk=256, overlap=64 matches Python handler.py fallback defaults.
// Python auto-tunes chunk by VRAM: >=24GB->512, >=16GB->384, >=12GB->256, <12GB->128.
// stride = chunk_size - 2*overlap
// For each tile: decode latent window with overlap context, trim to core, concatenate.
// Default chunk=256/overlap=64 matches reference code. Larger chunks (e.g. 1024)
// reduce tile count and improve throughput; use --vae-chunk/--vae-overlap to tune.
// Returns T_audio (total samples per channel) or -1 on error.
static int vae_ggml_decode_tiled(
VAEGGML * m,
Expand Down