I have get my logits by the yaml below:
d: 151936
k: 128
exact_k: 16
exact_dtype: bfloat16
polynomial_terms: [0, 1, 2, 3, 4, "sqrt"]
term_dtype: float32
residual_bins: []
delta_encoding: false
error_diffusion: false
and use this code to run:
python -m distillkit.sample_logits_vllm \
--model /home/liuhengyu/llm/model/ori/Qwen3/Qwen3-14B \
--dataset /home/liuhengyu/data/2026/KD/xxx/merge/ \
--output /home/liuhengyu/data/2026/KD/Qwen3-14B-Logits/ \
--apply-chat-template \
--dtype bfloat16 \
--gpu-memory-utilization 0.95 \
--tensor-parallel-size 1 \
--max-seq-len 1024 \
--compression-config ./config/logits_compress.yaml
🚀🚀 but, an error occurred when i was distilling with the logits data
my distill code:
distillkit examples/Qwen3-1.7B.yml
the yml is below:
# config.yaml
project_name: KD-Qwen3-1.7B
model: /home/liuhengyu/llm/model/ori/Qwen3/Qwen3-1.7B
model_auto_class: AutoModelForCausalLM
output_path: /home/liuhengyu/llm/model/train/Qwen3-14B-KD-1.7B
sequence_length: 1024
use_flash_attention: false
dataset:
train_dataset:
repo_id: /home/liuhengyu/data/2026/KD/Qwen3-14B-Logits
split: train
prepacked: true
teacher:
kind: dataset
logprob_compressor:
d: 151670 # 词表大小
delta_encoding: false
error_diffusion: false
exact_dtype: bfloat16
exact_k: 16
k: 128
polynomial_terms: [0, 1, 2, 3, 4, "sqrt"]
residual_bins: []
term_dtype: float32
loss_functions:
- function: cross_entropy
weight: 0.5
- function: kl
weight: 0.5
temperature: 1.0
missing_probability_handling: zero
sparse_chunk_length: 1024
training_args:
num_train_epochs: 1
per_device_train_batch_size: 1
gradient_accumulation_steps: 8
learning_rate: 2.0e-6
bf16: true
optim: adamw_torch
gradient_checkpointing: true
report_to: none
warmup_ratio: 0.025
save_steps: 512
save_total_limit: 4
logging_steps: 2
the error is:
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.
0%| | 0/1729 [00:00<?, ?it/s]Traceback (most recent call last):
File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/bin/distillkit", line 7, in <module>
sys.exit(main())
^^^^^^
File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/click/core.py", line 1485, in __call__
return self.main(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/click/core.py", line 1406, in main
rv = self.invoke(ctx)
^^^^^^^^^^^^^^^^
File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/click/core.py", line 1269, in invoke
return ctx.invoke(self.callback, **ctx.params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/click/core.py", line 824, in invoke
return callback(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/md/liuhengyu/workspace/python/2026/KD/DistillKit/distillkit/main.py", line 393, in main
do_distill(config)
File "/mnt/md/liuhengyu/workspace/python/2026/KD/DistillKit/distillkit/main.py", line 363, in do_distill
trainer.train(
File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/transformers/trainer.py", line 2325, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/transformers/trainer.py", line 2618, in _inner_training_loop
batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches, args.device)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/transformers/trainer.py", line 5654, in get_batch_samples
batch_samples.append(next(epoch_iterator))
^^^^^^^^^^^^^^^^^^^^
File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/accelerate/data_loader.py", line 567, in __iter__
current_batch = next(dataloader_iter)
^^^^^^^^^^^^^^^^^^^^^
File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 732, in __next__
data = self._next_data()
^^^^^^^^^^^^^^^^^
File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 788, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/md/liuhengyu/miniconda3/envs/DistillKit_A100-3/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
return self.collate_fn(data)
^^^^^^^^^^^^^^^^^^^^^
File "/mnt/md/liuhengyu/workspace/python/2026/KD/DistillKit/distillkit/main.py", line 276, in collate_packed_batch
key: torch.tensor([example[key] for example in examples])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: expected sequence of length 208 at dim 1 (got 220)
what should I do, please tell me.
I have get my logits by the yaml below:
and use this code to run:
🚀🚀 but, an error occurred when i was distilling with the logits data
my distill code:
the yml is below:
the error is:
what should I do, please tell me.