Skip to content

Commit acf8a37

Browse files
committed
[update] Several Python Models
1 parent f3a241e commit acf8a37

13 files changed

Lines changed: 601 additions & 114 deletions

File tree

.vscode/extensions.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
"josetr.cmake-language-support-vscode",
66
"ms-vscode.cpptools", // C/C++
77
"ms-python.python", // Python
8-
"ms-python.black-formatter", // Python formatter
98
"njpwerner.autodocstring", // Python docstring generator
109
]
1110
}

csrc/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ include(${PROJECT_SOURCE_DIR}/cmake/libraries/libtorch.cmake)
3232
find_package(cxxopts CONFIG REQUIRED)
3333
find_package(fmt CONFIG REQUIRED)
3434
find_package(spdlog CONFIG REQUIRED)
35-
find_package(proxy CONFIG REQUIRED)
35+
find_package(msft_proxy4 CONFIG REQUIRED)
3636
find_package(yaml-cpp CONFIG REQUIRED)
3737
enable_testing()
3838
find_package(GTest CONFIG REQUIRED)

csrc/cmake/compilers/cxx-compiler-configs.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
3333
string(APPEND CMAKE_EXE_LINKER_FLAGS " /STACK:${STACK_SIZE}")
3434
# Clang ---------------------------------------------------------------------------------------------------------------
3535
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
36-
string(APPEND CMAKE_CXX_FLAGS " -fopenmp -Wall -Wextra -Werror")
36+
string(APPEND CMAKE_CXX_FLAGS " -stdlib=libc++ -fopenmp -Wall -Wextra -Werror")
3737
if (WIN32)
3838
string(APPEND CMAKE_EXE_LINKER_FLAGS " -Wl,-stack,${STACK_SIZE}")
3939
else()

csrc/lib/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ target_link_libraries(
2525
Python::Python
2626
${TORCH_LIBRARIES}
2727
fmt::fmt
28-
msft_proxy
28+
msft_proxy4::proxy
2929
)
3030

3131
target_compile_definitions(${LIB_NAME}

pmpp/models/attention.py

Lines changed: 6 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,86 +1,10 @@
1-
import math
21
from typing import Optional
32
import torch
43
from torch import nn
5-
from torch.nn import functional as F
6-
import numpy as np
4+
from .mha_kernels import MHAKernel
75

86

9-
def set_random_seed(
10-
seed: int, rank: int = 0, force_deterministic: bool = False
11-
) -> None:
12-
"""
13-
Set the random seed for numpy and torch.
14-
"""
15-
np.random.seed(seed + rank)
16-
torch.manual_seed(seed + rank)
17-
if force_deterministic:
18-
torch.backends.cudnn.deterministic = True
19-
torch.backends.cudnn.benchmark = False
20-
21-
22-
class MultiHeadSelfAttentionKernel(nn.Module):
23-
def __init__(self, hidden_dim: int, num_heads: int):
24-
super().__init__()
25-
26-
self.hidden_dim: int = hidden_dim
27-
self.num_heads: int = num_heads
28-
self.head_size: int = hidden_dim // num_heads
29-
30-
def forward(
31-
self,
32-
q: torch.Tensor,
33-
k: torch.Tensor,
34-
v: torch.Tensor,
35-
mask: Optional[torch.Tensor] = None,
36-
):
37-
"""
38-
Calculates softmax(Q @ KT / sqrt(dk)) @ V .
39-
40-
Parameters
41-
----------
42-
q : torch.Tensor; Shape: (q_len, hidden_dim)
43-
44-
k : torch.Tensor; Shape: (kv_len, hidden_dim)
45-
46-
v : torch.Tensor; Shape: (kv_len, hidden_dim)
47-
48-
mask: torch.Tensor; Shape: (q_len, kv_len), optional
49-
50-
Note
51-
----
52-
When prefilling, q_len equals to seq_len (number of tokens in the input
53-
seq);
54-
When decoding, q_len equals to 1, refering to the newly generated
55-
token. (Based on different sampling strategies, q_len could be larger
56-
than 1.)
57-
"""
58-
59-
q_len, kv_len = q.size(0), k.size(0)
60-
# q -> (num_heads, q_len, head_size)
61-
q = q.view(q_len, self.num_heads, self.head_size).transpose(0, 1)
62-
# k -> (num_heads, kv_len, head_size)
63-
k = k.view(kv_len, self.num_heads, self.head_size).transpose(0, 1)
64-
# v -> (num_heads, kv_len, head_size)
65-
v = v.view(kv_len, self.num_heads, self.head_size).transpose(0, 1)
66-
# scores -> (num_heads, q_len, kv_len)
67-
scores = torch.matmul(q, k.transpose(-1, -2)) / (self.head_size**0.5)
68-
scores = (
69-
scores.masked_fill(mask == 0, float("-inf"))
70-
if mask is not None
71-
else scores
72-
)
73-
# scores -> (num_heads, q_len, kv_len)
74-
attn_probs = F.softmax(scores, dim=-1)
75-
# out -> (num_heads, q_len, head_size)
76-
out = torch.matmul(attn_probs, v)
77-
# out -> (q_len, num_heads, head_size)
78-
out = out.transpose(0, 1).reshape(q_len, self.hidden_dim)
79-
80-
return out
81-
82-
83-
class MultiHeadSelfAttention(nn.Module):
7+
class MHA(nn.Module):
848
def __init__(
859
self,
8610
embed_dim: int,
@@ -97,7 +21,7 @@ def __init__(
9721
self.Wv = nn.Linear(embed_dim, hidden_dim)
9822
self.Wo = nn.Linear(hidden_dim, embed_dim)
9923

100-
self.attn_kernel = MultiHeadSelfAttentionKernel(hidden_dim, num_heads)
24+
self.attn_kernel = MHAKernel(hidden_dim, num_heads)
10125

10226
def forward(
10327
self,
@@ -135,7 +59,7 @@ def forward(
13559
v = self.Wv(seq)
13660

13761
# k_cache -> (kv_len + seq_len, hidden_dim)
138-
k = k if k_cache is None else torch.cat([k_cache, k.detach()], dim=0)
62+
k = k if k_cache is None else torch.cat([k_cache, k.detach()], dim=0)
13963
# v_cache -> (kv_len + seq_len, hidden_dim)
14064
v = v if v_cache is None else torch.cat([v_cache, v.detach()], dim=0)
14165

@@ -148,9 +72,7 @@ def forward(
14872
class TransformerBlock(nn.Module):
14973
def __init__(self, embed_dim, num_heads, hidden_dim, mlp_dim, dropout=0.1):
15074
super().__init__()
151-
self.attention = MultiHeadSelfAttention(
152-
embed_dim, num_heads, hidden_dim
153-
)
75+
self.attention = MHA(embed_dim, num_heads, hidden_dim)
15476
self.norm1 = nn.RMSNorm(embed_dim)
15577
self.norm2 = nn.RMSNorm(embed_dim)
15678
self.mlp = nn.Sequential(
@@ -293,8 +215,6 @@ def forward(
293215

294216

295217
if __name__ == "__main__":
296-
set_random_seed(114514)
297-
298218
seq_len = 4
299219
vocab_size = 1024
300220
embed_dim = 128
@@ -347,5 +267,5 @@ def forward(
347267
for i in range(1, n_generate):
348268
probs = lm(token, is_prefilling=False)
349269
token = torch.argmax(probs[-1, :], dim=-1, keepdim=True)
350-
print(f"The {i+1}th predicted token: {token}")
270+
print(f"The {i + 1}th predicted token: {token}")
351271
print(f"|- Token Shape: {token.shape}")

pmpp/models/grpo.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
PSEUDO-CODE FOR GRPO TRAINING (Group Relative Policy Optimization):
3+
p_model <- policy model (the one we're training)
4+
ref_model <- reference policy (frozen; for KL regularization)
5+
reward_model <- frozen model / callable that scores (prompt, resp)
6+
# [NOTE] No value model / critic. GRPO uses group-relative baselines.
7+
8+
for i in range(num_iterations):
9+
prompts <- [B, prompt_len]; Sampled from prompts dataset
10+
11+
# 1) Rollout: sample a GROUP of responses per prompt
12+
# Let G = num_generations_per_prompt (a.k.a. group size)
13+
resps <- [B, G, T]; Rollout from p_model on prompts (G samples each)
14+
old_logp <- [B, G, T]; Logprobs under rollout policy (snapshot of p_model)
15+
ref_logp <- [B, G, T]; Logprobs under ref_model for same tokens
16+
action_mask <- [B, G, T]; 1 on valid action tokens, 0 on padding
17+
18+
# 2) Compute rewards (often sequence-level)
19+
# reward_seq is per (prompt, resp) scalar score from reward_model
20+
reward_seq <- [B, G]; reward_model(prompts, resps)
21+
22+
# Optional: add formatting penalties, stop-token penalties, etc.
23+
# reward_seq <- reward_seq + extra_terms
24+
25+
# If you want token-shaped rewards, place seq reward on last valid token
26+
reward_tok <- zeros([B, G, T])
27+
last_idx <- last_valid_index(action_mask) # [B, G]
28+
reward_tok[b,g,last_idx[b,g]] += reward_seq[b,g] # scatter add
29+
30+
# 3) KL term (tokenwise, rollout policy vs reference)
31+
kl_tok <- [B, G, T]; kl_tok = old_logp - ref_logp
32+
33+
# 4) Optional KL shaping of reward (same idea as PPO RLHF)
34+
shaped_reward_tok <- [B, G, T];
35+
shaped_reward_tok = reward_tok - beta * kl_tok
36+
37+
# 5) Construct group-relative advantage (baseline from the GROUP)
38+
# Most common GRPO: baseline is mean reward within the group (per
39+
# prompt).
40+
# Use shaped (sequence) reward or unshaped, depending on your design.
41+
# Here: use shaped sequence reward = sum over tokens of
42+
# shaped_reward_tok on valid tokens.
43+
shaped_reward_seq <- [B, G]
44+
shaped_reward_seq[b,g] = sum_t(shaped_reward_tok[b,g,t] * action_mask[b,g,t])
45+
46+
group_mean <- [B, 1]; group_mean[b,1] = mean_g(shaped_reward_seq[b,g])
47+
group_std <- [B, 1]; group_std[b,1] = std_g(shaped_reward_seq[b,g]) + eps
48+
49+
adv <- [B, G]; adv = (shaped_reward_seq - group_mean) / group_std
50+
# Alternatively: adv = shaped_reward_seq - group_mean (no normalization)
51+
52+
# Broadcast to tokens if doing token-level PPO-style objective
53+
adv_tok <- [B, G, T]; adv_tok = adv[..., None] * action_mask
54+
55+
batch <- {prompts, resps, old_logp, ref_logp, adv_tok, action_mask}
56+
57+
# 6) Policy optimization (PPO-style clipped objective, but no value loss)
58+
for epoch in range(num_epochs_per_rollout):
59+
for minibatch in iterate_minibatches(batch, mb_size):
60+
new_logp <- [mb, T]; logprobs from (updated) p_model on minibatch prompts+resps
61+
old_logp_mb <- [mb, T]; from minibatch old_logp
62+
adv_tok_mb <- [mb, T]; from minibatch adv_tok
63+
mask_mb <- [mb, T]; from minibatch action_mask
64+
65+
# PPO ratio per token
66+
log_ratio <- new_logp - old_logp_mb
67+
ratio <- exp(log_ratio)
68+
69+
# GRPO policy gradient loss (clipped), averaged over valid tokens
70+
unclipped <- -adv_tok_mb * ratio
71+
clipped <- -adv_tok_mb * clip(ratio, 1-eps_clip, 1+eps_clip)
72+
pg_loss_tok <- max(unclipped, clipped)
73+
pg_loss <- sum(pg_loss_tok * mask_mb) / sum(mask_mb)
74+
75+
# Optional entropy bonus
76+
ent_bonus <- entropy_from_logits(...) # or from new_logp if available
77+
loss <- pg_loss - ent_coef * ent_bonus
78+
79+
optimizer.zero_grad()
80+
loss.backward()
81+
clip_grad_norm_(p_model.parameters(), max_grad_norm) # common
82+
optimizer.step()
83+
84+
# Optional: monitor approximate KL to ref, early stop if KL too large
85+
# approx_kl = mean( (new_logp - ref_logp_mb) * mask_mb )
86+
"""

pmpp/models/kldiv.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
3+
4+
def kl_divergence_from_log_probs(
5+
log_p: torch.Tensor,
6+
log_q: torch.Tensor,
7+
reduction: str = "batchmean",
8+
eps: float = 0.0,
9+
) -> torch.Tensor:
10+
"""
11+
KL(P || Q) = sum(P * (log P - log Q))
12+
where log_p = log P, log_q = log Q along the last dim.
13+
14+
Args:
15+
log_p: (..., K) log-probabilities for P (target)
16+
log_q: (..., K) log-probabilities for Q (prediction)
17+
reduction: "none" | "sum" | "mean" | "batchmean"
18+
eps: optional additive smoothing in prob space; usually keep 0.0
19+
20+
Returns:
21+
KL divergence with the chosen reduction.
22+
"""
23+
if eps != 0.0:
24+
# Smooth in prob space then re-normalize
25+
p = log_p.exp()
26+
q = log_q.exp()
27+
p = p + eps
28+
q = q + eps
29+
p = p / p.sum(dim=-1, keepdim=True)
30+
q = q / q.sum(dim=-1, keepdim=True)
31+
log_p = (p.clamp_min(1e-30)).log()
32+
log_q = (q.clamp_min(1e-30)).log()
33+
34+
p = log_p.exp()
35+
kl_per_elem = p * (log_p - log_q) # (..., K)
36+
kl = kl_per_elem.sum(dim=-1) # (...,)
37+
38+
if reduction == "none":
39+
return kl # Shape: (...)
40+
if reduction == "sum":
41+
return kl.sum() # Scalar
42+
if reduction == "mean":
43+
return kl.mean() # Scalar
44+
if reduction == "batchmean":
45+
return kl.sum() / kl.shape[0] # Scalar
46+
raise ValueError(f"Unknown reduction: {reduction}")

pmpp/models/mha_kernels.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from typing import Optional
6+
7+
8+
class MHAKernel(nn.Module):
9+
def __init__(self, hidden_dim: int, num_heads: int):
10+
super().__init__()
11+
12+
self.hidden_dim: int = hidden_dim
13+
self.num_heads: int = num_heads
14+
self.head_size: int = hidden_dim // num_heads
15+
16+
def forward(
17+
self,
18+
q: torch.Tensor,
19+
k: torch.Tensor,
20+
v: torch.Tensor,
21+
mask: Optional[torch.Tensor] = None,
22+
):
23+
"""
24+
Calculates softmax(Q @ KT / sqrt(dk)) @ V .
25+
26+
Parameters
27+
----------
28+
q : torch.Tensor; Shape: (q_len, hidden_dim)
29+
30+
k : torch.Tensor; Shape: (kv_len, hidden_dim)
31+
32+
v : torch.Tensor; Shape: (kv_len, hidden_dim)
33+
34+
mask: torch.Tensor; Shape: (q_len, kv_len), optional
35+
36+
Note
37+
----
38+
When prefilling, q_len equals to seq_len (number of tokens in the input
39+
seq);
40+
When decoding, q_len equals to 1, refering to the newly generated
41+
token. (Based on different sampling strategies, q_len could be larger
42+
than 1.)
43+
"""
44+
45+
q_len, kv_len = q.size(0), k.size(0)
46+
# q -> (num_heads, q_len, head_size)
47+
q = q.reshape(q_len, self.num_heads, self.head_size).transpose(0, 1)
48+
# k -> (num_heads, kv_len, head_size)
49+
k = k.reshape(kv_len, self.num_heads, self.head_size).transpose(0, 1)
50+
# v -> (num_heads, kv_len, head_size)
51+
v = v.reshape(kv_len, self.num_heads, self.head_size).transpose(0, 1)
52+
# scores -> (num_heads, q_len, kv_len)
53+
scores = torch.matmul(q, k.transpose(-1, -2)) / (self.head_size**0.5)
54+
scores = (
55+
scores.masked_fill(mask == 0, float("-inf"))
56+
if mask is not None
57+
else scores
58+
)
59+
# scores -> (num_heads, q_len, kv_len)
60+
attn_probs = F.softmax(scores.to(torch.float32), dim=-1).type_as(
61+
scores
62+
)
63+
# out -> (num_heads, q_len, head_size)
64+
out = torch.matmul(attn_probs, v)
65+
# out -> (q_len, num_heads, head_size)
66+
out = out.transpose(0, 1).reshape(q_len, self.hidden_dim)
67+
68+
return out

0 commit comments

Comments
 (0)