Minimal change for latest transformer usage.#31
Open
CenjhihLi wants to merge 1 commit intoShark-NLP:mainfrom
Open
Minimal change for latest transformer usage.#31CenjhihLi wants to merge 1 commit intoShark-NLP:mainfrom
CenjhihLi wants to merge 1 commit intoShark-NLP:mainfrom
Conversation
CenjhihLi
commented
May 6, 2026
- Modify the tokenizer encode usage.
- Modify the pad_token assign: only assign pad_token = eos_token when eos_token is not None
- Modify the ce_loss computation in ppl_inferemcer to support bfloat16.
- Add cone_retriever from https://github.com/Romainpkq/revisit_demon_selection_in_ICL
1. Modify the tokenizer encode usage. 2. Modify the pad_token assign: only assign pad_token = eos_token when eos_token is not None 3. Modify the ce_loss computation in ppl_inferemcer to support bfloat16. 4. Add cone_retriever from https://github.com/Romainpkq/revisit_demon_selection_in_ICL
There was a problem hiding this comment.
Pull request overview
This PR updates OpenICL’s Hugging Face Transformers integration to be more compatible with newer tokenizer/model behaviors (padding/tokenization APIs and bfloat16 outputs), and adds a new ConE-based retriever implementation.
Changes:
- Adjust device transfer in the padding+CUDA collator to handle non-tensor values in the batch.
- Make
pad_token/pad_token_idassignment conditional oneos_token(_id)being present to avoid overwriting valid pad settings withNone. - Update PPL CE-loss computation to better handle bfloat16 logits and add a new
ConERetriever.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| openicl/utils/collators.py | Changes how collated batches are moved onto device (now via per-item .to(...)). |
| openicl/icl_retriever/icl_topk_retriever.py | Makes pad token assignment conditional on the presence of EOS token/id. |
| openicl/icl_retriever/icl_cone_retriever.py | Adds a new ConE retriever that reranks TopK candidates using CE loss. |
| openicl/icl_retriever/init.py | Exposes ConERetriever from the retriever package. |
| openicl/icl_inferencer/icl_ppl_inferencer.py | Updates CE-loss normalization and dtype handling to support bfloat16 outputs. |
| openicl/icl_inferencer/icl_base_inferencer.py | Makes pad token assignment conditional on the presence of EOS token/id. |
| openicl/icl_dataset_reader.py | Updates dataset tokenization to use tokenizer(...) with a fallback to encode_plus. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+65
to
+68
| batch = { | ||
| k: v.to(self.device) if hasattr(v, "to") else v | ||
| for k, v in batch.items() | ||
| } |
| ce_loss = loss.sum(-1).cpu().detach().numpy() / lens | ||
| lens -= torch.tensor(mask_length, device=lens.device, dtype=lens.dtype) | ||
| # Some new hf models are bfloat16 | ||
| ce_loss = (loss.sum(-1) / lens.to(loss.dtype)).detach().to(torch.float32).cpu() |
| tokenized_data = self.tokenizer.encode_plus(data, truncation=True, return_tensors='pt', verbose=False) | ||
| try: | ||
| tokenized_data = self.tokenizer(data, truncation=True, return_tensors='pt', verbose=False) | ||
| except: |
Comment on lines
+101
to
+112
| embed = np.expand_dims(entry['embed'], axis=0) | ||
| near_ids = self.index.search(embed, min(self.candidate_num, len(self.index_ds)))[1][0].tolist() | ||
| candidates = [] | ||
| mdl_scores = [] | ||
|
|
||
| prompts = [] | ||
| mask_lengths = [] | ||
| test_lengths = [] | ||
|
|
||
| for j in range(self.candidate_num): | ||
| rand_idx_list = [near_ids[j]] | ||
| candidates.append(rand_idx_list) |
Comment on lines
+140
to
+153
| for batch_id in range(self.candidate_num // self.ppl_batch_size): | ||
| with torch.no_grad(): | ||
| loss_list = self.cal_ce(prompts[batch_id * self.ppl_batch_size: (batch_id + 1) * self.ppl_batch_size], mask_lengths=mask_lengths[batch_id * self.ppl_batch_size: (batch_id + 1) * self.ppl_batch_size], test_lengths=test_lengths[batch_id * self.ppl_batch_size: (batch_id + 1) * self.ppl_batch_size]) | ||
| mdl_scores.extend(loss_list) | ||
|
|
||
| if self.candidate_num % self.ppl_batch_size != 0: | ||
| with torch.no_grad(): | ||
| end_pos = self.candidate_num // self.ppl_batch_size * self.ppl_batch_size | ||
| loss_list = self.cal_ce(prompts[end_pos:], mask_lengths=mask_lengths[end_pos:], test_lengths=test_lengths[end_pos:]) | ||
| mdl_scores.extend(loss_list) | ||
|
|
||
| ppl_scores = list(sorted(list(enumerate(mdl_scores)), key=lambda x: x[1])) | ||
| # get the most lower ppl demonstrations for each test input | ||
| rtr_idx_list[idx] = [int(candidates[ppl_scores[i][0]][0]) for i in range(self.ice_num)] |
Comment on lines
+162
to
+164
| logger.info(f'Load model {self.metric_model} for calculating MDL...') | ||
| self.metric_model = AutoModelForCausalLM.from_pretrained(self.ce_model_name) | ||
| self.metric_model.to(self.device) |
|
|
||
| ppl_scores = list(sorted(list(enumerate(mdl_scores)), key=lambda x: x[1])) | ||
| # get the most lower ppl demonstrations for each test input | ||
| rtr_idx_list[idx] = [int(candidates[ppl_scores[i][0]][0]) for i in range(self.ice_num)] |
Comment on lines
+1
to
+21
| """MDL Retriever""" | ||
|
|
||
| from openicl import DatasetReader, PromptTemplate | ||
| from openicl.icl_retriever.icl_topk_retriever import TopkRetriever | ||
| from openicl.utils.calculate import entropy | ||
| from openicl.utils.logging import get_logger | ||
| from typing import List, Union, Optional, Tuple | ||
| from transformers import AutoTokenizer, AutoModelForCausalLM | ||
| import tqdm | ||
| import torch | ||
| import numpy as np | ||
| from accelerate import Accelerator | ||
|
|
||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
|
|
||
| class ConERetriever(TopkRetriever): | ||
| """PPL In-context Learning Retriever Class | ||
| Class of ConE retriever. | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.