diff --git a/.gitignore b/.gitignore index 5d5f16e..f86a7f6 100755 --- a/.gitignore +++ b/.gitignore @@ -69,3 +69,5 @@ DMax LLaDA ep_execution_plan.md parallel_state_redesign.md +experiment/ +ckpts/ \ No newline at end of file diff --git a/diffulex/engine/engine.py b/diffulex/engine/engine.py index da0a6d4..d3811a5 100644 --- a/diffulex/engine/engine.py +++ b/diffulex/engine/engine.py @@ -21,6 +21,54 @@ logger = get_logger(__name__) +def maybe_override_mask_token_id( + config: Config, + tokenizer, + *, + mask_token_id_explicit: bool = False, +) -> None: + """Resolve mask token id from model artifacts when config still uses a placeholder default. + + Some tokenizers, including LLaDA, do not expose ``tokenizer.mask_token_id`` even though the + model config carries the correct ``mask_token_id``. In those cases, server startup would fall + back to the global default and corrupt the denoising buffer initialization. + + When callers provide ``mask_token_id``, keep that value and skip tokenizer / HF overrides. + """ + + if mask_token_id_explicit: + return + + default_mask_token_id = Config.mask_token_id + + tokenizer_mask_token_id = getattr(tokenizer, "mask_token_id", None) + if ( + tokenizer_mask_token_id is not None + and config.mask_token_id == default_mask_token_id + and int(tokenizer_mask_token_id) != default_mask_token_id + ): + logger.warning( + "Overriding default mask_token_id from %s to tokenizer mask_token_id %s.", + config.mask_token_id, + tokenizer_mask_token_id, + ) + config.mask_token_id = int(tokenizer_mask_token_id) + return + + hf_mask_token_id = getattr(getattr(config, "hf_config", None), "mask_token_id", None) + if ( + hf_mask_token_id is not None + and config.mask_token_id == default_mask_token_id + and int(hf_mask_token_id) != default_mask_token_id + ): + logger.warning( + "Overriding default mask_token_id from %s to hf_config mask_token_id %s.", + config.mask_token_id, + hf_mask_token_id, + ) + config.mask_token_id = int(hf_mask_token_id) + + class DiffulexEngine(DiffulexAsyncEngineMixin): def __init__(self, model, **kwargs): config_fields = {field.name for field in fields(Config)} @@ -38,6 +86,15 @@ def __init__(self, model, **kwargs): ) self.ps = [] self.events = [] + self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True, trust_remote_code=True) + config.tokenizer_vocab_size = len(self.tokenizer) + config.eos = self.tokenizer.eos_token_id + maybe_override_mask_token_id( + config, + self.tokenizer, + mask_token_id_explicit="mask_token_id" in config_kwargs, + ) + ctx = mp.get_context("spawn") for i in range(1, self.model_parallel_world_size): event = ctx.Event() @@ -45,17 +102,6 @@ def __init__(self, model, **kwargs): process.start() self.ps.append(process) self.events.append(event) - self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True, trust_remote_code=True) - config.tokenizer_vocab_size = len(self.tokenizer) - config.eos = self.tokenizer.eos_token_id - - if getattr(self.tokenizer, "mask_token_id", None) is not None and config.mask_token_id != self.tokenizer.mask_token_id: - logger.warning( - "Overriding mask_token_id from %s to tokenizer mask_token_id %s.", - config.mask_token_id, - self.tokenizer.mask_token_id, - ) - config.mask_token_id = self.tokenizer.mask_token_id self.model_runner = AutoModelRunner.from_config(config, 0, self.events) self.scheduler: SchedulerBase | DataParallelScheduler = AutoScheduler.from_config(config) diff --git a/diffulex/sampler/base/no_shift.py b/diffulex/sampler/base/no_shift.py index 2ee4f37..308d0b8 100644 --- a/diffulex/sampler/base/no_shift.py +++ b/diffulex/sampler/base/no_shift.py @@ -37,15 +37,9 @@ def _prefill_mask_token_local_ids(req: DllmReq, block, req_logits: torch.Tensor) return local_ids if min(local_ids) < 0 or max(local_ids) >= req_logits.shape[0]: - raise IndexError( - "Prefill mask-token logits index out of bounds: " - f"req_id={getattr(req, 'req_id', '?')}, " - f"block_id={getattr(block, 'block_id', '?')}, " - f"in_cache_len={prefix_offset}, " - f"global_ids={block.mask_token_global_ids}, " - f"local_ids={local_ids}, " - f"req_logits_len={req_logits.shape[0]}" - ) + # Mixed prefill batches can yield partial q_len for a req in one step. + # Skip this block this step and retry when its logits slice is present. + return [] return local_ids def forward( @@ -86,11 +80,15 @@ def forward( continue if attn_metadata.is_prefill[idx]: + if getattr(req, "_resume_prefill_until", 0) > 0 and getattr(block, "start", 0) >= req.running_len: + continue # Prefix-cache prefill can produce q_len=0 for some requests in mixed batches. # In that case there are no logits to sample for this req in this step. if req_logits.shape[0] == 0: continue local_ids = self._prefill_mask_token_local_ids(req, block, req_logits) + if not local_ids: + continue mask_token_logits = req_logits[local_ids, ...] else: buf_offset = block.start - req.dllm_block_buffer.first_running_block.start diff --git a/diffulex/sampler/base/shift.py b/diffulex/sampler/base/shift.py index c928536..98f2462 100644 --- a/diffulex/sampler/base/shift.py +++ b/diffulex/sampler/base/shift.py @@ -31,7 +31,9 @@ def evict_req_states(self, req_ids: list[int] | list[str]) -> None: def _fetch_last_logits(self, logits: torch.Tensor, req: DllmReq) -> torch.Tensor: req_id_str = str(req.req_id) if req.has_to_cache_block: - return self._cache_last_logits(req_id_str, logits[req.to_cache_last_token_id]) + idx = int(req.to_cache_last_token_id) + if 0 <= idx < logits.shape[0]: + return self._cache_last_logits(req_id_str, logits[idx]) if req_id_str in self.req_last_logits_map: return self.req_last_logits_map[req_id_str] @@ -113,6 +115,8 @@ def forward( if shifted_logits.shape[0] == 0: continue local_ids = DllmSamplerNoShiftBase._prefill_mask_token_local_ids(req, block, shifted_logits) + if not local_ids: + continue mask_token_logits = shifted_logits[local_ids, ...] else: buf_offset = block.start - req.dllm_block_buffer.first_running_block.start diff --git a/diffulex/server/args.py b/diffulex/server/args.py index 9c3e98f..18ae123 100644 --- a/diffulex/server/args.py +++ b/diffulex/server/args.py @@ -23,6 +23,7 @@ class ServerArgs: model_name: str = "dream" decoding_strategy: str = "d2f" sampling_mode: str = "naive" + mask_token_id: int | None = None tensor_parallel_size: int = 1 data_parallel_size: int = 1 master_addr: str = "localhost" @@ -53,7 +54,7 @@ class ServerArgs: pre_merge_lora: bool = False def engine_kwargs(self) -> dict: - return { + kwargs = { "model_name": self.model_name, "decoding_strategy": self.decoding_strategy, "sampling_mode": self.sampling_mode, @@ -88,6 +89,9 @@ def engine_kwargs(self) -> dict: "lora_path": self.lora_path, "pre_merge_lora": self.pre_merge_lora, } + if self.mask_token_id is not None: + kwargs["mask_token_id"] = self.mask_token_id + return kwargs def build_arg_parser() -> argparse.ArgumentParser: @@ -102,6 +106,7 @@ def build_arg_parser() -> argparse.ArgumentParser: parser.add_argument("--model-name", default="dream") parser.add_argument("--decoding-strategy", default="d2f") parser.add_argument("--sampling-mode", default="naive", choices=["naive", "edit"]) + parser.add_argument("--mask-token-id", type=int, default=None) parser.add_argument("--tensor-parallel-size", type=int, default=1) parser.add_argument("--data-parallel-size", type=int, default=1) parser.add_argument("--master-addr", default="localhost") @@ -145,6 +150,7 @@ def parse_args(argv: Sequence[str] | None = None) -> ServerArgs: model_name=ns.model_name, decoding_strategy=ns.decoding_strategy, sampling_mode=ns.sampling_mode, + mask_token_id=ns.mask_token_id, tensor_parallel_size=ns.tensor_parallel_size, data_parallel_size=ns.data_parallel_size, master_addr=ns.master_addr, diff --git a/diffulex/strategy_template/multi_block/engine/request.py b/diffulex/strategy_template/multi_block/engine/request.py index 2b8d755..008b14e 100644 --- a/diffulex/strategy_template/multi_block/engine/request.py +++ b/diffulex/strategy_template/multi_block/engine/request.py @@ -28,6 +28,8 @@ def init_multi_block(self: DllmReq, config: Config): self.is_multi_block = True self.status_history = [self.status] self.completion_reason = None + self._resume_prefill_until = 0 + self._terminal_context_block_id: int | None = None self.block_size = config.block_size self.buffer_size = config.buffer_size @@ -102,8 +104,6 @@ def init_multi_block(self: DllmReq, config: Config): def eos_token_generated(self) -> bool: if self.ignore_eos: return False - # Only inspect generated segment; prompt tokens may also contain chat delimiters - # such as <|im_end|>, which must not trigger immediate completion. generated_seq = self.token_ids[self.prefix_len :] return self.eos_token_id in generated_seq @@ -230,6 +230,8 @@ def chunk_size(self) -> int: @property def running_len(self) -> int: if self.is_prefilling: + if self._resume_prefill_until > 0: + return self._resume_prefill_until return ( (self.padded_prefix_len - self.block_size) + self.dllm_block_buffer.num_valid_blocks * self.block_size if self.is_padded @@ -314,7 +316,7 @@ def to_cache_seq_end(self) -> tuple[int, int]: @property def has_to_cache_blocks(self) -> bool: if self.is_prefilling: - return True + return self._prefill_visible_to_cache_last_global_id() is not None elif self.is_decoding: return len(self.dllm_block_buffer.to_cache_blocks) > 0 @@ -325,14 +327,69 @@ def has_to_cache_block(self) -> bool: @property def to_cache_last_token_id(self) -> int: if self.is_prefilling: - return self.to_cache_len - 1 if self.to_cache_len > 0 else 0 + window_start = int(self.contiguous_in_cache_prefix_len) + last_global = self._prefill_visible_to_cache_last_global_id() + if last_global is None: + return 0 + return int(last_global - window_start) n = len(self.dllm_block_buffer.to_cache_blocks) * self.block_size return n - 1 if n > 0 else 0 + def _prefill_visible_to_cache_last_global_id(self) -> int | None: + if not self.is_prefilling: + return None + + window_start = int(self.contiguous_in_cache_prefix_len) + window_end = int(self.running_len) + if window_end <= window_start: + return None + + last_global = None + for block in self.dllm_blocks: + if block.end <= window_start: + continue + if block.start >= window_end: + break + if not block.is_to_cache: + continue + candidate = min(block.end, window_end) - 1 + if candidate < window_start: + continue + last_global = candidate if last_global is None else max(last_global, candidate) + return last_global + @property def last_block_finished(self) -> bool: - inspected_block = self.dllm_block_buffer.first_running_block.prev_block - return inspected_block is not None and inspected_block.is_complete and inspected_block.is_last_in_context + terminal_block = self.terminal_context_block + return terminal_block is not None and terminal_block.is_complete + + @property + def terminal_context_block(self) -> DllmBlock | None: + block_id = getattr(self, "_terminal_context_block_id", None) + if block_id is None or not (0 <= int(block_id) < len(self.dllm_blocks)): + return None + return self.dllm_blocks[int(block_id)] + + def set_terminal_context_block(self, block: DllmBlock | None) -> None: + # The terminal context boundary is request state and must always point to a + # real block. Dummy blocks are only buffer-capacity placeholders. + while block is not None and block.is_dummy: + block = block.prev_block + if block is None: + self._terminal_context_block_id = None + return + + terminal_block_id = int(block.block_id) + self._terminal_context_block_id = terminal_block_id + + for dllm_block in self.dllm_blocks: + if dllm_block.is_dummy or dllm_block.block_id > terminal_block_id: + dllm_block.make_out_of_context() + elif dllm_block.block_id < terminal_block_id: + dllm_block.make_in_context() + else: + dllm_block.make_in_context() + dllm_block.make_last_in_context() @property def pure_prefill_without_mask_token(self) -> bool: @@ -357,6 +414,20 @@ def make_pending(self): def preempt(self): self.lazy_activate() self.log_status() + if self.is_multi_block: + rebuild_until = 0 + if self.is_decoding: + rebuild_until = int(self.running_seq_start) + self._resume_prefill_until = rebuild_until + # Scheduler preemption frees this req's page_table immediately after moving it + # back to WAITING. Any block still marked IN_CACHE would then point to KV pages + # that no longer belong to the req. Demote those blocks back to TO_CACHE so a + # resumed req rebuilds its cached prefix instead of reusing stale block state. + for block in self.dllm_blocks: + if rebuild_until > 0 and block.end <= rebuild_until and block.is_to_cache: + continue + if block.is_in_cache: + block.status = DllmBlockStatus.TO_CACHE self.status = DllmReqStatus.WAITING @property @@ -370,6 +441,9 @@ def lazy_activate(self): self.log_status() self.status = self.status_history[-1] + if self._resume_prefill_until > 0: + self.status = DllmReqStatus.PREFILLING + return if self.is_pending: self.status = DllmReqStatus.PREFILLING elif self.is_prefilling: @@ -399,6 +473,19 @@ def deactivate(self, reason: str | None = None): def step(self): self.lazy_activate() + + # Decode can livelock when the running buffer is fully occupied by + # IN_CACHE blocks (no ACTIVE/TO_CACHE block left). In that state we keep + # scheduling decode, but there is no writable frontier so no new token is + # ever accepted. Recycle the cached head block to reopen one slot. + if ( + self.is_decoding + and not self.dllm_block_buffer.active_blocks + and not self.dllm_block_buffer.to_cache_blocks + ): + head_block = self.dllm_block_buffer.first_running_block + if head_block.is_in_cache and not head_block.is_last_in_context: + head_block.status = DllmBlockStatus.TO_CACHE # Condition to activate the next block, when buffer contains active blocks activate_cond = self.dllm_block_buffer.should_add_block and not self.dllm_block_buffer.is_overflow @@ -431,9 +518,13 @@ def push_back_dummy_block(self): ) dllm_block.post_init_dllm_block(self, self.dllm_block_buffer) - if (self.max_new_tokens_reached or self.max_model_len_reached) and dllm_block.prev_block.is_in_context: - dllm_block.make_last_in_context() - elif dllm_block.prev_block.is_out_of_context or dllm_block.prev_block.is_last_in_context: + if ( + self.max_new_tokens_reached + or self.max_model_len_reached + or self.terminal_context_block is not None + or dllm_block.prev_block.is_out_of_context + or dllm_block.prev_block.is_last_in_context + ): dllm_block.make_out_of_context() self.dllm_blocks.append(dllm_block) @@ -446,6 +537,13 @@ def maybe_postprocess_prefix_blocks(self): for block_id in range(self.num_prefix_blocks): self.dllm_blocks[block_id].in_cache() + if self._resume_prefill_until > 0: + for block in self.dllm_blocks: + if block.end > self._resume_prefill_until: + break + if block.is_to_cache and block.is_complete: + block.in_cache() + def apply_cached_prefix_pages(self): if not self.is_multi_block: return @@ -482,10 +580,14 @@ def postprocess(self): elif block.is_dummy or block.is_active or block.is_in_cache: block_id += 1 + if self._resume_prefill_until > 0 and self.contiguous_in_cache_prefix_len >= self._resume_prefill_until: + self._resume_prefill_until = 0 + + if self.is_truncated: + self.set_terminal_context_block(self.dllm_block_buffer.last_valid_block) + if self.eos_token_generated: - self.dllm_block_buffer.last_valid_block.make_last_in_context() self.meet_eos = True - self.dllm_block_buffer.maybe_fix_context_management() if ( self.eos_token_generated diff --git a/diffulex/utils/loader.py b/diffulex/utils/loader.py index e54c6dd..89257bb 100755 --- a/diffulex/utils/loader.py +++ b/diffulex/utils/loader.py @@ -23,19 +23,40 @@ def load_lora_config(lora_path: str) -> dict: return {} -def enable_lora_for_model(model: nn.Module, lora_config: dict): - """Enable LoRA for existing linear layers in the model.""" +def enable_lora_for_model( + model: nn.Module, + lora_config: dict, + packed_modules_mapping: dict | None = None, +): + """Enable LoRA for existing linear layers in the model. + + `target_modules` from PEFT adapter_config refers to the *checkpoint* leaf + names (e.g. `attn_out`). When the local model class re-names a layer (e.g. + LLaDA's `attn_out` is implemented as `self_attn.o_proj`), the mapping is + declared in `packed_modules_mapping` as `{ckpt_leaf: (local_dotted_name, _)}`. + We must consult that mapping here, otherwise renamed targets silently miss + `__init_lora__` and the loaded LoRA tensors get dropped at apply time. + """ r = lora_config.get("r", 16) lora_alpha = lora_config.get("lora_alpha", 32.0) lora_dropout = lora_config.get("lora_dropout", 0.0) target_modules = lora_config.get("target_modules", []) + if isinstance(target_modules, str): + target_modules = [target_modules] + + rev_mapping = {} + if packed_modules_mapping: + for ckpt_leaf, (local_dotted, _) in packed_modules_mapping.items(): + local_leaf = local_dotted.split(".")[-1] + rev_mapping[local_leaf] = ckpt_leaf for name, module in model.named_modules(): if hasattr(module, "__init_lora__"): should_apply = True if target_modules: leaf = name.split(".")[-1] if name else name - should_apply = any(target == leaf for target in target_modules) + effective = rev_mapping.get(leaf, leaf) + should_apply = any(target == effective for target in target_modules) if should_apply: module.__init_lora__(r, lora_alpha, lora_dropout) return model @@ -177,13 +198,14 @@ def load_model(model: nn.Module, config: Config): # Enable LoRA for linear layers if LoRA is enabled if config.use_lora and config.lora_path: lora_config = load_lora_config(config.lora_path) + packed_modules_mapping_for_lora = getattr(model, "packed_modules_mapping", None) if lora_config: logger.info(f"LoRA Config Loaded: {lora_config}") - model = enable_lora_for_model(model, lora_config) + model = enable_lora_for_model(model, lora_config, packed_modules_mapping_for_lora) else: logger.info("No adapter_config.json found, using default LoRA parameters") default_config = {"r": 16, "lora_alpha": 32.0, "lora_dropout": 0.0} - model = enable_lora_for_model(model, default_config) + model = enable_lora_for_model(model, default_config, packed_modules_mapping_for_lora) # Load base model weights packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) diff --git a/diffulex_bench/configs/dream_base_gsm8k.yml b/diffulex_bench/configs/dream_base_gsm8k.yml index 8b8df7d..3fc6d5f 100644 --- a/diffulex_bench/configs/dream_base_gsm8k.yml +++ b/diffulex_bench/configs/dream_base_gsm8k.yml @@ -28,14 +28,14 @@ engine: semi_complete_threshold: 0.9 accept_threshold: 0.95 block_size: 32 - buffer_size: 4 + buffer_size: 1 eval: dataset_name: "gsm8k_diffulex" dataset_split: "test" dataset_limit: null temperature: 0.0 - max_tokens: 256 + max_tokens: 512 add_bos_token: false # Base model: plain Q/A format, no chat template output_dir: "benchmark_results" save_results: true diff --git a/diffulex_bench/configs/fast_dllm_v2_gsm8k.yml b/diffulex_bench/configs/fast_dllm_v2_gsm8k.yml index 7b59a9b..4871b3c 100644 --- a/diffulex_bench/configs/fast_dllm_v2_gsm8k.yml +++ b/diffulex_bench/configs/fast_dllm_v2_gsm8k.yml @@ -34,7 +34,7 @@ eval: dataset_split: "test" dataset_limit: null temperature: 0.0 - max_tokens: 256 + max_tokens: 512 add_bos_token: true # Instruct model: chat template format output_dir: "benchmark_results" save_results: true diff --git a/diffulex_bench/configs/llada_instruct_gsm8k.yml b/diffulex_bench/configs/llada_instruct_gsm8k.yml index 96bcc90..0bfc8d5 100644 --- a/diffulex_bench/configs/llada_instruct_gsm8k.yml +++ b/diffulex_bench/configs/llada_instruct_gsm8k.yml @@ -16,9 +16,9 @@ engine: tensor_parallel_size: 1 data_parallel_size: 1 gpu_memory_utilization: 0.6 - max_model_len: 2048 - max_num_batched_tokens: 4096 - max_num_reqs: 128 + max_model_len: 1024 + max_num_batched_tokens: 1024 + max_num_reqs: 24 enforce_eager: false kv_cache_layout: "unified" @@ -28,14 +28,14 @@ engine: semi_complete_threshold: 0.9 accept_threshold: 0.95 block_size: 32 - buffer_size: 4 + buffer_size: 1 eval: dataset_name: "gsm8k_diffulex" dataset_split: "test" dataset_limit: null temperature: 0.0 - max_tokens: 256 + max_tokens: 512 add_bos_token: true # Instruct model: chat template format output_dir: "benchmark_results" save_results: true diff --git a/diffulex_bench/configs/sdar_chat_gsm8k.yml b/diffulex_bench/configs/sdar_chat_gsm8k.yml index df0b2e7..ada3990 100644 --- a/diffulex_bench/configs/sdar_chat_gsm8k.yml +++ b/diffulex_bench/configs/sdar_chat_gsm8k.yml @@ -34,7 +34,7 @@ eval: dataset_split: "test" dataset_limit: null temperature: 0.0 - max_tokens: 256 + max_tokens: 512 add_bos_token: true # Chat model: chat template format output_dir: "benchmark_results" save_results: true diff --git a/docs/cookbook/benchmark.md b/docs/cookbook/benchmark.md index 3cb92a6..604366b 100644 --- a/docs/cookbook/benchmark.md +++ b/docs/cookbook/benchmark.md @@ -16,78 +16,11 @@ The config file provides the engine and evaluation settings. In the repository, - `--config diffulex_bench/configs/.yml` - optional overrides like `--model-path`, `--tokenizer-path`, `--model-name`, `--decoding-strategy`, `--tensor-parallel-size`, `--data-parallel-size`, `--dataset`, `--dataset-limit`, `--temperature`, `--max-tokens`, and `--output-dir` -## Supported models +## Model recipes -### D2F-LLaDA +Use the model pages for recommended benchmark configurations: -```bash -python -m diffulex_bench.main \ - --config diffulex_bench/configs/llada_instruct_gsm8k.yml \ - --model-path /YOUR-CKPT-PATH/GSAI-ML/LLaDA-8B-Instruct \ - --tokenizer-path /YOUR-CKPT-PATH/GSAI-ML/LLaDA-8B-Instruct \ - --model-name llada \ - --decoding-strategy d2f \ - --use-lora \ - --lora-path /YOUR-CKPT-PATH/SJTU-DENG-Lab/D2F_LLaDA_Instruct_8B_Lora \ - --tensor-parallel-size 2 \ - --data-parallel-size 1 \ - --dataset gsm8k_diffulex \ - --dataset-limit 100 \ - --temperature 0.0 \ - --max-tokens 256 -``` - -### D2F-Dream - -```bash -python -m diffulex_bench.main \ - --config diffulex_bench/configs/dream_base_gsm8k.yml \ - --model-path /YOUR-CKPT-PATH/Dream-org/Dream-v0-Base-7B \ - --tokenizer-path /YOUR-CKPT-PATH/Dream-org/Dream-v0-Base-7B \ - --model-name dream \ - --decoding-strategy d2f \ - --use-lora \ - --lora-path /YOUR-CKPT-PATH/SJTU-DENG-Lab/D2F_Dream_Base_7B_Lora \ - --tensor-parallel-size 2 \ - --data-parallel-size 1 \ - --dataset gsm8k_diffulex \ - --dataset-limit 100 \ - --temperature 0.0 \ - --max-tokens 256 -``` - -### Fast-dLLM-v2 - -```bash -python -m diffulex_bench.main \ - --config diffulex_bench/configs/fast_dllm_v2_gsm8k.yml \ - --model-path /YOUR-CKPT-PATH/Efficient-Large-Model/Fast_dLLM_v2_7B \ - --tokenizer-path /YOUR-CKPT-PATH/Efficient-Large-Model/Fast_dLLM_v2_7B \ - --model-name fast_dllm_v2 \ - --decoding-strategy multi_bd \ - --tensor-parallel-size 2 \ - --data-parallel-size 1 \ - --dataset gsm8k_diffulex \ - --dataset-limit 100 \ - --temperature 0.0 \ - --max-tokens 256 -``` - -### SDAR - -```bash -python -m diffulex_bench.main \ - --config diffulex_bench/configs/sdar_chat_gsm8k.yml \ - --model-path /YOUR-CKPT-PATH/JetLM/SDAR-1.7B-Chat-b32 \ - --model-name sdar \ - --decoding-strategy multi_bd \ - --tensor-parallel-size 1 \ - --data-parallel-size 1 \ - --dataset gsm8k_diffulex \ - --temperature 0.0 \ - --max-tokens 256 -``` - -### SDAR-MoE - -Use the same benchmark entry point and a matching `sdar_moe` config. The repository already treats `sdar_moe` as a supported model family; keep the same benchmark structure as SDAR and set the model path to your SDAR-MoE checkpoint. +- [LLaDA-8B-Instruct (D2F)](models/llada_instruct) +- [Dream-v0-Base-7B (D2F)](models/dream_base) +- [Fast_dLLM_v2_7B (MultiBD)](models/fast_dllm_v2) +- [SDAR-1.7B-Chat-b32 (MultiBD)](models/sdar_chat) diff --git a/docs/cookbook/index.md b/docs/cookbook/index.md index 4a559b0..19a1176 100644 --- a/docs/cookbook/index.md +++ b/docs/cookbook/index.md @@ -1,12 +1,30 @@ # Cookbook -This section summarizes how to start Diffulex through the supported entry points in this repository: +This section provides model-oriented recipes and entry-point references for Diffulex. -- `benchmark` for evaluation workloads through `diffulex_bench` -- `server` for HTTP serving through `diffulex.server.launch` -- `streamlit` for the sample chat frontend +Use the model pages when you want recommended serving and benchmark configurations for a specific checkpoint. Use the entry-point pages when you want the generic command structure. -Start with the page that matches what you want to do: +## D2F Recipes + +These recipes use `d2f` decoding and include high-concurrency and low-concurrency speed presets where applicable. + +:::{toctree} +:maxdepth: 1 +models/llada_instruct +models/dream_base +::: + +## MultiBD Recipes + +These recipes use `multi_bd` decoding and focus on balanced serving and throughput-oriented presets. + +:::{toctree} +:maxdepth: 1 +models/fast_dllm_v2 +models/sdar_chat +::: + +## Entry Points :::{toctree} :maxdepth: 1 diff --git a/docs/cookbook/models/dream_base.md b/docs/cookbook/models/dream_base.md new file mode 100644 index 0000000..2e0cbf0 --- /dev/null +++ b/docs/cookbook/models/dream_base.md @@ -0,0 +1,138 @@ +# Dream-v0-Base-7B (D2F) + +## 1. Model Introduction + +Dream-v0-Base-7B uses `d2f` decoding with the D2F LoRA adapter. This recipe covers serving and GSM8K benchmark commands, with separate recommendations for high-concurrency serving and low-concurrency speed. + +## 2. Diffulex Installation + +```bash +git clone https://github.com/SJTU-DENG-Lab/Diffulex.git +cd Diffulex +uv pip install -e . +``` + +## 3. Model Deployment + +### 3.1 Basic Configuration + +```bash +python -m diffulex.server.launch \ + --model /YOUR-CKPT-PATH/Dream-org/Dream-v0-Base-7B \ + --model-name dream \ + --decoding-strategy d2f \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 4096 \ + --max-num-batched-tokens 8192 \ + --max-num-reqs 24 \ + --block-size 32 \ + --buffer-size 1 \ + --accept-threshold 0.95 \ + --semi-complete-threshold 0.9 \ + --add-block-threshold 0.1 \ + --use-lora \ + --lora-path /YOUR-CKPT-PATH/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora \ + --pre-merge-lora +``` + +### 3.2 Configuration Tips + +There is a serving trade-off between high concurrency and low-concurrency speed. + +For general high-concurrency serving, use a smaller active block buffer: + +```bash +python -m diffulex.server.launch \ + --model /YOUR-CKPT-PATH/Dream-org/Dream-v0-Base-7B \ + --model-name dream \ + --decoding-strategy d2f \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 4096 \ + --max-num-batched-tokens 8192 \ + --max-num-reqs 24 \ + --block-size 32 \ + --buffer-size 1 \ + --accept-threshold 0.95 \ + --semi-complete-threshold 0.9 \ + --add-block-threshold 0.1 \ + --use-lora \ + --lora-path /YOUR-CKPT-PATH/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora \ + --pre-merge-lora +``` + +For low-concurrency speed, use a slightly larger active block buffer: + +```bash +python -m diffulex.server.launch \ + --model /YOUR-CKPT-PATH/Dream-org/Dream-v0-Base-7B \ + --model-name dream \ + --decoding-strategy d2f \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 4096 \ + --max-num-batched-tokens 8192 \ + --max-num-reqs 2 \ + --block-size 32 \ + --buffer-size 2 \ + --accept-threshold 0.95 \ + --semi-complete-threshold 0.9 \ + --add-block-threshold 0.1 \ + --use-lora \ + --lora-path /YOUR-CKPT-PATH/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora \ + --pre-merge-lora +``` + +For interactive serving, keep `max_num_batched_tokens` larger than `max_model_len`; the recommended server preset uses `8192` with `max_model_len` `4096` to leave room for longer chat turns. If GPU memory becomes the bottleneck, reduce `max_num_reqs` before lowering the context length. `block_size` must be one of `4`, `8`, `16`, or `32`. + +## 4. Model Startup + +### 4.1 Server Startup + +```bash +python -m diffulex.server.launch \ + --model /YOUR-CKPT-PATH/Dream-org/Dream-v0-Base-7B \ + --model-name dream \ + --decoding-strategy d2f \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 4096 \ + --max-num-batched-tokens 8192 \ + --max-num-reqs 24 \ + --block-size 32 \ + --buffer-size 1 \ + --accept-threshold 0.95 \ + --semi-complete-threshold 0.9 \ + --add-block-threshold 0.1 \ + --use-lora \ + --lora-path /YOUR-CKPT-PATH/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora \ + --pre-merge-lora +``` + +### 4.2 Benchmark Startup + +```bash +python -m diffulex_bench.main \ + --config diffulex_bench/configs/dream_base_gsm8k.yml \ + --model-path /YOUR-CKPT-PATH/Dream-org/Dream-v0-Base-7B \ + --tokenizer-path /YOUR-CKPT-PATH/Dream-org/Dream-v0-Base-7B \ + --lora-path /YOUR-CKPT-PATH/SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora \ + --dataset-limit 400 \ + --max-model-len 1024 \ + --max-num-batched-tokens 2048 \ + --max-num-reqs 24 \ + --block-size 32 \ + --engine-arg buffer_size=1 \ + --output-dir /YOUR-OUTPUT-PATH/dream_base_gsm8k +``` + +## 5. Benchmark + +### 5.1 Accuracy Benchmark + +Use the benchmark command above for GSM8K exact-match evaluation. For general serving, prefer the high-concurrency preset. For low-concurrency throughput experiments, compare it with the low-concurrency speed preset. + +### 5.2 Speed Benchmark + +For throughput-focused evaluation, sweep `max_num_reqs` under the two presets. In our tests, `buffer_size=1` is the safer default for high-concurrency serving, while `buffer_size=2` is useful as a low-concurrency speed option. diff --git a/docs/cookbook/models/fast_dllm_v2.md b/docs/cookbook/models/fast_dllm_v2.md new file mode 100644 index 0000000..d579e77 --- /dev/null +++ b/docs/cookbook/models/fast_dllm_v2.md @@ -0,0 +1,113 @@ +# Fast_dLLM_v2_7B (MultiBD) + +## 1. Model Introduction + +Fast_dLLM_v2_7B uses `multi_bd` decoding and is a throughput-oriented baseline for Diffulex. + +## 2. Diffulex Installation + +```bash +git clone https://github.com/SJTU-DENG-Lab/Diffulex.git +cd Diffulex +uv pip install -e . +``` + +## 3. Model Deployment + +### 3.1 Basic Configuration + +```bash +python -m diffulex.server.launch \ + --model /YOUR-CKPT-PATH/Efficient-Large-Model/Fast_dLLM_v2_7B \ + --model-name fast_dllm_v2 \ + --decoding-strategy multi_bd \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 4096 \ + --max-num-batched-tokens 8192 \ + --max-num-reqs 24 \ + --block-size 32 \ + --buffer-size 1 +``` + +### 3.2 Configuration Tips + +There is a trade-off between balanced serving and maximum throughput. + +For balanced serving with strong accuracy, use moderate concurrency: + +```bash +python -m diffulex.server.launch \ + --model /YOUR-CKPT-PATH/Efficient-Large-Model/Fast_dLLM_v2_7B \ + --model-name fast_dllm_v2 \ + --decoding-strategy multi_bd \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 4096 \ + --max-num-batched-tokens 8192 \ + --max-num-reqs 24 \ + --block-size 32 \ + --buffer-size 1 +``` + +For throughput-oriented serving, use higher concurrency: + +```bash +python -m diffulex.server.launch \ + --model /YOUR-CKPT-PATH/Efficient-Large-Model/Fast_dLLM_v2_7B \ + --model-name fast_dllm_v2 \ + --decoding-strategy multi_bd \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 4096 \ + --max-num-batched-tokens 8192 \ + --max-num-reqs 48 \ + --block-size 32 \ + --buffer-size 1 +``` + +For interactive serving, keep `max_num_batched_tokens` larger than `max_model_len`; the recommended server preset uses `8192` with `max_model_len` `4096` to leave room for longer chat turns. If GPU memory becomes the bottleneck, reduce `max_num_reqs` before lowering the context length. + +## 4. Model Startup + +### 4.1 Server Startup + +```bash +python -m diffulex.server.launch \ + --model /YOUR-CKPT-PATH/Efficient-Large-Model/Fast_dLLM_v2_7B \ + --model-name fast_dllm_v2 \ + --decoding-strategy multi_bd \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 4096 \ + --max-num-batched-tokens 8192 \ + --max-num-reqs 24 \ + --block-size 32 \ + --buffer-size 1 +``` + +### 4.2 Benchmark Startup + +```bash +python -m diffulex_bench.main \ + --config diffulex_bench/configs/fast_dllm_v2_gsm8k.yml \ + --model-path /YOUR-CKPT-PATH/Efficient-Large-Model/Fast_dLLM_v2_7B \ + --tokenizer-path /YOUR-CKPT-PATH/Efficient-Large-Model/Fast_dLLM_v2_7B \ + --dataset-limit 400 \ + --max-model-len 1024 \ + --max-num-batched-tokens 1024 \ + --max-num-reqs 24 \ + --block-size 32 \ + --engine-arg buffer_size=1 \ + --output-dir /YOUR-OUTPUT-PATH/fast_dllm_v2_gsm8k +``` + +## 5. Benchmark + +### 5.1 Accuracy Benchmark + +Use the benchmark command above for GSM8K exact-match evaluation. For balanced serving, use the moderate-concurrency preset. + +### 5.2 Speed Benchmark + +For throughput-focused evaluation, increase `max_num_reqs` and monitor accuracy. The throughput-oriented preset uses higher concurrency while keeping `buffer_size=1`. diff --git a/docs/cookbook/models/llada_instruct.md b/docs/cookbook/models/llada_instruct.md new file mode 100644 index 0000000..9b43d57 --- /dev/null +++ b/docs/cookbook/models/llada_instruct.md @@ -0,0 +1,134 @@ +# LLaDA-8B-Instruct (D2F) + +## 1. Model Introduction + +LLaDA-8B-Instruct uses `d2f` decoding with the D2F LoRA adapter. This recipe covers serving and GSM8K benchmark commands, with separate recommendations for high-concurrency serving and low-concurrency speed. + +## 2. Diffulex Installation + +```bash +git clone https://github.com/SJTU-DENG-Lab/Diffulex.git +cd Diffulex +uv pip install -e . +``` + +## 3. Model Deployment + +### 3.1 Basic Configuration + +```bash +python -m diffulex.server.launch \ + --model /YOUR-CKPT-PATH/GSAI-ML/LLaDA-8B-Instruct \ + --model-name llada \ + --decoding-strategy d2f \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 4096 \ + --max-num-batched-tokens 8192 \ + --max-num-reqs 24 \ + --block-size 32 \ + --buffer-size 1 \ + --accept-threshold 0.95 \ + --semi-complete-threshold 0.9 \ + --add-block-threshold 0.1 \ + --use-lora \ + --lora-path /YOUR-CKPT-PATH/SJTU-Deng-Lab/D2F_LLaDA_Instruct_8B_Lora +``` + +### 3.2 Configuration Tips + +There is a serving trade-off between high concurrency and low-concurrency speed. + +For general high-concurrency serving, use a smaller active block buffer: + +```bash +python -m diffulex.server.launch \ + --model /YOUR-CKPT-PATH/GSAI-ML/LLaDA-8B-Instruct \ + --model-name llada \ + --decoding-strategy d2f \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 4096 \ + --max-num-batched-tokens 8192 \ + --max-num-reqs 24 \ + --block-size 32 \ + --buffer-size 1 \ + --accept-threshold 0.95 \ + --semi-complete-threshold 0.9 \ + --add-block-threshold 0.1 \ + --use-lora \ + --lora-path /YOUR-CKPT-PATH/SJTU-Deng-Lab/D2F_LLaDA_Instruct_8B_Lora +``` + +For low-concurrency speed, use a slightly larger active block buffer: + +```bash +python -m diffulex.server.launch \ + --model /YOUR-CKPT-PATH/GSAI-ML/LLaDA-8B-Instruct \ + --model-name llada \ + --decoding-strategy d2f \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 4096 \ + --max-num-batched-tokens 8192 \ + --max-num-reqs 2 \ + --block-size 32 \ + --buffer-size 2 \ + --accept-threshold 0.95 \ + --semi-complete-threshold 0.9 \ + --add-block-threshold 0.1 \ + --use-lora \ + --lora-path /YOUR-CKPT-PATH/SJTU-Deng-Lab/D2F_LLaDA_Instruct_8B_Lora +``` + +For interactive serving, keep `max_num_batched_tokens` larger than `max_model_len`; the recommended server preset uses `8192` with `max_model_len` `4096` to leave room for longer chat turns. If GPU memory becomes the bottleneck, reduce `max_num_reqs` before lowering the context length. `block_size` must be one of `4`, `8`, `16`, or `32`. + +## 4. Model Startup + +### 4.1 Server Startup + +```bash +python -m diffulex.server.launch \ + --model /YOUR-CKPT-PATH/GSAI-ML/LLaDA-8B-Instruct \ + --model-name llada \ + --decoding-strategy d2f \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 4096 \ + --max-num-batched-tokens 8192 \ + --max-num-reqs 24 \ + --block-size 32 \ + --buffer-size 1 \ + --accept-threshold 0.95 \ + --semi-complete-threshold 0.9 \ + --add-block-threshold 0.1 \ + --use-lora \ + --lora-path /YOUR-CKPT-PATH/SJTU-Deng-Lab/D2F_LLaDA_Instruct_8B_Lora +``` + +### 4.2 Benchmark Startup + +```bash +python -m diffulex_bench.main \ + --config diffulex_bench/configs/llada_instruct_gsm8k.yml \ + --model-path /YOUR-CKPT-PATH/GSAI-ML/LLaDA-8B-Instruct \ + --tokenizer-path /YOUR-CKPT-PATH/GSAI-ML/LLaDA-8B-Instruct \ + --lora-path /YOUR-CKPT-PATH/SJTU-Deng-Lab/D2F_LLaDA_Instruct_8B_Lora \ + --dataset-limit 400 \ + --max-model-len 1024 \ + --max-num-batched-tokens 1024 \ + --max-num-reqs 24 \ + --block-size 32 \ + --engine-arg buffer_size=1 \ + --output-dir /YOUR-OUTPUT-PATH/llada_instruct_gsm8k +``` + +## 5. Benchmark + +### 5.1 Accuracy Benchmark + +Use the benchmark command above for GSM8K exact-match evaluation. For general serving, prefer the high-concurrency preset. For low-concurrency throughput experiments, compare it with the low-concurrency speed preset. + +### 5.2 Speed Benchmark + +For throughput-focused evaluation, sweep `max_num_reqs` under the two presets. In our tests, `buffer_size=1` is the safer default for high-concurrency serving, while `buffer_size=2` is useful as a low-concurrency speed option. diff --git a/docs/cookbook/models/sdar_chat.md b/docs/cookbook/models/sdar_chat.md new file mode 100644 index 0000000..92dc7d1 --- /dev/null +++ b/docs/cookbook/models/sdar_chat.md @@ -0,0 +1,112 @@ +# SDAR-1.7B-Chat-b32 (MultiBD) + +## 1. Model Introduction + +SDAR-1.7B-Chat-b32 uses `multi_bd` decoding and is a stable low-concurrency baseline. + +## 2. Diffulex Installation + +```bash +git clone https://github.com/SJTU-DENG-Lab/Diffulex.git +cd Diffulex +uv pip install -e . +``` + +## 3. Model Deployment + +### 3.1 Basic Configuration + +```bash +python -m diffulex.server.launch \ + --model /YOUR-CKPT-PATH/SDAR/SDAR-1.7B-Chat-b32 \ + --model-name sdar \ + --decoding-strategy multi_bd \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 4096 \ + --max-num-batched-tokens 8192 \ + --max-num-reqs 1 \ + --block-size 32 \ + --buffer-size 1 +``` + +### 3.2 Configuration Tips + +There is a trade-off between stable baseline quality and higher throughput. + +For a stable quality-oriented baseline, use low concurrency: + +```bash +python -m diffulex.server.launch \ + --model /YOUR-CKPT-PATH/SDAR/SDAR-1.7B-Chat-b32 \ + --model-name sdar \ + --decoding-strategy multi_bd \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 4096 \ + --max-num-batched-tokens 8192 \ + --max-num-reqs 1 \ + --block-size 32 \ + --buffer-size 1 +``` + +For throughput-oriented serving, use higher concurrency: + +```bash +python -m diffulex.server.launch \ + --model /YOUR-CKPT-PATH/SDAR/SDAR-1.7B-Chat-b32 \ + --model-name sdar \ + --decoding-strategy multi_bd \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 4096 \ + --max-num-batched-tokens 8192 \ + --max-num-reqs 4 \ + --block-size 32 \ + --buffer-size 1 +``` + +For interactive serving, keep `max_num_batched_tokens` larger than `max_model_len`; the recommended server preset uses `8192` with `max_model_len` `4096` to leave room for longer chat turns. If GPU memory becomes the bottleneck, reduce `max_num_reqs` before lowering the context length. + +## 4. Model Startup + +### 4.1 Server Startup + +```bash +python -m diffulex.server.launch \ + --model /YOUR-CKPT-PATH/SDAR/SDAR-1.7B-Chat-b32 \ + --model-name sdar \ + --decoding-strategy multi_bd \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 4096 \ + --max-num-batched-tokens 8192 \ + --max-num-reqs 1 \ + --block-size 32 \ + --buffer-size 1 +``` + +### 4.2 Benchmark Startup + +```bash +python -m diffulex_bench.main \ + --config diffulex_bench/configs/sdar_chat_gsm8k.yml \ + --model-path /YOUR-CKPT-PATH/SDAR/SDAR-1.7B-Chat-b32 \ + --dataset-limit 400 \ + --max-model-len 2048 \ + --max-num-batched-tokens 4096 \ + --max-num-reqs 1 \ + --block-size 32 \ + --engine-arg buffer_size=1 \ + --output-dir /YOUR-OUTPUT-PATH/sdar_chat_gsm8k +``` + +## 5. Benchmark + +### 5.1 Accuracy Benchmark + +Use the benchmark command above for GSM8K exact-match evaluation. For stable baseline comparisons, use the low-concurrency preset. + +### 5.2 Speed Benchmark + +For throughput-focused evaluation, increase `max_num_reqs` and monitor accuracy. The throughput-oriented preset uses higher concurrency while keeping `buffer_size=1`. diff --git a/docs/cookbook/server.md b/docs/cookbook/server.md index b98200c..7703b15 100644 --- a/docs/cookbook/server.md +++ b/docs/cookbook/server.md @@ -11,12 +11,14 @@ python -m diffulex.server.launch \ --decoding-strategy \ --tensor-parallel-size 1 \ --data-parallel-size 1 \ - --max-model-len 2048 \ - --max-num-batched-tokens 4096 \ + --max-model-len 4096 \ + --max-num-batched-tokens 8192 \ --max-num-reqs 128 \ --gpu-memory-utilization 0.9 ``` +For interactive chat serving, keep `max_num_batched_tokens` larger than `max_model_len`. If you increase `max_model_len` for longer conversations, increase `max_num_batched_tokens` with it and reduce `max_num_reqs` if GPU memory becomes the bottleneck. + The server process accepts the same core engine arguments as the benchmark path, plus HTTP-specific flags: - `--host` @@ -26,72 +28,11 @@ The server process accepts the same core engine arguments as the benchmark path, - `--zmq-command-addr` - `--zmq-event-addr` -## Supported models - -### Fast-dLLM-v2 - -```bash -python -m diffulex.server.launch \ - --model /YOUR-CKPT-PATH/Efficient-Large-Model/Fast_dLLM_v2_7B \ - --model-name fast_dllm_v2 \ - --decoding-strategy multi_bd \ - --sampling-mode naive \ - --tensor-parallel-size 2 \ - --data-parallel-size 1 \ - --max-model-len 1024 \ - --max-num-batched-tokens 1024 \ - --max-num-reqs 24 \ - --gpu-memory-utilization 0.4 \ - --block-size 32 \ - --buffer-size 1 \ - --accept-threshold 0.95 \ - --semi-complete-threshold 0.9 \ - --add-block-threshold 0.1 \ - --enforce-eager -``` +## Model recipes -### D2F-LLaDA +Use the model pages for recommended server configurations: -```bash -python -m diffulex.server.launch \ - --model /YOUR-CKPT-PATH/GSAI-ML/LLaDA-8B-Instruct \ - --model-name llada \ - --decoding-strategy d2f \ - --tensor-parallel-size 2 \ - --data-parallel-size 1 \ - --use-lora \ - --lora-path /YOUR-CKPT-PATH/SJTU-DENG-Lab/D2F_LLaDA_Instruct_8B_Lora \ - --pre-merge-lora \ - --max-model-len 2048 \ - --max-num-batched-tokens 2048 \ - --max-num-reqs 32 \ - --accept-threshold 0.95 \ - --semi-complete-threshold 0.9 \ - --add-block-threshold 0.1 \ - --enforce-eager -``` - -### SDAR - -```bash -python -m diffulex.server.launch \ - --model /YOUR-CKPT-PATH/JetLM/SDAR-1.7B-Chat-b32 \ - --host 0.0.0.0 \ - --port 8000 \ - --model-name sdar \ - --decoding-strategy multi_bd \ - --tensor-parallel-size 1 \ - --data-parallel-size 1 \ - --device-ids 1 \ - --block-size 32 \ - --buffer-size 4 \ - --page-size 32 \ - --max-num-batched-tokens 4096 \ - --max-num-reqs 128 \ - --max-model-len 2048 \ - --gpu-memory-utilization 0.5 \ - --kv-cache-layout unified \ - --add-block-threshold 0.1 \ - --semi-complete-threshold 0.9 \ - --accept-threshold 0.95 -``` +- [LLaDA-8B-Instruct (D2F)](models/llada_instruct) +- [Dream-v0-Base-7B (D2F)](models/dream_base) +- [Fast_dLLM_v2_7B (MultiBD)](models/fast_dllm_v2) +- [SDAR-1.7B-Chat-b32 (MultiBD)](models/sdar_chat) diff --git a/examples/streamlit_block_append_chat.py b/examples/streamlit_block_append_chat.py index 28b64a2..e968da0 100644 --- a/examples/streamlit_block_append_chat.py +++ b/examples/streamlit_block_append_chat.py @@ -17,7 +17,7 @@ DEFAULT_TEMPERATURE = 0.0 MASK_TOKEN_TEXT = "<|MASK|>" DEFAULT_MASK_SYMBOL = "▒" -DISPLAY_STOP_TOKENS = ("<|im_end|>",) +DISPLAY_STOP_TOKENS = ("<|im_end|>", "<|eot_id|>") @dataclass