Skip to content

vae: drop BF16 activation casts, decode in F32 throughout#7

Merged
lmangani merged 3 commits intoaudiohacking:masterfrom
ServeurpersoCom:master
Mar 2, 2026
Merged

vae: drop BF16 activation casts, decode in F32 throughout#7
lmangani merged 3 commits intoaudiohacking:masterfrom
ServeurpersoCom:master

Conversation

@lmangani
Copy link

@lmangani lmangani commented Mar 2, 2026

Snake and col2im_1d kernels compute in F32 internally, so the BF16 casts were round-trip bandwidth waste: 3 dispatches and 16 bytes/elem vs 1 dispatch and 8 bytes/elem.

Removes 82 graph nodes per tile (417 -> 335).
Weights stay BF16 in GGUF, mul_mat dequantizes on-the-fly.

M2 Pro 16GB, 86.8s audio, Q8_0, chunk=1024 overlap=16:
38.89s -> 26.82s (-31%)

Snake and col2im_1d kernels compute in F32 internally, so the
BF16 casts were round-trip bandwidth waste: 3 dispatches and
16 bytes/elem vs 1 dispatch and 8 bytes/elem.

Removes 82 graph nodes per tile (417 -> 335).
Weights stay BF16 in GGUF, mul_mat dequantizes on-the-fly.

M2 Pro 16GB, 86.8s audio, Q8_0, chunk=1024 overlap=16:
  38.89s -> 26.82s (-31%)
The generic kernel_im2col dispatches (IC, 1, OW) threadgroups with K
threads each. For 1D convolutions with small kernels (k=1 or k=7),
this wastes 78-97% of SIMD lanes (7 or 1 active threads per 32-wide
SIMD group).

Add a dedicated kernel_im2col_1d with flat dispatch identical to
snake and col2im_1d: (total/256, 1, 1) threadgroups with 256 threads.
The existing im2col dispatch branches on is_2D at runtime; the 2D
path and kernel are unchanged.

VAE decode benchmark (M2 Pro 16GB, 86.8s audio @ 48kHz stereo):

  chunk=256  overlap=64  old im2col: 71.2s  17 tiles
  chunk=1024 overlap=16  old im2col: 38.9s   3 tiles
  chunk=256  overlap=64  im2col_1d:  31.8s  17 tiles
  chunk=1024 overlap=16  im2col_1d:  18.3s   3 tiles
@lmangani lmangani marked this pull request as ready for review March 2, 2026 10:07
Sub-Ampere GPUs (cc < 800) use FP16 tensor core accumulation in GGML's
mul_mat (max 65504). Deep transformer layers can overflow to inf, then
rms_norm computes inf/inf = NaN, silently corrupting the pipeline: LM
produces 0 audio codes, condition encoder feeds NaN to DiT, silent WAV.

The fix detects GPU compute capability at init and conditionally clamps
hidden states to [-65504, 65504] before rms_norm on affected hardware.
On Ampere+ (FP32 accumulation), no clamp op is added, zero overhead.

Tested on Jetson Xavier NX (sm_72) with Q8_0 models.

  src/backend.h     Add gpu_cc to BackendPair. Query cc via forward-
                    declared cudaDeviceGetAttribute (no cuda_runtime.h).
  src/cond-enc.h    Clamp lyric/timbre encoder output before rms_norm
                    when cc < 800. Prevents NaN in DiT cross-attention.
  src/qwen3-lm.h    Clamp hidden state after each MLP residual in
                    prefill and decode loops (36 layers).
  CMakeLists.txt    Link CUDA::cudart for cudaDeviceGetAttribute.
  src/bpe.h         Fix -Wconversion warning (int to char cast).

Fixing this in GGML's mul_mat or rms_norm would require touching core
operators across all architectures for a niche hardware edge case.

Closes #4
@lmangani lmangani merged commit 85b1e29 into audiohacking:master Mar 2, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants