fix rerank value error (Add token_type_ids input for rerank model with type-embedding)#21729
fix rerank value error (Add token_type_ids input for rerank model with type-embedding)#21729cjsdurj wants to merge 5 commits intoggml-org:masterfrom
Conversation
…oss-encoder/ms-marco-MiniLM-L6-v2 ...) Signed-off-by: jian.chen03 <jian.chen03@transwarp.io>
ngxson
left a comment
There was a problem hiding this comment.
This PR introduce unnecessary breaking change to llama_batch just for a single use case, this is not something acceptable.
Instead, the simple solution is to offset the token ID, example: if n_vocab = 1000, then:
- token ID 0 -> 999 are type 0
- token ID 1000 -> 1999 are type 1, the token content is wrapped 1000 -> 0, 1001 -> 1, etc
That will be much easier to implement, while introducing no breaking changes
…(like cross-encoder/ms-marco-MiniLM-L6-v2 ...)" This reverts commit c1add0f.
|
I agree with your suggestion and changed my code. but make some other breaking changes: @ngxson
|
Signed-off-by: jian.chen03 <jian.chen03@transwarp.io>
why? simply modify the goal is to have no breaking changes to public api whenever possible |
| if (batch.token) { | ||
| for (int32_t i = 0; i < batch.n_tokens; ++i) { | ||
| if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) { | ||
| LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); | ||
| return false; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
there should be an extra case for BERT, example:
const uint32_t max_tokens = vocab.n_tokens() * (arch == LLM_ARCH_BERT ? max_token_types : 1)
Overview
This PR introduces
token_type(Token Type IDs) support tollama_batch, primarily to enable full support for BERT-based Cross-Encoders and Reranker models. Previously,token_type_idswere hardcoded to zero, and the pooling/classifier layers were discarded during conversion, which limited BERT to embedding generation only.Additional information
Key changes include:
llama_batch& Graph Updates: Addedint32_t * token_typeto thellama_batchstruct. Implementedllm_graph_input_token_typeto pass these IDs asI32tensors into the computation graph.src/models/bert.cppto utilizemodel.type_embdwith the actual token types instead of ignoring them.convert_hf_to_gguf.pyto mappooler.dense.weight/biastoclassifier.weight/biasfor BERT models. Rerankers rely on this classification head to output sequence scores.examples/embedding/embedding.cppto correctly assigntoken_type = 0for query tokens andtoken_type = 1for document tokens when processing rerank prompts. Synced all other tools (server,perplexity,batched-bench, etc.) to match the newllama_batchlayout.test
all test use model ms-marco-MiniLM-L6-v2-f16, commod:
tokenizer.chat_template.rerank : [CLS]{query}[SEP]{document}[SEP]
Requirements