Skip to content

fix rerank value error (Add token_type_ids input for rerank model with type-embedding)#21729

Open
cjsdurj wants to merge 5 commits intoggml-org:masterfrom
cjsdurj:fix-rerank-error-output
Open

fix rerank value error (Add token_type_ids input for rerank model with type-embedding)#21729
cjsdurj wants to merge 5 commits intoggml-org:masterfrom
cjsdurj:fix-rerank-error-output

Conversation

@cjsdurj
Copy link
Copy Markdown

@cjsdurj cjsdurj commented Apr 10, 2026

Overview

This PR introduces token_type (Token Type IDs) support to llama_batch, primarily to enable full support for BERT-based Cross-Encoders and Reranker models. Previously, token_type_ids were hardcoded to zero, and the pooling/classifier layers were discarded during conversion, which limited BERT to embedding generation only.

Additional information

Key changes include:

  • llama_batch & Graph Updates: Added int32_t * token_type to the llama_batch struct. Implemented llm_graph_input_token_type to pass these IDs as I32 tensors into the computation graph.
  • BERT Model Update: Modified src/models/bert.cpp to utilize model.type_embd with the actual token types instead of ignoring them.
  • Conversion Script: Updated convert_hf_to_gguf.py to map pooler.dense.weight/bias to classifier.weight/bias for BERT models. Rerankers rely on this classification head to output sequence scores.
  • Examples & Tools: Updated examples/embedding/embedding.cpp to correctly assign token_type = 0 for query tokens and token_type = 1 for document tokens when processing rerank prompts. Synced all other tools (server, perplexity, batched-bench, etc.) to match the new llama_batch layout.

test

all test use model ms-marco-MiniLM-L6-v2-f16, commod:

llama-embedding.exe -m D:\models\ms-marco-MiniLM-L6-v2-f16.gguf --pooling rank -p "How many people live in Berlin?\tBerlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.\nHow many people live in Berlin?\tBerlin is well known for its museums."  --embd-normalize -1 

tokenizer.chat_template.rerank : [CLS]{query}[SEP]{document}[SEP]

  • groud truth
from sentence_transformers import CrossEncoder

model = CrossEncoder('cross-encoder/ms-marco-MiniLM-L6-v2')
scores = model.predict([
    ("How many people live in Berlin?", "Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers."),
    ("How many people live in Berlin?", "Berlin is well known for its museums."),
])
print(scores)
# [ 8.607138 -4.320078]
  • before
image
  • this pr
image

Requirements

jian.chen03 and others added 2 commits April 10, 2026 18:33
…oss-encoder/ms-marco-MiniLM-L6-v2 ...)

Signed-off-by: jian.chen03 <jian.chen03@transwarp.io>
@cjsdurj cjsdurj requested review from a team, CISC and ggerganov as code owners April 10, 2026 14:21
@github-actions github-actions bot added model Model specific examples python python script changes server labels Apr 10, 2026
Copy link
Copy Markdown
Contributor

@ngxson ngxson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR introduce unnecessary breaking change to llama_batch just for a single use case, this is not something acceptable.

Instead, the simple solution is to offset the token ID, example: if n_vocab = 1000, then:

  • token ID 0 -> 999 are type 0
  • token ID 1000 -> 1999 are type 1, the token content is wrapped 1000 -> 0, 1001 -> 1, etc

That will be much easier to implement, while introducing no breaking changes

jian.chen03 added 2 commits April 11, 2026 11:19
…(like cross-encoder/ms-marco-MiniLM-L6-v2 ...)"

This reverts commit c1add0f.
Signed-off-by: jian.chen03 <jian.chen03@transwarp.io>
@cjsdurj cjsdurj requested a review from ngxson April 11, 2026 07:42
@cjsdurj
Copy link
Copy Markdown
Author

cjsdurj commented Apr 11, 2026

I agree with your suggestion and changed my code. but make some other breaking changes: @ngxson

  1. remove token validity check in llama-batch.cpp
  2. add 3 member variables (t_token_types, remove_token_offset,token_offset ) in class llm_graph_result

Signed-off-by: jian.chen03 <jian.chen03@transwarp.io>
@ngxson
Copy link
Copy Markdown
Contributor

ngxson commented Apr 11, 2026

2. add 3 member variables (t_token_types, remove_token_offset,token_offset ) in class llm_graph_result

why? simply modify llama_batch_allocr and llama_ubatch to store an extra token_type array; they are internal struct, so we are free to modify them. llm_graph_input_token_type like your initial version is also necessary

the goal is to have no breaking changes to public api whenever possible

Comment on lines -49 to -56
if (batch.token) {
for (int32_t i = 0; i < batch.n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
return false;
}
}
}
Copy link
Copy Markdown
Contributor

@ngxson ngxson Apr 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there should be an extra case for BERT, example:

const uint32_t max_tokens = vocab.n_tokens() * (arch == LLM_ARCH_BERT ? max_token_types : 1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples model Model specific python python script changes server

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants