Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,5 @@ DMax
LLaDA
ep_execution_plan.md
parallel_state_redesign.md
experiment/
ckpts/
68 changes: 57 additions & 11 deletions diffulex/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,54 @@
logger = get_logger(__name__)


def maybe_override_mask_token_id(
config: Config,
tokenizer,
*,
mask_token_id_explicit: bool = False,
) -> None:
"""Resolve mask token id from model artifacts when config still uses a placeholder default.

Some tokenizers, including LLaDA, do not expose ``tokenizer.mask_token_id`` even though the
model config carries the correct ``mask_token_id``. In those cases, server startup would fall
back to the global default and corrupt the denoising buffer initialization.

When callers provide ``mask_token_id``, keep that value and skip tokenizer / HF overrides.
"""

if mask_token_id_explicit:
return

default_mask_token_id = Config.mask_token_id

tokenizer_mask_token_id = getattr(tokenizer, "mask_token_id", None)
if (
tokenizer_mask_token_id is not None
and config.mask_token_id == default_mask_token_id
and int(tokenizer_mask_token_id) != default_mask_token_id
):
logger.warning(
"Overriding default mask_token_id from %s to tokenizer mask_token_id %s.",
config.mask_token_id,
tokenizer_mask_token_id,
)
config.mask_token_id = int(tokenizer_mask_token_id)
return
Comment thread
coderabbitai[bot] marked this conversation as resolved.

hf_mask_token_id = getattr(getattr(config, "hf_config", None), "mask_token_id", None)
if (
hf_mask_token_id is not None
and config.mask_token_id == default_mask_token_id
and int(hf_mask_token_id) != default_mask_token_id
):
logger.warning(
"Overriding default mask_token_id from %s to hf_config mask_token_id %s.",
config.mask_token_id,
hf_mask_token_id,
)
config.mask_token_id = int(hf_mask_token_id)


class DiffulexEngine(DiffulexAsyncEngineMixin):
def __init__(self, model, **kwargs):
config_fields = {field.name for field in fields(Config)}
Expand All @@ -38,24 +86,22 @@ def __init__(self, model, **kwargs):
)
self.ps = []
self.events = []
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True, trust_remote_code=True)
config.tokenizer_vocab_size = len(self.tokenizer)
config.eos = self.tokenizer.eos_token_id
maybe_override_mask_token_id(
config,
self.tokenizer,
mask_token_id_explicit="mask_token_id" in config_kwargs,
)

ctx = mp.get_context("spawn")
for i in range(1, self.model_parallel_world_size):
event = ctx.Event()
process = ctx.Process(target=AutoModelRunner.from_config, args=(config, i, event))
process.start()
self.ps.append(process)
self.events.append(event)
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True, trust_remote_code=True)
config.tokenizer_vocab_size = len(self.tokenizer)
config.eos = self.tokenizer.eos_token_id

if getattr(self.tokenizer, "mask_token_id", None) is not None and config.mask_token_id != self.tokenizer.mask_token_id:
logger.warning(
"Overriding mask_token_id from %s to tokenizer mask_token_id %s.",
config.mask_token_id,
self.tokenizer.mask_token_id,
)
config.mask_token_id = self.tokenizer.mask_token_id

self.model_runner = AutoModelRunner.from_config(config, 0, self.events)
self.scheduler: SchedulerBase | DataParallelScheduler = AutoScheduler.from_config(config)
Expand Down
16 changes: 7 additions & 9 deletions diffulex/sampler/base/no_shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,9 @@ def _prefill_mask_token_local_ids(req: DllmReq, block, req_logits: torch.Tensor)
return local_ids

if min(local_ids) < 0 or max(local_ids) >= req_logits.shape[0]:
raise IndexError(
"Prefill mask-token logits index out of bounds: "
f"req_id={getattr(req, 'req_id', '?')}, "
f"block_id={getattr(block, 'block_id', '?')}, "
f"in_cache_len={prefix_offset}, "
f"global_ids={block.mask_token_global_ids}, "
f"local_ids={local_ids}, "
f"req_logits_len={req_logits.shape[0]}"
)
# Mixed prefill batches can yield partial q_len for a req in one step.
# Skip this block this step and retry when its logits slice is present.
return []
return local_ids

def forward(
Expand Down Expand Up @@ -86,11 +80,15 @@ def forward(
continue

if attn_metadata.is_prefill[idx]:
if getattr(req, "_resume_prefill_until", 0) > 0 and getattr(block, "start", 0) >= req.running_len:
continue
# Prefix-cache prefill can produce q_len=0 for some requests in mixed batches.
# In that case there are no logits to sample for this req in this step.
if req_logits.shape[0] == 0:
continue
local_ids = self._prefill_mask_token_local_ids(req, block, req_logits)
if not local_ids:
continue
mask_token_logits = req_logits[local_ids, ...]
else:
buf_offset = block.start - req.dllm_block_buffer.first_running_block.start
Expand Down
6 changes: 5 additions & 1 deletion diffulex/sampler/base/shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def evict_req_states(self, req_ids: list[int] | list[str]) -> None:
def _fetch_last_logits(self, logits: torch.Tensor, req: DllmReq) -> torch.Tensor:
req_id_str = str(req.req_id)
if req.has_to_cache_block:
return self._cache_last_logits(req_id_str, logits[req.to_cache_last_token_id])
idx = int(req.to_cache_last_token_id)
if 0 <= idx < logits.shape[0]:
return self._cache_last_logits(req_id_str, logits[idx])

if req_id_str in self.req_last_logits_map:
return self.req_last_logits_map[req_id_str]
Expand Down Expand Up @@ -113,6 +115,8 @@ def forward(
if shifted_logits.shape[0] == 0:
continue
local_ids = DllmSamplerNoShiftBase._prefill_mask_token_local_ids(req, block, shifted_logits)
if not local_ids:
continue
mask_token_logits = shifted_logits[local_ids, ...]
else:
buf_offset = block.start - req.dllm_block_buffer.first_running_block.start
Expand Down
8 changes: 7 additions & 1 deletion diffulex/server/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ServerArgs:
model_name: str = "dream"
decoding_strategy: str = "d2f"
sampling_mode: str = "naive"
mask_token_id: int | None = None
tensor_parallel_size: int = 1
data_parallel_size: int = 1
master_addr: str = "localhost"
Expand Down Expand Up @@ -53,7 +54,7 @@ class ServerArgs:
pre_merge_lora: bool = False

def engine_kwargs(self) -> dict:
return {
kwargs = {
"model_name": self.model_name,
"decoding_strategy": self.decoding_strategy,
"sampling_mode": self.sampling_mode,
Expand Down Expand Up @@ -88,6 +89,9 @@ def engine_kwargs(self) -> dict:
"lora_path": self.lora_path,
"pre_merge_lora": self.pre_merge_lora,
}
if self.mask_token_id is not None:
kwargs["mask_token_id"] = self.mask_token_id
return kwargs


def build_arg_parser() -> argparse.ArgumentParser:
Expand All @@ -102,6 +106,7 @@ def build_arg_parser() -> argparse.ArgumentParser:
parser.add_argument("--model-name", default="dream")
parser.add_argument("--decoding-strategy", default="d2f")
parser.add_argument("--sampling-mode", default="naive", choices=["naive", "edit"])
parser.add_argument("--mask-token-id", type=int, default=None)
parser.add_argument("--tensor-parallel-size", type=int, default=1)
parser.add_argument("--data-parallel-size", type=int, default=1)
parser.add_argument("--master-addr", default="localhost")
Expand Down Expand Up @@ -145,6 +150,7 @@ def parse_args(argv: Sequence[str] | None = None) -> ServerArgs:
model_name=ns.model_name,
decoding_strategy=ns.decoding_strategy,
sampling_mode=ns.sampling_mode,
mask_token_id=ns.mask_token_id,
tensor_parallel_size=ns.tensor_parallel_size,
data_parallel_size=ns.data_parallel_size,
master_addr=ns.master_addr,
Expand Down
Loading