From 373c3ccfc84cc62270a9c704103d3307e5ccdfe2 Mon Sep 17 00:00:00 2001 From: Pascal Date: Mon, 2 Mar 2026 08:47:20 +0100 Subject: [PATCH 1/3] vae: drop BF16 activation casts, decode in F32 throughout Snake and col2im_1d kernels compute in F32 internally, so the BF16 casts were round-trip bandwidth waste: 3 dispatches and 16 bytes/elem vs 1 dispatch and 8 bytes/elem. Removes 82 graph nodes per tile (417 -> 335). Weights stay BF16 in GGUF, mul_mat dequantizes on-the-fly. M2 Pro 16GB, 86.8s audio, Q8_0, chunk=1024 overlap=16: 38.89s -> 26.82s (-31%) --- src/vae.h | 40 +++++++++++----------------------------- 1 file changed, 11 insertions(+), 29 deletions(-) diff --git a/src/vae.h b/src/vae.h index d00d416..943376a 100644 --- a/src/vae.h +++ b/src/vae.h @@ -248,27 +248,19 @@ static void vae_ggml_load(VAEGGML * m, const char * path) { vae_load_snake_inv(m->sb, gf, "decoder.snake1.beta"); vae_fuse_wn(m->c2w, gf, "decoder.conv2"); - fprintf(stderr, "[VAE] Loaded: 5 blocks, upsample=1920x, BF16 activations\n"); + fprintf(stderr, "[VAE] Loaded: 5 blocks, upsample=1920x, F32 activations\n"); gf_close(&gf); } // Graph building // Snake activation (fused): y = x + sin^2(a * x) * inv_b // x: [T, C], exp_a: [1, C], inv_b: [1, C] (pre-computed at load) -// Casts to BF16 before snake, back to F32 after. static struct ggml_tensor * vae_snake( struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * exp_a, struct ggml_tensor * inv_b) { - if (x->type == GGML_TYPE_F32) { - x = ggml_cast(ctx, x, GGML_TYPE_BF16); - } - x = ggml_snake(ctx, x, exp_a, inv_b); - if (x->type != GGML_TYPE_F32) { - x = ggml_cast(ctx, x, GGML_TYPE_F32); - } - return x; + return ggml_snake(ctx, x, exp_a, inv_b); } // Conv1d + bias: data [T, IC] -> [T_out, OC] @@ -306,21 +298,10 @@ static struct ggml_tensor * vae_conv_t1d( // w: [IC, K*OC] xt: [IC, T_in] -> col: [K*OC, T_in] struct ggml_tensor * col = ggml_mul_mat(ctx, w, xt); - // Step 3: cast to BF16 before col2im_1d - if (col->type == GGML_TYPE_F32) { - col = ggml_cast(ctx, col, GGML_TYPE_BF16); - } - - // Step 4: col2im - scatter-add columns to signal, fused padding crop - // [K*OC, T_in] -> [T_out, OC] where T_out = (T_in-1)*stride + K - 2*padding + // Step 3: col2im_1d scatter-add (F32 path, no BF16 casts) struct ggml_tensor * y = ggml_col2im_1d(ctx, col, stride, oc, padding); - // Step 5: cast back to F32 - if (y->type != GGML_TYPE_F32) { - y = ggml_cast(ctx, y, GGML_TYPE_F32); - } - - // Step 6: Add bias + // Step 4: Add bias if (b) { struct ggml_tensor * b2d = ggml_reshape_2d(ctx, b, 1, b->ne[0]); y = ggml_add(ctx, y, b2d); @@ -389,6 +370,7 @@ static int vae_ggml_compute( ggml_free(m->graph_ctx); free(m->graph_buf); } + // Graph context (generous fixed allocation) size_t ctx_size = ggml_tensor_overhead() * 1024 + ggml_graph_overhead_custom(8192, false); m->graph_buf = (uint8_t *)malloc(ctx_size); @@ -407,7 +389,7 @@ static int vae_ggml_compute( ggml_build_forward_expand(m->graph, m->graph_output); if (!ggml_backend_sched_alloc_graph(m->sched, m->graph)) { - fprintf(stderr, "[VAE] FATAL: graph alloc failed\n"); + fprintf(stderr, "[VAE] FATAL: graph alloc failed for T=%d\n", T_latent); ggml_free(ctx); free(m->graph_buf); m->graph_ctx = NULL; @@ -415,6 +397,7 @@ static int vae_ggml_compute( m->graph_T = 0; return -1; } + m->graph_ctx = ctx; m->graph_T = T_latent; fprintf(stderr, "[VAE] Graph: %d nodes, T_latent=%d\n", @@ -462,11 +445,10 @@ static int vae_ggml_decode( } // Tiled decode: overlap-discard chunking for bounded VRAM usage. -// Matches Python handler.py tiled_decode / _tiled_decode_gpu: -// stride = chunk_size - 2*overlap -// For each tile: decode latent window with overlap context, trim to core, concatenate. -// Default chunk=256, overlap=64 matches Python handler.py fallback defaults. -// Python auto-tunes chunk by VRAM: >=24GB->512, >=16GB->384, >=12GB->256, <12GB->128. +// stride = chunk_size - 2*overlap +// For each tile: decode latent window with overlap context, trim to core, concatenate. +// Default chunk=256/overlap=64 matches reference code. Larger chunks (e.g. 1024) +// reduce tile count and improve throughput; use --vae-chunk/--vae-overlap to tune. // Returns T_audio (total samples per channel) or -1 on error. static int vae_ggml_decode_tiled( VAEGGML * m, From 9480b60c438ef8f4c3fdb44300c47c70cb38806a Mon Sep 17 00:00:00 2001 From: Pascal Date: Mon, 2 Mar 2026 11:00:31 +0100 Subject: [PATCH 2/3] metal: add kernel_im2col_1d for flat 1D dispatch The generic kernel_im2col dispatches (IC, 1, OW) threadgroups with K threads each. For 1D convolutions with small kernels (k=1 or k=7), this wastes 78-97% of SIMD lanes (7 or 1 active threads per 32-wide SIMD group). Add a dedicated kernel_im2col_1d with flat dispatch identical to snake and col2im_1d: (total/256, 1, 1) threadgroups with 256 threads. The existing im2col dispatch branches on is_2D at runtime; the 2D path and kernel are unchanged. VAE decode benchmark (M2 Pro 16GB, 86.8s audio @ 48kHz stereo): chunk=256 overlap=64 old im2col: 71.2s 17 tiles chunk=1024 overlap=16 old im2col: 38.9s 3 tiles chunk=256 overlap=64 im2col_1d: 31.8s 17 tiles chunk=1024 overlap=16 im2col_1d: 18.3s 3 tiles --- README.md | 25 ++++++++++++++++++++++--- ggml | 2 +- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index a9beca0..25449d0 100644 --- a/README.md +++ b/README.md @@ -317,9 +317,9 @@ python3 debug-dit-cossim.py # DiT: per-layer cossim GGML vs Python (turbo/ ## Patched GGML fork -Uses a patched GGML fork (submodule) with two new ops and a CUDA bugfix for the Oobleck -VAE decoder. All backends: CPU, CUDA, Metal, Vulkan. F32/F16/BF16 data types. -The DiT uses only standard GGML ops and needs no patches. +Uses a patched GGML fork (submodule) with two new ops, a Metal im2col optimization, and +a CUDA bugfix for the Oobleck VAE decoder. All backends: CPU, CUDA, Metal, Vulkan. +F32/F16/BF16 data types. The DiT uses only standard GGML ops and needs no patches. The VAE reconstructs audio from latent space through 5 upsampling blocks (total 1920x), each running a transposed convolution followed by 3 WaveNet-style residual units with @@ -347,6 +347,25 @@ transposed convolutions. We decompose each as `mul_mat + col2im_1d`, routing the GEMM through cuBLAS/BLAS/MPS tensor cores. The col2im_1d gather has a 2-iteration inner loop and is pure bandwidth. BF16 cast nodes around col2im_1d halve the scatter bandwidth. +### Metal: `kernel_im2col_1d` (flat 1D dispatch) + +The generic Metal `kernel_im2col` dispatches (IC, 1, OW) threadgroups with K threads +each. For the VAE's 1D convolutions with small kernels (k=1 or k=7), this wastes 78-97% +of SIMD lanes (7 or 1 active threads per 32-wide SIMD group). The dedicated +`kernel_im2col_1d` uses a flat dispatch identical to snake and col2im_1d: +(total/256, 1, 1) threadgroups with 256 threads, achieving full SIMD utilization. +The dispatch branches on `is_2D` at runtime; the 2D path and kernel are unchanged. +CUDA and Vulkan already use flat dispatch and are not affected. + +VAE decode (M2 Pro 16GB, 86.8s audio @ 48kHz stereo): + +| chunk | overlap | im2col | tiles | time | +|------:|--------:|-----------|------:|-------:| +| 256 | 64 | generic | 17 | 71.2s | +| 1024 | 16 | generic | 3 | 38.9s | +| 256 | 64 | im2col_1d | 17 | 31.8s | +| 1024 | 16 | im2col_1d | 3 | 18.3s | + ### Bugfix: `im2col` gridDim.y overflow (CUDA) Upstream `im2col_kernel` uses OW directly as grid dimension Y, which exceeds the CUDA diff --git a/ggml b/ggml index 55e062a..4895202 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 55e062ab597eccaa3e7ee7c7b230197d83d94bc8 +Subproject commit 48952027390b2134ee03d02bc6eb71e8a1b05ed2 From fa33a09c7513db644201ed56aa4b377fcf1389e3 Mon Sep 17 00:00:00 2001 From: Pascal Date: Mon, 2 Mar 2026 12:19:42 +0100 Subject: [PATCH 3/3] fix: FP16 accumulation overflow on sub-Ampere CUDA (Volta/Turing) Sub-Ampere GPUs (cc < 800) use FP16 tensor core accumulation in GGML's mul_mat (max 65504). Deep transformer layers can overflow to inf, then rms_norm computes inf/inf = NaN, silently corrupting the pipeline: LM produces 0 audio codes, condition encoder feeds NaN to DiT, silent WAV. The fix detects GPU compute capability at init and conditionally clamps hidden states to [-65504, 65504] before rms_norm on affected hardware. On Ampere+ (FP32 accumulation), no clamp op is added, zero overhead. Tested on Jetson Xavier NX (sm_72) with Q8_0 models. src/backend.h Add gpu_cc to BackendPair. Query cc via forward- declared cudaDeviceGetAttribute (no cuda_runtime.h). src/cond-enc.h Clamp lyric/timbre encoder output before rms_norm when cc < 800. Prevents NaN in DiT cross-attention. src/qwen3-lm.h Clamp hidden state after each MLP residual in prefill and decode loops (36 layers). CMakeLists.txt Link CUDA::cudart for cudaDeviceGetAttribute. src/bpe.h Fix -Wconversion warning (int to char cast). Fixing this in GGML's mul_mat or rms_norm would require touching core operators across all architectures for a niche hardware edge case. Closes #4 --- CMakeLists.txt | 6 ++++++ src/backend.h | 17 +++++++++++++++++ src/bpe.h | 2 +- src/cond-enc.h | 13 +++++++++++++ src/qwen3-lm.h | 11 +++++++++++ 5 files changed, 48 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b37c978..afa9cd0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -41,6 +41,12 @@ macro(link_ggml_backends target) target_link_libraries(${target} PRIVATE ggml-${backend}) string(TOUPPER ${backend} BACKEND_UPPER) target_compile_definitions(${target} PRIVATE ACESTEP_HAVE_${BACKEND_UPPER}) + if(backend STREQUAL "cuda") + find_package(CUDAToolkit QUIET) + if(CUDAToolkit_FOUND) + target_link_libraries(${target} PRIVATE CUDA::cudart) + endif() + endif() endif() endforeach() endmacro() diff --git a/src/backend.h b/src/backend.h index 39b7978..0257363 100644 --- a/src/backend.h +++ b/src/backend.h @@ -7,6 +7,11 @@ #include "ggml-backend.h" #include "ggml-cpu.h" +#ifdef ACESTEP_HAVE_CUDA +// Query compute capability without pulling in cuda_runtime.h. +// cudaDeviceGetAttribute takes an int enum value; we pass the raw constants. +extern "C" int cudaDeviceGetAttribute(int *, int, int); +#endif #include #include #include @@ -14,6 +19,7 @@ struct BackendPair { ggml_backend_t backend; ggml_backend_t cpu_backend; + int gpu_cc; // CUDA compute capability (e.g. 720 for sm_72), 0 if not CUDA }; // Initialize backends: load all available (CUDA, Metal, Vulkan...), @@ -37,6 +43,17 @@ static BackendPair backend_init(const char * label) { } fprintf(stderr, "[Load] %s backend: %s (CPU threads: %d)\n", label, ggml_backend_name(bp.backend), n_threads); + + bp.gpu_cc = 0; +#ifdef ACESTEP_HAVE_CUDA + if (!best_is_cpu) { + int major = 0, minor = 0; + cudaDeviceGetAttribute(&major, 75, 0); // cudaDevAttrComputeCapabilityMajor + cudaDeviceGetAttribute(&minor, 76, 0); // cudaDevAttrComputeCapabilityMinor + bp.gpu_cc = major * 100 + minor * 10; + } +#endif + return bp; } diff --git a/src/bpe.h b/src/bpe.h index fe5ebc6..316e4d1 100644 --- a/src/bpe.h +++ b/src/bpe.h @@ -101,7 +101,7 @@ static std::vector gpt2_pre_tokenize(const std::string &text) { // case-insensitive compare for (int k = 0; k < slen; k++) { char c1 = rest[k], c2 = suffix[k]; - if (c1 >= 'A' && c1 <= 'Z') c1 += 32; + if (c1 >= 'A' && c1 <= 'Z') c1 = (char)(c1 + 32); if (c1 != c2) return false; } // next char should NOT be a letter diff --git a/src/cond-enc.h b/src/cond-enc.h index 880cbf7..eabb18c 100644 --- a/src/cond-enc.h +++ b/src/cond-enc.h @@ -70,6 +70,7 @@ struct CondGGML { ggml_backend_t cpu_backend; ggml_backend_sched_t sched; bool use_flash_attn; + bool clamp_fp16; // clamp encoder output on sub-Ampere CUDA (FP16 accumulation overflow) WeightCtx wctx; }; @@ -80,6 +81,12 @@ static void cond_ggml_init_backend(CondGGML * m) { m->cpu_backend = bp.cpu_backend; m->sched = backend_sched_new(bp, 8192); m->use_flash_attn = true; + // Sub-Ampere tensor cores accumulate in FP16 (max 65504). + // Deep encoders can overflow to inf, causing NaN in rms_norm. + m->clamp_fp16 = (bp.gpu_cc > 0 && bp.gpu_cc < 800); + if (m->clamp_fp16) { + fprintf(stderr, "[CondEncoder] FP16 clamp enabled (cc=%d)\n", bp.gpu_cc); + } } // Load from ACEStep DiT GGUF @@ -196,6 +203,9 @@ static void cond_ggml_forward(CondGGML * m, lyric_h, lyric_pos, layer_mask, S_lyric, m->use_flash_attn); } + if (m->clamp_fp16) { + lyric_h = ggml_clamp(ctx, lyric_h, -65504.0f, 65504.0f); + } lyric_h = qwen3_rms_norm(ctx, lyric_h, m->lyric_norm, m->lyric_cfg.rms_norm_eps); ggml_set_name(lyric_h, "lyric_out"); @@ -242,6 +252,9 @@ static void cond_ggml_forward(CondGGML * m, timbre_h, timbre_pos, layer_mask, S_ref, m->use_flash_attn); } + if (m->clamp_fp16) { + timbre_h = ggml_clamp(ctx, timbre_h, -65504.0f, 65504.0f); + } timbre_h = qwen3_rms_norm(ctx, timbre_h, m->timbre_norm, m->timbre_cfg.rms_norm_eps); // Take first frame: [2048, S_ref] -> view [2048, 1] diff --git a/src/qwen3-lm.h b/src/qwen3-lm.h index 29b254f..ce4f746 100644 --- a/src/qwen3-lm.h +++ b/src/qwen3-lm.h @@ -47,6 +47,7 @@ struct Qwen3LM { ggml_backend_t cpu_backend; ggml_backend_sched_t sched; bool use_flash_attn; + bool clamp_fp16; // clamp hidden state on sub-Ampere CUDA (FP16 accumulation overflow) // KV cache: per-set, per-layer [D, max_seq, Nkv] f16 struct ggml_context * kv_ctx; @@ -145,6 +146,10 @@ static void qw3lm_init_backend(Qwen3LM * m) { m->cpu_backend = bp.cpu_backend; m->sched = backend_sched_new(bp, 8192); m->use_flash_attn = true; + m->clamp_fp16 = (bp.gpu_cc > 0 && bp.gpu_cc < 800); + if (m->clamp_fp16) { + fprintf(stderr, "[LM] FP16 clamp enabled (cc=%d)\n", bp.gpu_cc); + } } // Allocate KV cache @@ -414,6 +419,9 @@ static void qw3lm_forward(Qwen3LM * m, const int * token_ids, int n_tokens, norm = qwen3_rms_norm(ctx, hidden, ly->post_attn_layernorm, c.rms_norm_eps); struct ggml_tensor * mlp = qwen3_build_mlp(ctx, ly, norm, n_tokens); hidden = ggml_add(ctx, hidden, mlp); + if (m->clamp_fp16) { + hidden = ggml_clamp(ctx, hidden, -65504.0f, 65504.0f); + } } // Final norm @@ -630,6 +638,9 @@ static void qw3lm_forward_batch(Qwen3LM * m, const int * token_ids, norm = qwen3_rms_norm(ctx, hidden, ly->post_attn_layernorm, c.rms_norm_eps); struct ggml_tensor * mlp = qwen3_build_mlp(ctx, ly, norm, N); hidden = ggml_add(ctx, hidden, mlp); + if (m->clamp_fp16) { + hidden = ggml_clamp(ctx, hidden, -65504.0f, 65504.0f); + } } // Final norm + LM head