[MISC] Speed up rigid constraint solver by fusing tile between Cholesky factor and solve.#2837
Conversation
Replaces 16x16 tile Cholesky in func_cholesky_factor_direct_tiled and func_cholesky_and_solve_fused_tiled with a 32x32 register-tile version (genesis/utils/_tile32.py, ported from _tile16.py by mechanical 16 -> 32 expansion). Kernels now run at block_dim=32 (full warp) instead of 16 (sub-warp), eliminating the half-warp idle penalty that the cholesky_mjw_vs_gs_2026may21 doc identified as a major driver of the 2.52x compute-only gap vs MJWarp. For dex_hand (n_dofs=62, tiled_n_dofs=64) the outer tile-block grid drops from 4x4=16 blocks at block_dim=16 to 2x2=4 blocks at block_dim=32 — same total FFMA work, twice the warp-execution throughput, fewer inter-tile sync boundaries. Other 32-multiple tile sizes (g1_fall at 32, franka at 32, etc.) collapse to 1x1=1 block. Diverges from T18-T20 pack-2 (which got 1.38x on factor kernel but lost -8.45 % overall on dex_hand): one env per warp here (no pack-2 inter-tile sync penalty, no skip-unchanged conflict). The standalone cholesky_solve_tiled kernel (block_dim=64, 2 warps per block) is unchanged. The non-tiled batch path (func_cholesky_factor_direct_batch) is unchanged. Includes tests/test_tile32_cholesky.py for direct numerical-equivalence validation against numpy on a 32x32 SPD matrix.
…oops Two changes on top of F3-S1: 1. _tile32.cholesky_: split the column-update dot reduction from 2 chains (dot0/dot1, ~16 FMAs deep at k=31) to 4 chains (dot0/dot1/dot2/dot3, ~8 FMAs deep at k=31). Cuts the back-to-back FMA dependency chain length another 2x for the 32-deep tile; _tile16's 2-way split was sufficient at 16-deep but at 32-deep we want more ILP. 2. solver.py func_cholesky_factor_direct_tiled and _fused_tiled: outer tile-block loops wrapped in qd.static(range(N_BLOCKS)) where N_BLOCKS is qd.static(tiled_n_dofs // 32). For dex_hand (tiled_n_dofs=64) this constant-folds the entire kb/ib/jb tile-block structure to N_BLOCKS=2. k0/i0/j0 also qd.static so the compiler can constant-propagate them into load3d/store3d/_resolve_vec3d call sites. S1 alone landed +2.62 % FPS on dex_hand (5-run, ±0.18 %), just shy of the +3 % decision gate; S2 aims to push past it via tighter compile-time folding and longer ILP.
…t-split S2 (8d6ab56) blew up compile time from 107s -> 928s (8.7x) and ran into pytest timeout on 4 of 5 dex_hand bench runs. The remaining run measured -3.20% FPS vs main; this is dominated by the compile-time-induced register spill / suboptimal codegen, not a real regression of the algorithm. Diagnosis: qd.static-wrapping the kb/ib/jb tile-block loops on top of the already-32-way-unrolled inner _ger_sub / _resolve_vec3d / cholesky_/ solve_triangular_ register-cascade ops generated >60k AST nodes for the factor + fused funcs combined. Quadrants AST -> PTX path apparently doesn't scale that far without spilling. S2b keeps the small surgical change that's safe (4-way dot-split inside _tile32.cholesky_, ~8-deep FMA chains vs 16-deep at k=31) and drops the problematic static unroll. solver.py reverts to S1 (5659357) state.
S2b bench landed at +2.13 % +/- 0.37 % vs WandB main 23012 = -0.47 % vs S1 (within combined CI 0.55 % so within noise, but trending negative). 4-way dot-split provides no gain over 2-way at the 32-deep POTRF chain, likely because the GPU scheduler already hides the FMA latency at 2-way and the extra accumulator regs / cross-chain adds eat the gain. Reverting to clean S1 state: tile32 + block_dim=32 + 2-way dot-split (same as _tile16). Branch head is now functionally equivalent to the original F3-S1 commit (5659357) on solver.py; _tile32.py is the clean port. This is the proposed final state for a PR. Headline: +2.62 % FPS on dex_hand vs WandB main 79a0e9b (5-run, +/-0.18 %), from killing the sub-warp execution penalty of the 16x16 register-tile Cholesky.
… of r0..r31 named fields Same algorithm as S1 (F3-final), but the per-thread tile row is now a single vector<32, dtype> field 'r' instead of 32 named scalar fields r0..r31. This eliminates every 32-way 'if k == N: self.rN = val' if-cascade that hand-rolled the runtime register-index dispatch in the named-field variant. With qd.static-folded indices (the common path -- inside cholesky_'s nested qd.static loops, eye_, _load3d, _store3d, _ger_sub, and from _resolve_vec3d call sites in solver.py) the 'self.r[k]' access lowers to the same direct register reference as 'self.rN' did via getattr in the named-field _r helper. With runtime indices (in _get_col / _set_col / _trsm) it lowers to a 32-way switch -- the same lowering the hand-rolled cascade produced. Generated PTX should be byte-identical or near-byte-identical to F3-S1. What changes is the *source* / *AST* node count: the named-field variant emitted 64 lines (32 if-stmts + 32 assigns) per cascade site, walked by the AST transformer even though all but one branch was dead-folded. The vector-storage variant emits a single 'self.r[<k>] = val' line per site. Across 7 cascade sites + cholesky_'s 2 cascades-per-outer-k-iteration, that collapses the tile-method AST from O(2k nodes) to O(few hundred). The _tile32.py file itself shrunk from 1077 to 505 lines. Goal: drop the +20s S1 vs main compile-time cost back toward zero without touching the runtime path.
…unc bodies
The first F4-A bench failed all 5 dex_hand runs with:
UserWarning: [PURE.VIOLATION] WARNING: Accessing global variable _TILE
<class 'int'> _TILE is in global vars, therefore violates pure
Triggered inside _trsm (and likely every other qd.func that referenced _TILE).
The original _tile32.py used a literal 32 inside qd.func bodies for exactly
this reason; module-level _TILE was OK to reference outside qd.func (at the
class factory level for vector(_TILE, dtype) and result.SIZE = _TILE).
Replace the in-func _TILE references with literal 32; keep the out-of-func
ones (which are evaluated at Python class-build time, not inside the AST
transformer).
F4-A vec32 storage compiled 10s faster (-50% of S1's overhead) but cost -19% runtime FPS on dex_hand — the vec32 didn't register-promote on cuda 7.x, fell back to local memory. F4-B uses 4 separate qd.types.vector(8, dtype) sub-banks 'b0..b3', each small enough to reliably register-promote (matching the 12x12 / 144-element per-thread matrices that quadrants's per-thread linalg ops register-promote). All hot indexing in this module is static (qd.static unrolls in cholesky_, _ger_sub, eye_, _load3d, _store3d), so the 4-way sub-bank dispatch (kb = k // 8, ko = k % 8) folds at trace time to direct sub-vector + intra-vector scalar access. The trace-time _static_read helper resolves the static-bank dispatch in pure Python (not @qd.func), producing a single field-access AST node per call site rather than a 4-way cascade. Write sites use an explicit 'if kb == N: self.bN[ko] = val' 4-way cascade — folded the same way. Only _get_col / _set_col / _trsm carry the runtime 4-way cascade, where it collapses to a switch over 4 banks (vs the original 32-way switch). _trsm itself is unchanged in structure since the cascade now lives in _get_col / _set_col. Source size: 601 lines (vs S1 1077, vec32 505). Cascade lines reduced ~10x vs S1.
…thon helper) The previous F4-B failed with 'Quadrants Expression object is not subscriptable' because the _static_read helper was a pure-Python function that tried to do 'self.b0[off]' outside the @qd.func AST context — quadrants Expression objects don't implement Python __getitem__, only the qd AST transformer can emit Subscript nodes against them. Fix: drop the helper, inline the 4-way 'if kb == 0: self.b0[ko] ... elif kb == 3: self.b3[ko]' dispatch directly inside cholesky_'s qd.static-unrolled outer/inner loops. kb = k // 8 and ko = k % 8 are python ints (k is a python int from qd.static), so the if-cascade folds at trace time — Python evaluates the const predicate during qd.static unroll, only the matching branch enters the AST. Same lowering as the original named-field approach, just with vec8 bank storage instead of 32 scalar fields. Also dropped the now-unused self_k_post re-read: after the diag write, only the tid==k lane mutates its register; the tid > k lanes still hold the original loaded col-k value, so reusing the self_k SSA above is correct.
…s not elif) The previous attempt failed at trace time with 'Name self_k is not defined' — quadrants treats variables introduced inside an if/elif chain as locally-scoped to that branch, not propagated to the outer scope, even when every branch assigns to the same name. Matches the pattern used in _tile16.cholesky_ where 'diag_val = qd.cast(0.0, dtype)' is pre-declared before the if-cascade that assigns to it. Fix: pre-declare self_k = qd.cast(0.0, dtype) and my_col = qd.cast(0.0, dtype) before each respective 4-way cascade, and switch the cascade to separate 'if kb == N:' statements (matching the original tile16 pattern) rather than elif/else. At trace time with kb being a python int from qd.static, each 'if kb == N' folds to True/False; only the matching branch emits AST. Same single-branch lowering as the original named-field cascade, just with self.bN[ko] writes instead of self.rN writes.
Both vec-storage attempts regressed runtime by ~18% on dex_hand: S1 (named r0..r31): 23614 FPS (+2.62%), compile 107.1s F4-A (vec32): 18662 FPS (-18.90%), compile 97.4s F4-B (vec8 x 4): 18915 FPS (-17.80%), compile 141.5s The vec32 single-field layout fell back to local memory on cuda 7.x (gpu register file doesn't promote 32-element per-thread vectors). The vec8 x 4 sub-bank layout also regressed — runtime hit suggests vec8 fields inside a qd.dataclass also don't reliably register-promote (vs vec8 locals in quadrants's per-thread linalg ops which do). F4-B was additionally SLOWER to compile than S1 because the 4-way bank-dispatch cascades emit more AST than the 32-way named-field cascades after qd.static folding (each cascade emits 4 if-statements regardless of whether only one branch is live, plus the pre-declared self_k = qd.cast(0.0, dtype) lines, plus the bank-vec subscript expressions are heavier than direct field references). Conclusion: vec storage is not viable for compile-time-only optimization of _tile32 without runtime regression on cuda 7.x. S1 remains the best known state. Compile-time-only F4 effort is a null result; documented in perso_hugh/doc/cholesky_tile32_2026may22.md. Keeping the F4-A/B history in git so the next investigation has the failure data on hand (e.g. when quadrants gets reliable register promotion for vec fields, F4-B's 4-way cascade approach should immediately compile faster than S1).
Drop-in replacement for S1's named-r0..r31 storage using the new
``qd.field_array(N, dtype)`` annotation on @qd.dataclass (quadrants branch
hp/qd-field-array). For python-int / qd.static-resolved indices the AST
transformer rewrites ``self.r[k]`` to ``self._r{k}``, so the generated LLVM
IR / PTX is byte-identical to S1's named-field form (verified on
chol_kernel + chol_trsm_kernel: every byte identical apart from the per-
session-nonce comment in the PTX trailer).
Changes vs S1:
- 32 named ``r0: dtype`` decls -> one ``r: qd.field_array(_TILE, dtype)``
annotation. Synthetic field names ``_r0.._r31`` remain.
- Inline 32-way write cascades in ``cholesky_`` (diagonal + off-diagonal),
``_ger_sub``, ``_load``, ``_store``, ``_load3d``, ``_store3d``, and
``eye_`` -> single ``self.r[k] = val`` lines (proposal-1 static rewrite).
- ``self._r(k)`` python-helper calls -> ``self.r[k]`` (also proposal-1).
- Dropped the ``_REGS`` lookup tuple + ``_r`` python helper method.
- ``_get_col`` / ``_set_col`` kept as cascade @qd.func (they handle the
runtime-k case used by ``_trsm``) but now reference the synthetic
``_rN`` names directly to bypass the field_array path.
Numerical correctness verified (chol_kernel: max|L-L_ref|=9.5e-7,
chol_trsm_kernel: max|X-X_ref|=6e-8; same as S1).
dex_hand benchmark (n_envs=4096, 3 cold-cache repeats; quadrants offline
cache cleared between every run):
S1 (named): compile mean=56.40s (56.04-56.79) runtime_fps=26521
FA (this) : compile mean=48.71s (48.18-49.56) runtime_fps=26583
-7.69s cold compile (-13.6%) with runtime FPS unchanged. Closes ~38% of
the original +20.3s S1-vs-F3 compile gap; in the middle of the doc's
5-10s estimate.
Source: 1068 -> 580 lines (-46%). PTX still byte-identical to S1.
Requires quadrants >= the hp/qd-field-array branch (commit 40e1b1275).
Opt-in via GS_FUSED_FACTOR_SOLVE_INIT=1 (gated on enable_tiled_cholesky_hessian) which routes the per-step Newton warm-start factor+solve through the existing func_cholesky_and_solve_fused_tiled instead of the separate factor + separate solve kernels. Saves the global-memory L roundtrip and one kernel launch. Microbench at dex_hand scale (4096 envs x 64 dofs x f32): the cholesky kernel work drops from 0.146 + 0.117 = 0.263 ms (factor + solve) to 0.256 ms (fused incl. L writeback to nt_H) -- a 2.7% saving on the cholesky work. End-to-end on dex_hand (5 cold-cache repeats each, --warmup 12 --record 3): baseline (separate): 26587 fps (std 64) fused warm-start: 26759 fps (std 58) delta: +171 fps (+0.64%) Body-iter time unchanged (9.62 ms vs baseline 10.04 ms, within noise). Three plumbing points: - RigidSimStaticConfig gets enable_fused_factor_solve_init (default False), set in rigid_solver.py from the GS_FUSED_FACTOR_SOLVE_INIT env var. - func_solve_init's Newton path skips the separate cholesky factor (the fused kernel does both) and dispatches to func_cholesky_and_solve_fused_tiled with write_L_to_nt_H=True. - func_cholesky_and_solve_fused_tiled takes a write_L_to_nt_H template arg (default False). When True it writes L (lower triangle) back into nt_H at the end, preserving the monolith body's incremental rank-1 Cholesky post-condition (nt_H holds L). The decomposed body's invocation keeps the default False so its iter-2+ patching path still finds H in nt_H. Why the writeback is required: the monolith body's func_solve_iter calls func_hessian_and_cholesky_factor_incremental_batch which reads nt_H as L and applies rank-1 updates to it. Without the writeback, nt_H would still hold H after the fused kernel, the incremental update would corrupt the factor, and convergence would take 5x more iters (per-call body time observed at 49 ms vs baseline 10 ms in an intermediate experiment without the writeback). Doc: perso_hugh/doc/dex_hand_tile_runtime_2026may23.md.
The step-3 lowertri writeback used `if i_d2 <= i_d1` to predicate-skip upper-tri stores. On dex_hand the warp-level branch divergence cost more wall-clock than the wasted writes saved: for the first 32 iters of the flat n_dofs*n_dofs walk only one of 32 lanes passes the predicate, so the warp is mostly idle there. Dropping the predicate (the new `fullsquare` variant) lets the warp issue stores at full rate and ride out the kernel tail faster. The extra 50% upper-tri writes are harmless because every nt_H reader in genesis (Hessian build, direct factor, incremental factor, solve, H-patcher) touches only the lower triangle. dex_hand cold-cache A/B (5 repeats, profiler off): baseline mean 26 590 fps std 54 lowertri mean 26 690 fps std 54 +0.38% fullsquare mean 26 798 fps std 113 +0.79% <-- new default Microbench data + 5 variants (fullsquare, lowertri, rowperlane, interleave, store3d, off) all stay available behind GS_FUSED_WB_VARIANT for future A/B runs. See perso_hugh/doc/dex_hand_fused_writeback_opt_2026may23.md for the full investigation.
….07% FPS on dex_hand) efc_force[i_c, i_b] was invariant across the inner i_d loop but the legacy code re-read it n_dofs=60 times per (env, constraint) from global memory. Hoist into a per-lane qd.Vector of size 4 (covers n_con <= 128) before the i_d loop; tail falls back to global re-read when n_con > 128. Cap chosen at 4 (not the full ceil(len_constraints/_K) ~ 25 on dex_hand): an earlier full-cache variant sized by len_constraints/32 cut the targeted kernel ~44% but regressed end-to-end FPS ~0.4% via neighbouring kernels slowing 2-4% from register-occupancy pressure. Cap=4 keeps register footprint at 128 regs/warp (~0.2% of file) so neighbouring kernels stay flat. Measured (8-round interleaved A/B, dex_hand 4096x, GS_ENABLE_NDARRAY=0): A (origin/main) mean = 23,533 FPS B (this branch) mean = 23,787 FPS B - A = +254 FPS (+1.07%) Per-call qfrc kernel (5-step profile): origin/main: 109.9 us this branch: 61.3 us (-44%) See doc/dex_hand_p5_p6_next_targets_2026may23.md and doc/p5_qfrc_register_cache_2026may23.md for the full investigation.
…-AI#2827 baseline Genesis-Embodied-AI#2827 widened the Cholesky kernel to Tile32x32 for n_dofs >= 17, tightening the per-warp register budget. Cap=4 (which won +1.07 % on the pre-Genesis-Embodied-AI#2827 Tile16x16 baseline) now regresses dex_hand by -1.42 % on cluster (RTX PRO 6000) even though the qfrc kernel itself sped up -77 %. Re-tuned on post-Genesis-Embodied-AI#2827 main (cluster A/B, 6 rounds, sd 41-46 FPS/run): cap=4: -1.42 % (~18 SEMs significant negative) cap=2: +0.69 % (~6 SEMs significant positive) <- WINNER cap=1: +0.63 % (~5 SEMs significant positive) Cap=2 still covers dex_hand active n_con ~ 55 fully (capacity 64); larger scenes hit the tail's global re-read path (unchanged from baseline). See perso_hugh/doc/p5_qfrc_register_cache_2026may23.md for full A/B and diff-profile data.
Combines the fused warm-start factor+solve (write_L_to_nt_H + fullsquare writeback variant) from hp/dex-hand-writeback-opt with the cholesky_tile_size dispatch (Tile16/Tile32) already on origin/main and the qfrc shmem cache from hp/p5-qfrc-shmem-cache (merged earlier). Conflict resolution: - genesis/utils/_tile32.py: kept origin/main's version (battle-tested PR Genesis-Embodied-AI#2827 implementation, API-compatible with the fused kernel). - _cholesky_and_solve_fused_tiled_impl: kept origin/main's TileCls- parametrised version, ported in the write_L_to_nt_H template arg and the WB_VARIANT-gated writeback epilogue (fullsquare/lowertri/ rowperlane/store3d). The writeback stride uses T (not hardcoded 32) so the Tile16 dispatch is also covered. - func_cholesky_and_solve_fused_tiled dispatcher gains the write_L_to_nt_H pass-through. - func_hessian_and_cholesky_factor_direct skips the standalone factor when enable_fused_factor_solve_init is on. - func_update_gradient_tiled routes through the fused kernel with write_L_to_nt_H=True under the same gate. - RigidSimStaticConfig keeps both cholesky_tile_size and enable_fused_factor_solve_init fields. - rigid_solver.py forwards both fields to the static config. Sanity smoke (single-run dex_hand, 4096 envs, warmup 12 / record 3): GS_FUSED_FACTOR_SOLVE_INIT=0 -> 28 351 fps GS_FUSED_FACTOR_SOLVE_INIT=1 -> 28 611 fps (+0.92 %) Co-authored-by: Cursor <cursoragent@cursor.com>
Per the rewrap-comments-120c skill: ran find_underwrapped.py --diff against origin/main and reflowed every reported run. Targets: - solver.py: WB_VARIANT module-top docs, fused-impl writeback intent block, fused-warm-start factor-skip comment, qfrc-cache P5 docstring + inline phase comments, warm-start dispatch comment in func_update_gradient_tiled. - rigid_solver.py: warm-start enable comment. - array_class.py: enable_fused_factor_solve_init docstring. - tests/test_tile32_cholesky.py: module docstring + test docstring + lower-tri-comparison inline comment. Also reworded the func_hessian_and_cholesky_factor_direct skip-comment to drop the over-long backtick'd kernel name so the line fits within 120c without an awkward continuation. No semantic changes. Co-authored-by: Cursor <cursoragent@cursor.com>
- Drop the GS_FUSED_WB_VARIANT and GS_FUSED_FACTOR_SOLVE_INIT env vars. fullsquare is now the only writeback shape (was the default). Fused warm-start is enabled whenever the tiled cholesky Hessian path is. - Strip experimental result numbers, internal doc references, and experiment labels from the qfrc-cache docstring; tighten phrasing. - Strip the writeback intent comment in _cholesky_and_solve_fused_tiled _impl and the warm-start dispatch comments in rigid_solver.py / func_update_gradient_tiled / func_hessian_and_cholesky_factor_direct to keep just what's needed to understand the call. - Drop the per-variant writeback branches in the fused impl now that only one variant survives. - Move tests/test_tile32_cholesky.py out of the genesis tree (the file exercises a register-tile primitive directly; it's tracked in Hugh's personal sandbox going forward). ruff check + ruff format pass via pre-commit. dex_hand smoke (single run, 4096 envs) at 28 606 fps -- same as before the cleanup. Co-authored-by: Cursor <cursoragent@cursor.com>
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 0cd5aba9ba
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
The sparse path runs the per-env Hessian build + Cholesky factor together inside func_hessian_and_cholesky_factor_direct_batch, leaving nt_H = L. Routing the warm-start through the fused kernel after that re-factors L as if it were H, producing NaN accelerations. Gate enable_fused_factor_solve_init on `not self._options.sparse_solve` so the sparse path stays on the existing factor-then-solve sequence. Repros locally: tests/test_rigid_physics_sparse.py::test_sparse_solve_no_nan[gpu] fails before this commit, passes after. dex_hand smoke (4096 envs) still routes through the fused path (28 461 fps, unchanged). Co-authored-by: Cursor <cursoragent@cursor.com>
|
@codex review |
|
Codex Review: Didn't find any major issues. Already looking forward to the next diff. ℹ️ About Codex in GitHubCodex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
If Codex has suggestions, it will comment; otherwise it will react with 👍. When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback". |
|
🙌 |

No description provided.