Skip to content

failed to run offline distillation #41

@wufenglailai

Description

@wufenglailai

I have get my logits by the yaml below:

d: 151936
k: 128
exact_k: 16
exact_dtype: bfloat16
polynomial_terms: [0, 1, 2, 3, 4, "sqrt"]
term_dtype: float32
residual_bins: []
delta_encoding: false
error_diffusion: false

and use this code to run:

python -m distillkit.sample_logits_vllm \
  --model /home/liuhengyu/llm/model/ori/Qwen3/Qwen3-14B \
  --dataset /home/liuhengyu/data/2026/KD/xxx/merge/ \
  --output /home/liuhengyu/data/2026/KD/Qwen3-14B-Logits/ \
  --apply-chat-template \
  --dtype bfloat16 \
  --gpu-memory-utilization 0.95 \
  --tensor-parallel-size 1 \
  --max-seq-len 1024 \
  --compression-config ./config/logits_compress.yaml

🚀🚀 but, an error occurred when i was distilling with the logits data
my distill code:

distillkit examples/Qwen3-1.7B.yml

the yml is below:

# config.yaml
project_name: KD-Qwen3-1.7B
model: /home/liuhengyu/llm/model/ori/Qwen3/Qwen3-1.7B
model_auto_class: AutoModelForCausalLM
output_path: /home/liuhengyu/llm/model/train/Qwen3-14B-KD-1.7B
sequence_length: 1024
use_flash_attention: false

dataset:
  train_dataset:
    repo_id: /home/liuhengyu/data/2026/KD/Qwen3-14B-Logits
    split: train
  prepacked: true


teacher:
  kind: dataset
  logprob_compressor:
    d: 151670 # 词表大小
    delta_encoding: false
    error_diffusion: false
    exact_dtype: bfloat16
    exact_k: 16
    k: 128
    polynomial_terms: [0, 1, 2, 3, 4, "sqrt"]
    residual_bins: []
    term_dtype: float32

loss_functions:
  - function: cross_entropy
    weight: 0.5
  - function: kl
    weight: 0.5
    temperature: 1.0
    missing_probability_handling: zero
    sparse_chunk_length: 1024

training_args:
  num_train_epochs: 1
  per_device_train_batch_size: 1
  gradient_accumulation_steps: 8
  learning_rate: 2.0e-6
  bf16: true
  optim: adamw_torch
  gradient_checkpointing: true
  report_to: none
  warmup_ratio: 0.025
  save_steps: 512
  save_total_limit: 4
  logging_steps: 2

the error is:

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.
  0%|                                                                                                                                               | 0/1729 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/bin/distillkit", line 7, in <module>
    sys.exit(main())
             ^^^^^^
  File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/click/core.py", line 1485, in __call__
    return self.main(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/click/core.py", line 1406, in main
    rv = self.invoke(ctx)
         ^^^^^^^^^^^^^^^^
  File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/click/core.py", line 1269, in invoke
    return ctx.invoke(self.callback, **ctx.params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/click/core.py", line 824, in invoke
    return callback(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/md/liuhengyu/workspace/python/2026/KD/DistillKit/distillkit/main.py", line 393, in main
    do_distill(config)
  File "/mnt/md/liuhengyu/workspace/python/2026/KD/DistillKit/distillkit/main.py", line 363, in do_distill
    trainer.train(
  File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/transformers/trainer.py", line 2325, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/transformers/trainer.py", line 2618, in _inner_training_loop
    batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/transformers/trainer.py", line 5654, in get_batch_samples
    batch_samples.append(next(epoch_iterator))
                         ^^^^^^^^^^^^^^^^^^^^
  File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/accelerate/data_loader.py", line 567, in __iter__
    current_batch = next(dataloader_iter)
                    ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 732, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 788, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/md/liuhengyu/workspace/python/2026/KD/DistillKit/distillkit/main.py", line 276, in collate_packed_batch
    key: torch.tensor([example[key] for example in examples])
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: expected sequence of length 208 at dim 1 (got 220)

what should I do, please tell me.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions