diff --git a/books/compute/src/cuda/kernels/trimat_forward.md b/books/compute/src/cuda/kernels/trimat_forward.md
index 0fbc7aa..e2c1d80 100644
--- a/books/compute/src/cuda/kernels/trimat_forward.md
+++ b/books/compute/src/cuda/kernels/trimat_forward.md
@@ -1,11 +1,261 @@
-# Kernels for triangular matmul forward pass
+
-
+# Kernels for Triangular Matrix Multiplication (Trimat Forward Pass)
{{ #aipr_header }}
+## Introduction
+
+This pocket reference provides efficient GPU implementations of **triangular
+matrix multiplication**, as used in **causal self-attention** in autoregressive
+transformer models.
+For **causal (autoregressive) attention**, we only need the **lower triangle**
+of the attention matrix. That is, each token should only attend to current and
+previous tokens.
+
+Computing the full matrix is wasteful when only the lower triangle is needed.
+Triangular matrix multiplication is a specialized form of matrix multiplication,
+where instead of computing the full output matrix, only the **lower triangle**
+is computed. This leads to substantial computational savings.
+
+This guide explains a series of CUDA kernel implementations for the **Trimat
+Forward Pass**, based on the [llm.c](https://github.com/karpathy/llm.c/tree/master/dev/cuda)
+GitHub repository.
+These kernels avoid unnecessary computation and offer potential speedups over
+cuBLAS. They are introduced in increasing order of optimization:
+
+- [Kernel 1: `matmul_tri_naive`](#kernel-1-naive-implementation-matmul_tri_naive):
+ A simple nested loop implementation with no memory optimization.
+- [Kernel 2: `matmul_tri_registers`](#kernel-2-register-tiling-matmul_tri_registers):
+ Uses **register tiling** to reduce redundant memory loads.
+- [Kernel 3: `matmul_tri3`](#kernel-3-vectorized-loads-matmul_tri3): Adds **vectorized
+ memory access** using `float4` to improve memory coalescing.
+- [Kernel 4: `matmul_tri4`](#kernel-4-shared-memory-tiling-matmul_tri4):
+ Leverages **shared memory** tiling for inter-thread data reuse and further
+ performance gains.
+
+The next section, [Input, Output, and
+Computation](#input-output-and-computation), describes the tensor shapes, the
+configuration used in the examples, and the exact computation performed during
+the Trimat Forward Pass.
+
+## Input, Output, and Computation
+
+This section describes the structure of the input/output tensors and the
+computation performed by the trimat kernels.
+
+### Input Tensor
+
+The input tensor packs queries and keys (and values, though unused here) in
+the shape:
+
+$$
+(B, T, 3, NH, HS)
+$$
+
+where:
+
+- \\(B\\): Batch size
+- \\(T\\): Sequence length
+- \\(3\\): Stacked Query, Key, and Value vectors
+- \\(NH\\): Number of attention heads
+- \\(HS\\): Head size, where \\(HS = C / NH\\) and \\(C\\) is the total
+ channel size
+
+Only the \\(Q\\) and \\(K\\) portions of the input are used in this
+computation.
+
+### Output Tensor
+
+The output tensor has shape:
+
+$$
+(B, NH, T, T)
+$$
+
+where:
+
+- \\(B\\): Batch size
+- \\(NH\\): Number of attention heads
+- \\(T\\): Sequence length (used for both dimensions of the attention
+ matrix)
+
+Each output slice \\([b, nh]\\) contains the attention scores for batch \\(b\\)
+and head \\(nh\\).
+Values above the diagonal (i.e., when a token would attend to a future token)
+are ignored or masked (e.g., set to NaN).
+
+### Configuration Used
+
+The configurations used in the examples are:
+
+- \\(B = 8\\): Batch size
+- \\(T = 1024\\): Sequence length
+- \\(C = 768\\): Total channels
+- \\(NH = 12\\): Number of heads
+- \\(HS = 64\\): Head size, where \\(HS = C / NH\\)
+
+### Computation Goal
+
+The goal is to compute the scaled dot-product attention score between queries
+and keys:
+
+$$
+\text{out}[b][h][i][j] = \frac{Q[b][i][h] \cdot K[b][j][h]}{\sqrt{\text{HS}}}
+\quad \text{for } j \leq i
+$$
+
+That is, for each batch \\((b)\\), head \\((h)\\), and timestep pair \\((i, j)\\)
+such that \\(j \leq i\\), we compute the dot product between query vector
+\\(Q\[b\]\[i\]\[h\]\\) and key vector \\(K\[b\]\[j\]\[h\]\\).
+The upper triangle \\((j > i)\\) is skipped or masked due to the causal
+attention constraint.
+
+### Mathematical Illustration
+
+To illustrate what this computation is accomplishing mathematically, consider
+the following example:
+
+Let \\(X\\) and \\(Y\\) be two 3×3 matrices. In a full matrix multiplication,
+we would compute:
+
+$$
+Z = X \cdot Y =
+\begin{bmatrix}
+\sum_{i=1}^3 x_{1,i} y_{i,1} & \sum_{i=1}^3 x_{1,i} y_{i,2} &
+\sum_{i=1}^3 x_{1,i} y_{i,3} \\\\
+\sum_{i=1}^3 x_{2,i} y_{i,1} & \sum_{i=1}^3 x_{2,i} y_{i,2} &
+\sum_{i=1}^3 x_{2,i} y_{i,3} \\\\
+\sum_{i=1}^3 x_{3,i} y_{i,1} & \sum_{i=1}^3 x_{3,i} y_{i,2} &
+\sum_{i=1}^3 x_{3,i} y_{i,3}
+\end{bmatrix}
+$$
+
+However, in **triangular (causal) matrix multiplication**, we only compute the
+**lower triangle**:
+
+$$
+Z_{\text{causal}} =
+\begin{bmatrix}
+\sum_{i=1}^3 x_{1,i} y_{i,1} & 0 & 0 \\\\
+\sum_{i=1}^3 x_{2,i} y_{i,1} & \sum_{i=1}^3 x_{2,i} y_{i,2} & 0 \\\\
+\sum_{i=1}^3 x_{3,i} y_{i,1} & \sum_{i=1}^3 x_{3,i} y_{i,2} &
+\sum_{i=1}^3 x_{3,i} y_{i,3}
+\end{bmatrix}
+$$
+
+This ensures that each row \\(i\\) only attends to columns \\(j \leq i\\),
+enforcing the causal constraint.
+
+## Kernel 1: Naive Implementation (`matmul_tri_naive`)
+
+This is the baseline GPU kernel, designed for clarity and correctness rather
+than performance.
+Each thread is responsible for computing an **8×8 tile** of the output
+attention matrix using a straightforward triple-nested loop.
+There are **no memory optimizations**; all reads are done directly from global
+memory.
+It is intentionally simple and mirrors a CPU-style nested loop structure to
+show what an unoptimized CUDA implementation looks like.
+
+### Key Characteristics of Kernel 1
+
+- **No shared memory** or caching.
+- **Each thread loads \\(Q[i]\\) and \\(K[j]\\)** directly from global memory.
+- Computes **64 dot products** per thread (8 queries × 8 keys).
+- Causal masking is enforced by skipping blocks where \\(j > i\\).
+- **Upper triangle is ignored**, though some redundant work may occur inside
+ diagonal blocks.
+
+Below is a visualization of how threads compute 8×8 blocks in the output
+matrix:
+
+
+
+
+
+
+
+## Kernel 2: Register Tiling (`matmul_tri_registers`)
+
+This kernel improves performance by leveraging **register tiling**.
+Each thread still computes an **8×8 tile** of the output, but instead of
+reading query and key vectors from global memory multiple times, each thread
+loads its \\(Q\\) and \\(K\\) vectors into registers for reuse.
+
+### Key Characteristics of Kernel 2
+
+- One thread per **8×8 tile**, same as Kernel 1.
+- \\(Q\\) and \\(K\\) values are loaded into **`float lhs[8]` and `float rhs[8]`**
+ arrays in registers.
+- Loops over the head size \\((HS)\\) to compute 64 dot products per thread.
+- **No shared memory**, but much better memory locality than Kernel 1.
+- Still performs some redundant computation above the diagonal (ignored due to
+ masking).
+- Faster than Kernel 1 due to fewer global loads.
+
+See **Figure 2** for a visualization of how registers are used to tile the data
+within a thread:
+
+
+
+
+
+
+
+## Kernel 3: Vectorized Loads (`matmul_tri3`)
+
+This kernel builds on Kernel 2 by introducing **vectorized and coalesced
+memory access** using `float4` loads.
+The goal is to improve global memory bandwidth utilization by aligning reads
+and writes to 16-byte boundaries.
+
+### Key Characteristics of Kernel 3
+
+- Each thread still computes an **8×8 tile** (64 dot products).
+- \\(Q\\) and \\(K\\) values are loaded using `float4` for better memory coalescing.
+- Improves memory access patterns by reducing the number of memory
+ transactions.
+- No shared memory; only register reuse + vectorized reads and writes.
+- Uses `ld_vec()` and `st_vec()` helper functions to safely cast pointers to
+ `float4`.
+- Faster than Kernel 2 due to reduced memory traffic.
+
+## Kernel 4: Shared Memory Tiling (`matmul_tri4`)
+
+This kernel introduces **shared memory tiling** to improve memory reuse across
+threads in a thread block.
+Threads collaborate to load tiles of the \\(Q\\) and \\(K\\) matrices into
+shared memory,
+significantly reducing global memory accesses.
+
+### Key Characteristics of Kernel 4
+
+- Uses shared memory arrays: `lhs_s[128][32]`, `rhs_s[128][32]`.
+- 16×16 threads cooperatively load **128 rows × 32 dimensions** from \\(Q\\)
+ and \\(K\\) into shared memory.
+- Computes **8×8 tiles** per thread, iterating over \\(HS / 32\\) slices to
+ accumulate dot products.
+- Final results are written with **vectorized `float4` stores** for efficient
+ global memory writes.
+
+See **Figure 4** for an illustration of shared memory tiling and accumulation:
+
+
+