From c1add0f3505ca27365d948efeda862a6b800422f Mon Sep 17 00:00:00 2001 From: "jian.chen03" Date: Fri, 10 Apr 2026 18:33:17 +0800 Subject: [PATCH 1/6] add token_type_ids input for rerank model with type-embedding(like cross-encoder/ms-marco-MiniLM-L6-v2 ...) Signed-off-by: jian.chen03 --- common/common.cpp | 12 ++++- common/common.h | 8 ++++ convert_hf_to_gguf.py | 9 +++- examples/embedding/embedding.cpp | 67 +++++++++++++++++++-------- examples/parallel/parallel.cpp | 1 + include/llama.h | 2 + src/llama-batch.cpp | 21 ++++++++- src/llama-batch.h | 3 ++ src/llama-graph.cpp | 20 ++++++++ src/llama-graph.h | 13 ++++++ src/models/bert.cpp | 6 +-- tools/batched-bench/batched-bench.cpp | 1 + tools/mtmd/mtmd-helper.cpp | 5 ++ tools/perplexity/perplexity.cpp | 1 + tools/server/server-context.cpp | 1 + 15 files changed, 142 insertions(+), 28 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 16f78debd02..dbdd135ddc6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1513,10 +1513,21 @@ void common_batch_add( llama_pos pos, const std::vector & seq_ids, bool logits) { + common_batch_add(batch,id,pos,0,seq_ids,logits); +} + +void common_batch_add( + struct llama_batch & batch, + llama_token id, + llama_pos pos, + int32_t token_type, + const std::vector & seq_ids, + bool logits) { GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); batch.token [batch.n_tokens] = id; batch.pos [batch.n_tokens] = pos; + batch.token_type[batch.n_tokens] = token_type; batch.n_seq_id[batch.n_tokens] = seq_ids.size(); for (size_t i = 0; i < seq_ids.size(); ++i) { batch.seq_id[batch.n_tokens][i] = seq_ids[i]; @@ -1525,7 +1536,6 @@ void common_batch_add( batch.n_tokens++; } - // // Vocab utils // diff --git a/common/common.h b/common/common.h index 020b6a721ff..e1d0c34a475 100644 --- a/common/common.h +++ b/common/common.h @@ -874,6 +874,14 @@ void common_batch_add( const std::vector & seq_ids, bool logits); +void common_batch_add( + struct llama_batch & batch, + llama_token id, + llama_pos pos, + int32_t token_type, + const std::vector & seq_ids, + bool logits); + // decodes a single batch of tokens for a prompt and manages session tokens // // Note: We save state before the last token so that we can replay it to ensure diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 8d6b0a97a02..3dd78199191 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6416,8 +6416,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith(".beta"): name = name[:-5] + ".bias" - # we are only using BERT for embeddings so we don't need the pooling layer - if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"): + if name == "embeddings.position_ids": return # we don't need these if name.startswith("cls.predictions"): @@ -6434,6 +6433,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name == "classifier.bias": name = "classifier.out_proj.bias" + if name == "pooler.dense.weight": + name = "classifier.weight" + + if name == "pooler.dense.bias": + name = "classifier.bias" + yield from super().modify_tensors(data_torch, name, bid) def _xlmroberta_tokenizer_init(self) -> None: diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index f6a20ef9d07..e3111852d08 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -27,10 +27,12 @@ static std::vector split_lines(const std::string & s, const std::st return lines; } -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, const std::vector & token_types ,llama_seq_id seq_id) { size_t n_tokens = tokens.size(); + bool add_token_type = n_tokens == token_types.size(); for (size_t i = 0; i < n_tokens; i++) { - common_batch_add(batch, tokens[i], i, { seq_id }, true); + int32_t token_type = add_token_type? token_types[i]:0; + common_batch_add(batch, tokens[i], i, token_type , { seq_id }, true); } } @@ -180,33 +182,57 @@ int main(int argc, char ** argv) { // tokenize the prompts and trim std::vector> inputs; + std::vector> token_type_ids; for (const auto & prompt : prompts) { std::vector inp; + std::vector token_type; // split classification pairs and insert expected separator tokens if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) { std::vector pairs = split_lines(prompt, params.cls_sep); + const std::string query = pairs[0]; + const std::string doc = pairs[1]; if (rerank_prompt != nullptr) { - const std::string query = pairs[0]; - const std::string doc = pairs[1]; std::string final_prompt = rerank_prompt; - string_replace_all(final_prompt, "{query}" , query); - string_replace_all(final_prompt, "{document}", doc ); - inp = common_tokenize(vocab, final_prompt, true, true); + size_t pos = final_prompt.find("{document}"); + std::string q_prompt = final_prompt.substr(0, pos); + std::string d_prompt = final_prompt.substr(pos); + string_replace_all(q_prompt, "{query}" , query); + string_replace_all(d_prompt, "{document}", doc ); + + auto inp_q= common_tokenize(vocab, q_prompt, false, true); + auto inp_d= common_tokenize(vocab, d_prompt, false, true); + + for(auto token: inp_q){ + inp.emplace_back(token); + token_type.emplace_back(0); + } + for(auto token: inp_d){ + inp.emplace_back(token); + token_type.emplace_back(1); + } } else { - std::string final_prompt; - for (size_t i = 0; i < pairs.size(); i++) { - final_prompt += pairs[i]; - if (i != pairs.size() - 1) { - if (!added_eos_token.empty()) { - final_prompt += added_eos_token; - } - if (!added_sep_token.empty()) { - final_prompt += added_sep_token; - } - } + auto inp_q= common_tokenize(vocab, query, false, false); + auto inp_d= common_tokenize(vocab, doc, false, false); + inp.emplace_back(llama_vocab_bos(vocab)); //add bos + token_type.emplace_back(0); + for(auto token: inp_q){ //add seq A + inp.emplace_back(token); + token_type.emplace_back(0); } - inp = common_tokenize(ctx, final_prompt, true, true); + inp.emplace_back(llama_vocab_eos(vocab)); //add eos + token_type.emplace_back(0); + + inp.emplace_back(llama_vocab_sep(vocab)); //add sep + token_type.emplace_back(0); + + for(auto token: inp_d){ //add seq B + inp.emplace_back(token); + token_type.emplace_back(1); + } + + inp.emplace_back(llama_vocab_eos(vocab)); //add eos + token_type.emplace_back(1); } } else { inp = common_tokenize(ctx, prompt, true, true); @@ -216,6 +242,7 @@ int main(int argc, char ** argv) { __func__, (long long int) inp.size(), (long long int) n_batch); return 1; } + token_type_ids.push_back(token_type); inputs.push_back(inp); } @@ -278,7 +305,7 @@ int main(int argc, char ** argv) { } // add to batch - batch_add_seq(batch, inp, s); + batch_add_seq(batch, inp, token_type_ids[k], s); s += 1; } diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index a46400c5b94..211f6890bb4 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -397,6 +397,7 @@ int main(int argc, char ** argv) { batch.token + i, nullptr, batch.pos + i, + batch.token_type + i, batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, diff --git a/include/llama.h b/include/llama.h index ac267b5089a..83e47879e22 100644 --- a/include/llama.h +++ b/include/llama.h @@ -224,6 +224,7 @@ extern "C" { // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) // - pos : the positions of the respective token in the sequence // (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode) + // - token_type: the type of token in rerank models (query tokens type is 0, document tokens type is 1) // - seq_id : the sequence to which the respective token belongs // (if set to NULL, the sequence ID will be assumed to be 0) // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output @@ -238,6 +239,7 @@ extern "C" { llama_token * token; float * embd; llama_pos * pos; + int32_t * token_type; int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 6bf76939cdd..38278e7de97 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -69,6 +69,11 @@ bool llama_batch_allocr::init( // // auto-generate missing fields // + if(!batch.token_type){ + token_type.resize(batch.n_tokens); + std::fill(token_type.begin(),token_type.end(),0); + batch.token_type = token_type.data(); + } if (!batch.n_seq_id) { n_seq_id.resize(batch.n_tokens); @@ -219,6 +224,7 @@ bool llama_batch_allocr::init( /*.token =*/ batch.token, /*.embd =*/ batch.embd, /*.pos =*/ batch.pos, + /*.token_type =*/ batch.token_type, /*.n_seq_id =*/ batch.n_seq_id, /*.seq_id =*/ batch.seq_id, /*.seq_id_unq =*/ this->seq_id_unq.data(), @@ -401,6 +407,7 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t udata->token .resize(n_tokens); udata->embd .clear(); udata->pos .resize(n_pos_all); + udata->token_type.resize(n_tokens); udata->n_seq_id .resize(n_tokens); udata->seq_id .resize(n_tokens); udata->seq_id_unq.resize(0); @@ -423,6 +430,7 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t /*.token =*/ udata->token.data(), /*.embd =*/ nullptr, /*.pos =*/ udata->pos.data(), + /*.token_type =*/ udata->token_type.data(), /*.n_seq_id =*/ udata->n_seq_id.data(), /*.seq_id =*/ udata->seq_id.data(), /*.seq_id_unq =*/ udata->seq_id_unq.data(), @@ -658,6 +666,7 @@ void llama_batch_allocr::clear() { batch = {}; pos .clear(); + token_type.clear(); n_seq_id .clear(); seq_id .clear(); seq_id_unq.clear(); @@ -691,6 +700,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u udata->token .resize(n_tokens); udata->embd .resize(n_embd_all); udata->pos .resize(n_pos_all); + udata->token_type.resize(n_tokens); udata->n_seq_id .resize(n_tokens); udata->seq_id .resize(n_tokens); udata->seq_id_unq.resize(0); @@ -719,6 +729,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u udata->pos[j*n_tokens + i] = batch.pos[src_off + idxs[i]]; } + udata->token_type[i] = batch.token_type[idxs[i]]; udata->n_seq_id[i] = batch.n_seq_id[idxs[i]]; udata->output[i] = batch.logits[idxs[i]]; @@ -758,6 +769,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u /*.token =*/ batch.token ? udata->token.data() : nullptr, /*.embd =*/ batch.embd ? udata->embd.data() : nullptr, /*.pos =*/ udata->pos.data(), + /*token_type =*/ udata->token_type.data(), /*.n_seq_id =*/ udata->n_seq_id.data(), /*.seq_id =*/ udata->seq_id.data(), /*.seq_id_unq =*/ udata->seq_id_unq.data(), @@ -807,6 +819,7 @@ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) { LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) ubatch.token); LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) ubatch.embd); LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) ubatch.pos); + LLAMA_LOG_DEBUG("%s: token_type = %p\n", __func__, (void *) ubatch.token_type); LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) ubatch.n_seq_id); LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) ubatch.seq_id); LLAMA_LOG_DEBUG("%s: seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str()); @@ -843,9 +856,9 @@ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) { } if (ubatch.token) { - LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n", + LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, token_type = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n", __func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(), - ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]); + ubatch.pos[i], ubatch.token_type[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]); } else { LLAMA_LOG_DEBUG("%s: %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n", __func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]); @@ -868,6 +881,7 @@ struct llama_batch llama_batch_get_one( /*tokens =*/ tokens, /*embd =*/ nullptr, /*pos =*/ nullptr, + /*token_type=*/nullptr, /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, /*logits =*/ nullptr, @@ -880,6 +894,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ /*tokens =*/ nullptr, /*embd =*/ nullptr, /*pos =*/ nullptr, + /*token_type=*/nullptr, /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, /*logits =*/ nullptr, @@ -892,6 +907,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ } batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc); + batch.token_type=(int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); for (int i = 0; i < n_tokens_alloc; ++i) { @@ -908,6 +924,7 @@ void llama_batch_free(struct llama_batch batch) { if (batch.token) free(batch.token); if (batch.embd) free(batch.embd); if (batch.pos) free(batch.pos); + if (batch.token_type) free(batch.token_type); if (batch.n_seq_id) free(batch.n_seq_id); if (batch.seq_id) { for (int i = 0; batch.seq_id[i] != nullptr; ++i) { diff --git a/src/llama-batch.h b/src/llama-batch.h index f77520e86c3..e62b569feb6 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -45,6 +45,7 @@ struct llama_ubatch { llama_token * token; // [n_tokens] | i | id, token float * embd; // [n_embd, n_tokens] | i | embd llama_pos * pos; // [n_tokens*n_pos] | i | pos + int32_t * token_type; // [n_tokens] | i | token_type int32_t * n_seq_id; // [n_tokens] | i | - llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id @@ -55,6 +56,7 @@ struct llama_ubatch { std::vector token; std::vector embd; std::vector pos; + std::vector token_type; std::vector n_seq_id; std::vector seq_id; // these point into the seq_id_data below std::vector seq_id_unq; @@ -139,6 +141,7 @@ class llama_batch_allocr { std::array seq_id_0 = {{ 0 }}; // default sequence id std::vector pos; + std::vector token_type; std::vector n_seq_id; std::vector seq_id; std::vector seq_id_unq; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 8e2b6ab8e7e..73f968df0dc 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -126,6 +126,17 @@ bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_token_type::set_input(const llama_ubatch * ubatch) { + if (ubatch->token_type && type) { + const int64_t n_tokens = ubatch->n_tokens; + ggml_backend_tensor_set(type, ubatch->token_type, 0, n_tokens*ggml_element_size(type)); + } +} + +bool llm_graph_input_token_type::can_reuse(const llm_graph_params & params) { + return type->ne[0] == params.ubatch.n_tokens; +} + void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) { if (ubatch->pos && attn_scale) { const int64_t n_tokens = ubatch->n_tokens; @@ -1719,6 +1730,15 @@ ggml_tensor * llm_graph_context::build_inp_pos() const { return cur; } +ggml_tensor * llm_graph_context::build_inp_token_type() const { + auto inp = std::make_unique(); + auto & cur = inp->type; + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens); + ggml_set_input(cur); + res->add_input(std::move(inp)); + return cur; +} + ggml_tensor * llm_graph_context::build_inp_attn_scale() const { auto inp = std::make_unique(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset); diff --git a/src/llama-graph.h b/src/llama-graph.h index 29e78451fbb..4753e8375bd 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -133,6 +133,18 @@ class llm_graph_input_pos : public llm_graph_input_i { const uint32_t n_pos_per_embd = 1; }; +class llm_graph_input_token_type: public llm_graph_input_i{ + public: + llm_graph_input_token_type() =default; + virtual ~llm_graph_input_token_type() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * type = nullptr; // I32 [n_batch] +}; + // temperature tuning, used by llama4 class llm_graph_input_attn_temp : public llm_graph_input_i { public: @@ -861,6 +873,7 @@ struct llm_graph_context { ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const; ggml_tensor * build_inp_pos() const; + ggml_tensor * build_inp_token_type() const; ggml_tensor * build_inp_attn_scale() const; ggml_tensor * build_inp_out_ids() const; ggml_tensor * build_inp_mean() const; diff --git a/src/models/bert.cpp b/src/models/bert.cpp index 6ab8c136858..260447d6803 100644 --- a/src/models/bert.cpp +++ b/src/models/bert.cpp @@ -9,6 +9,7 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params ggml_tensor * cur; ggml_tensor * inpL; ggml_tensor * inp_pos = nullptr; + ggml_tensor * inp_token_type; if (model.arch != LLM_ARCH_JINA_BERT_V2) { inp_pos = build_inp_pos(); @@ -17,10 +18,9 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params // construct input embeddings (token, type, position) inpL = build_inp_embd(model.tok_embd); - // token types are hardcoded to zero ("Sentence A") if (model.type_embd) { - ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); - inpL = ggml_add(ctx0, inpL, type_row0); + inp_token_type = build_inp_token_type(); + inpL = ggml_add(ctx0, inpL, ggml_get_rows(ctx0, model.type_embd, inp_token_type)); } if (model.arch == LLM_ARCH_BERT) { inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); diff --git a/tools/batched-bench/batched-bench.cpp b/tools/batched-bench/batched-bench.cpp index 3964ef25955..4a3b664f0f9 100644 --- a/tools/batched-bench/batched-bench.cpp +++ b/tools/batched-bench/batched-bench.cpp @@ -85,6 +85,7 @@ int main(int argc, char ** argv) { batch.token + i, nullptr, batch.pos + i, + batch.token_type +i, batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index 778aacb61d2..bf01ea58268 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -122,6 +122,7 @@ struct decode_embd_batch { std::vector pos; std::vector pos_view; // used by mrope std::vector n_seq_id; + std::vector token_type_ids; std::vector seq_id_0; std::vector seq_ids; std::vector logits; @@ -129,6 +130,8 @@ struct decode_embd_batch { decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) { GGML_ASSERT(n_tokens > 0 && n_pos_per_embd > 0 && n_mmproj_embd > 0); pos .resize(n_tokens * n_pos_per_embd); + token_type_ids.resize(n_tokens); + std::fill(token_type_ids.begin(),token_type_ids.end(),0); n_seq_id.resize(n_tokens); seq_ids .resize(n_tokens + 1); logits .resize(n_tokens); @@ -139,6 +142,7 @@ struct decode_embd_batch { /*tokens =*/ nullptr, /*embd =*/ embd, /*pos =*/ pos.data(), + /*token_type= =*/ token_type_ids.data(), /*n_seq_id =*/ n_seq_id.data(), /*seq_id =*/ seq_ids.data(), /*logits =*/ logits.data(), @@ -221,6 +225,7 @@ struct decode_embd_batch { /*tokens =*/ nullptr, /*embd =*/ batch.embd + offset * n_mmproj_embd, /*pos =*/ pos_ptr, + /*token_type =*/ batch.token_type + offset, /*n_seq_id =*/ batch.n_seq_id + offset, /*seq_id =*/ batch.seq_id + offset, /*logits =*/ batch.logits + offset, diff --git a/tools/perplexity/perplexity.cpp b/tools/perplexity/perplexity.cpp index 6e319ce55d4..c7ac35f7234 100644 --- a/tools/perplexity/perplexity.cpp +++ b/tools/perplexity/perplexity.cpp @@ -670,6 +670,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< batch.token + i, nullptr, batch.pos + i, + batch.token_type +i, batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index b31981c5628..c725c674c98 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2733,6 +2733,7 @@ struct server_context_impl { batch.token + i, nullptr, batch.pos + i, + batch.token_type+i, batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, From 28c1c2804798e12febfd2dd97768cb326f4722a5 Mon Sep 17 00:00:00 2001 From: "jian.chen03" Date: Sat, 11 Apr 2026 11:19:27 +0800 Subject: [PATCH 2/6] Revert "add token_type_ids input for rerank model with type-embedding(like cross-encoder/ms-marco-MiniLM-L6-v2 ...)" This reverts commit c1add0f3505ca27365d948efeda862a6b800422f. --- common/common.cpp | 12 +---- common/common.h | 8 ---- convert_hf_to_gguf.py | 9 +--- examples/embedding/embedding.cpp | 67 ++++++++------------------- examples/parallel/parallel.cpp | 1 - include/llama.h | 2 - src/llama-batch.cpp | 21 +-------- src/llama-batch.h | 3 -- src/llama-graph.cpp | 20 -------- src/llama-graph.h | 13 ------ src/models/bert.cpp | 6 +-- tools/batched-bench/batched-bench.cpp | 1 - tools/mtmd/mtmd-helper.cpp | 5 -- tools/perplexity/perplexity.cpp | 1 - tools/server/server-context.cpp | 1 - 15 files changed, 28 insertions(+), 142 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index dbdd135ddc6..16f78debd02 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1513,21 +1513,10 @@ void common_batch_add( llama_pos pos, const std::vector & seq_ids, bool logits) { - common_batch_add(batch,id,pos,0,seq_ids,logits); -} - -void common_batch_add( - struct llama_batch & batch, - llama_token id, - llama_pos pos, - int32_t token_type, - const std::vector & seq_ids, - bool logits) { GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); batch.token [batch.n_tokens] = id; batch.pos [batch.n_tokens] = pos; - batch.token_type[batch.n_tokens] = token_type; batch.n_seq_id[batch.n_tokens] = seq_ids.size(); for (size_t i = 0; i < seq_ids.size(); ++i) { batch.seq_id[batch.n_tokens][i] = seq_ids[i]; @@ -1536,6 +1525,7 @@ void common_batch_add( batch.n_tokens++; } + // // Vocab utils // diff --git a/common/common.h b/common/common.h index e1d0c34a475..020b6a721ff 100644 --- a/common/common.h +++ b/common/common.h @@ -874,14 +874,6 @@ void common_batch_add( const std::vector & seq_ids, bool logits); -void common_batch_add( - struct llama_batch & batch, - llama_token id, - llama_pos pos, - int32_t token_type, - const std::vector & seq_ids, - bool logits); - // decodes a single batch of tokens for a prompt and manages session tokens // // Note: We save state before the last token so that we can replay it to ensure diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3dd78199191..8d6b0a97a02 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6416,7 +6416,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name.endswith(".beta"): name = name[:-5] + ".bias" - if name == "embeddings.position_ids": + # we are only using BERT for embeddings so we don't need the pooling layer + if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"): return # we don't need these if name.startswith("cls.predictions"): @@ -6433,12 +6434,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name == "classifier.bias": name = "classifier.out_proj.bias" - if name == "pooler.dense.weight": - name = "classifier.weight" - - if name == "pooler.dense.bias": - name = "classifier.bias" - yield from super().modify_tensors(data_torch, name, bid) def _xlmroberta_tokenizer_init(self) -> None: diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index e3111852d08..f6a20ef9d07 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -27,12 +27,10 @@ static std::vector split_lines(const std::string & s, const std::st return lines; } -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, const std::vector & token_types ,llama_seq_id seq_id) { +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { size_t n_tokens = tokens.size(); - bool add_token_type = n_tokens == token_types.size(); for (size_t i = 0; i < n_tokens; i++) { - int32_t token_type = add_token_type? token_types[i]:0; - common_batch_add(batch, tokens[i], i, token_type , { seq_id }, true); + common_batch_add(batch, tokens[i], i, { seq_id }, true); } } @@ -182,57 +180,33 @@ int main(int argc, char ** argv) { // tokenize the prompts and trim std::vector> inputs; - std::vector> token_type_ids; for (const auto & prompt : prompts) { std::vector inp; - std::vector token_type; // split classification pairs and insert expected separator tokens if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) { std::vector pairs = split_lines(prompt, params.cls_sep); - const std::string query = pairs[0]; - const std::string doc = pairs[1]; if (rerank_prompt != nullptr) { + const std::string query = pairs[0]; + const std::string doc = pairs[1]; std::string final_prompt = rerank_prompt; - size_t pos = final_prompt.find("{document}"); - std::string q_prompt = final_prompt.substr(0, pos); - std::string d_prompt = final_prompt.substr(pos); - string_replace_all(q_prompt, "{query}" , query); - string_replace_all(d_prompt, "{document}", doc ); - - auto inp_q= common_tokenize(vocab, q_prompt, false, true); - auto inp_d= common_tokenize(vocab, d_prompt, false, true); - - for(auto token: inp_q){ - inp.emplace_back(token); - token_type.emplace_back(0); - } - for(auto token: inp_d){ - inp.emplace_back(token); - token_type.emplace_back(1); - } + string_replace_all(final_prompt, "{query}" , query); + string_replace_all(final_prompt, "{document}", doc ); + inp = common_tokenize(vocab, final_prompt, true, true); } else { - auto inp_q= common_tokenize(vocab, query, false, false); - auto inp_d= common_tokenize(vocab, doc, false, false); - inp.emplace_back(llama_vocab_bos(vocab)); //add bos - token_type.emplace_back(0); - for(auto token: inp_q){ //add seq A - inp.emplace_back(token); - token_type.emplace_back(0); - } - inp.emplace_back(llama_vocab_eos(vocab)); //add eos - token_type.emplace_back(0); - - inp.emplace_back(llama_vocab_sep(vocab)); //add sep - token_type.emplace_back(0); - - for(auto token: inp_d){ //add seq B - inp.emplace_back(token); - token_type.emplace_back(1); + std::string final_prompt; + for (size_t i = 0; i < pairs.size(); i++) { + final_prompt += pairs[i]; + if (i != pairs.size() - 1) { + if (!added_eos_token.empty()) { + final_prompt += added_eos_token; + } + if (!added_sep_token.empty()) { + final_prompt += added_sep_token; + } + } } - - inp.emplace_back(llama_vocab_eos(vocab)); //add eos - token_type.emplace_back(1); + inp = common_tokenize(ctx, final_prompt, true, true); } } else { inp = common_tokenize(ctx, prompt, true, true); @@ -242,7 +216,6 @@ int main(int argc, char ** argv) { __func__, (long long int) inp.size(), (long long int) n_batch); return 1; } - token_type_ids.push_back(token_type); inputs.push_back(inp); } @@ -305,7 +278,7 @@ int main(int argc, char ** argv) { } // add to batch - batch_add_seq(batch, inp, token_type_ids[k], s); + batch_add_seq(batch, inp, s); s += 1; } diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 211f6890bb4..a46400c5b94 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -397,7 +397,6 @@ int main(int argc, char ** argv) { batch.token + i, nullptr, batch.pos + i, - batch.token_type + i, batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, diff --git a/include/llama.h b/include/llama.h index 83e47879e22..ac267b5089a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -224,7 +224,6 @@ extern "C" { // - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) // - pos : the positions of the respective token in the sequence // (if set to NULL, the token position will be tracked automatically by llama_encode/llama_decode) - // - token_type: the type of token in rerank models (query tokens type is 0, document tokens type is 1) // - seq_id : the sequence to which the respective token belongs // (if set to NULL, the sequence ID will be assumed to be 0) // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output @@ -239,7 +238,6 @@ extern "C" { llama_token * token; float * embd; llama_pos * pos; - int32_t * token_type; int32_t * n_seq_id; llama_seq_id ** seq_id; int8_t * logits; // TODO: rename this to "output" diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 38278e7de97..6bf76939cdd 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -69,11 +69,6 @@ bool llama_batch_allocr::init( // // auto-generate missing fields // - if(!batch.token_type){ - token_type.resize(batch.n_tokens); - std::fill(token_type.begin(),token_type.end(),0); - batch.token_type = token_type.data(); - } if (!batch.n_seq_id) { n_seq_id.resize(batch.n_tokens); @@ -224,7 +219,6 @@ bool llama_batch_allocr::init( /*.token =*/ batch.token, /*.embd =*/ batch.embd, /*.pos =*/ batch.pos, - /*.token_type =*/ batch.token_type, /*.n_seq_id =*/ batch.n_seq_id, /*.seq_id =*/ batch.seq_id, /*.seq_id_unq =*/ this->seq_id_unq.data(), @@ -407,7 +401,6 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t udata->token .resize(n_tokens); udata->embd .clear(); udata->pos .resize(n_pos_all); - udata->token_type.resize(n_tokens); udata->n_seq_id .resize(n_tokens); udata->seq_id .resize(n_tokens); udata->seq_id_unq.resize(0); @@ -430,7 +423,6 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t /*.token =*/ udata->token.data(), /*.embd =*/ nullptr, /*.pos =*/ udata->pos.data(), - /*.token_type =*/ udata->token_type.data(), /*.n_seq_id =*/ udata->n_seq_id.data(), /*.seq_id =*/ udata->seq_id.data(), /*.seq_id_unq =*/ udata->seq_id_unq.data(), @@ -666,7 +658,6 @@ void llama_batch_allocr::clear() { batch = {}; pos .clear(); - token_type.clear(); n_seq_id .clear(); seq_id .clear(); seq_id_unq.clear(); @@ -700,7 +691,6 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u udata->token .resize(n_tokens); udata->embd .resize(n_embd_all); udata->pos .resize(n_pos_all); - udata->token_type.resize(n_tokens); udata->n_seq_id .resize(n_tokens); udata->seq_id .resize(n_tokens); udata->seq_id_unq.resize(0); @@ -729,7 +719,6 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u udata->pos[j*n_tokens + i] = batch.pos[src_off + idxs[i]]; } - udata->token_type[i] = batch.token_type[idxs[i]]; udata->n_seq_id[i] = batch.n_seq_id[idxs[i]]; udata->output[i] = batch.logits[idxs[i]]; @@ -769,7 +758,6 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u /*.token =*/ batch.token ? udata->token.data() : nullptr, /*.embd =*/ batch.embd ? udata->embd.data() : nullptr, /*.pos =*/ udata->pos.data(), - /*token_type =*/ udata->token_type.data(), /*.n_seq_id =*/ udata->n_seq_id.data(), /*.seq_id =*/ udata->seq_id.data(), /*.seq_id_unq =*/ udata->seq_id_unq.data(), @@ -819,7 +807,6 @@ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) { LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) ubatch.token); LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) ubatch.embd); LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) ubatch.pos); - LLAMA_LOG_DEBUG("%s: token_type = %p\n", __func__, (void *) ubatch.token_type); LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) ubatch.n_seq_id); LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) ubatch.seq_id); LLAMA_LOG_DEBUG("%s: seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str()); @@ -856,9 +843,9 @@ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) { } if (ubatch.token) { - LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, token_type = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n", + LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n", __func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(), - ubatch.pos[i], ubatch.token_type[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]); + ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]); } else { LLAMA_LOG_DEBUG("%s: %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n", __func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]); @@ -881,7 +868,6 @@ struct llama_batch llama_batch_get_one( /*tokens =*/ tokens, /*embd =*/ nullptr, /*pos =*/ nullptr, - /*token_type=*/nullptr, /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, /*logits =*/ nullptr, @@ -894,7 +880,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ /*tokens =*/ nullptr, /*embd =*/ nullptr, /*pos =*/ nullptr, - /*token_type=*/nullptr, /*n_seq_id =*/ nullptr, /*seq_id =*/ nullptr, /*logits =*/ nullptr, @@ -907,7 +892,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ } batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc); - batch.token_type=(int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); batch.seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); for (int i = 0; i < n_tokens_alloc; ++i) { @@ -924,7 +908,6 @@ void llama_batch_free(struct llama_batch batch) { if (batch.token) free(batch.token); if (batch.embd) free(batch.embd); if (batch.pos) free(batch.pos); - if (batch.token_type) free(batch.token_type); if (batch.n_seq_id) free(batch.n_seq_id); if (batch.seq_id) { for (int i = 0; batch.seq_id[i] != nullptr; ++i) { diff --git a/src/llama-batch.h b/src/llama-batch.h index e62b569feb6..f77520e86c3 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -45,7 +45,6 @@ struct llama_ubatch { llama_token * token; // [n_tokens] | i | id, token float * embd; // [n_embd, n_tokens] | i | embd llama_pos * pos; // [n_tokens*n_pos] | i | pos - int32_t * token_type; // [n_tokens] | i | token_type int32_t * n_seq_id; // [n_tokens] | i | - llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id @@ -56,7 +55,6 @@ struct llama_ubatch { std::vector token; std::vector embd; std::vector pos; - std::vector token_type; std::vector n_seq_id; std::vector seq_id; // these point into the seq_id_data below std::vector seq_id_unq; @@ -141,7 +139,6 @@ class llama_batch_allocr { std::array seq_id_0 = {{ 0 }}; // default sequence id std::vector pos; - std::vector token_type; std::vector n_seq_id; std::vector seq_id; std::vector seq_id_unq; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 73f968df0dc..8e2b6ab8e7e 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -126,17 +126,6 @@ bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) { return res; } -void llm_graph_input_token_type::set_input(const llama_ubatch * ubatch) { - if (ubatch->token_type && type) { - const int64_t n_tokens = ubatch->n_tokens; - ggml_backend_tensor_set(type, ubatch->token_type, 0, n_tokens*ggml_element_size(type)); - } -} - -bool llm_graph_input_token_type::can_reuse(const llm_graph_params & params) { - return type->ne[0] == params.ubatch.n_tokens; -} - void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) { if (ubatch->pos && attn_scale) { const int64_t n_tokens = ubatch->n_tokens; @@ -1730,15 +1719,6 @@ ggml_tensor * llm_graph_context::build_inp_pos() const { return cur; } -ggml_tensor * llm_graph_context::build_inp_token_type() const { - auto inp = std::make_unique(); - auto & cur = inp->type; - cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, (int64_t)n_tokens); - ggml_set_input(cur); - res->add_input(std::move(inp)); - return cur; -} - ggml_tensor * llm_graph_context::build_inp_attn_scale() const { auto inp = std::make_unique(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset); diff --git a/src/llama-graph.h b/src/llama-graph.h index 4753e8375bd..29e78451fbb 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -133,18 +133,6 @@ class llm_graph_input_pos : public llm_graph_input_i { const uint32_t n_pos_per_embd = 1; }; -class llm_graph_input_token_type: public llm_graph_input_i{ - public: - llm_graph_input_token_type() =default; - virtual ~llm_graph_input_token_type() = default; - - void set_input(const llama_ubatch * ubatch) override; - - bool can_reuse(const llm_graph_params & params) override; - - ggml_tensor * type = nullptr; // I32 [n_batch] -}; - // temperature tuning, used by llama4 class llm_graph_input_attn_temp : public llm_graph_input_i { public: @@ -873,7 +861,6 @@ struct llm_graph_context { ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const; ggml_tensor * build_inp_pos() const; - ggml_tensor * build_inp_token_type() const; ggml_tensor * build_inp_attn_scale() const; ggml_tensor * build_inp_out_ids() const; ggml_tensor * build_inp_mean() const; diff --git a/src/models/bert.cpp b/src/models/bert.cpp index 260447d6803..6ab8c136858 100644 --- a/src/models/bert.cpp +++ b/src/models/bert.cpp @@ -9,7 +9,6 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params ggml_tensor * cur; ggml_tensor * inpL; ggml_tensor * inp_pos = nullptr; - ggml_tensor * inp_token_type; if (model.arch != LLM_ARCH_JINA_BERT_V2) { inp_pos = build_inp_pos(); @@ -18,9 +17,10 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params // construct input embeddings (token, type, position) inpL = build_inp_embd(model.tok_embd); + // token types are hardcoded to zero ("Sentence A") if (model.type_embd) { - inp_token_type = build_inp_token_type(); - inpL = ggml_add(ctx0, inpL, ggml_get_rows(ctx0, model.type_embd, inp_token_type)); + ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); + inpL = ggml_add(ctx0, inpL, type_row0); } if (model.arch == LLM_ARCH_BERT) { inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); diff --git a/tools/batched-bench/batched-bench.cpp b/tools/batched-bench/batched-bench.cpp index 4a3b664f0f9..3964ef25955 100644 --- a/tools/batched-bench/batched-bench.cpp +++ b/tools/batched-bench/batched-bench.cpp @@ -85,7 +85,6 @@ int main(int argc, char ** argv) { batch.token + i, nullptr, batch.pos + i, - batch.token_type +i, batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index bf01ea58268..778aacb61d2 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -122,7 +122,6 @@ struct decode_embd_batch { std::vector pos; std::vector pos_view; // used by mrope std::vector n_seq_id; - std::vector token_type_ids; std::vector seq_id_0; std::vector seq_ids; std::vector logits; @@ -130,8 +129,6 @@ struct decode_embd_batch { decode_embd_batch(float * embd, int32_t n_tokens, int n_pos_per_embd, int n_mmproj_embd) : n_pos_per_embd(n_pos_per_embd), n_mmproj_embd(n_mmproj_embd) { GGML_ASSERT(n_tokens > 0 && n_pos_per_embd > 0 && n_mmproj_embd > 0); pos .resize(n_tokens * n_pos_per_embd); - token_type_ids.resize(n_tokens); - std::fill(token_type_ids.begin(),token_type_ids.end(),0); n_seq_id.resize(n_tokens); seq_ids .resize(n_tokens + 1); logits .resize(n_tokens); @@ -142,7 +139,6 @@ struct decode_embd_batch { /*tokens =*/ nullptr, /*embd =*/ embd, /*pos =*/ pos.data(), - /*token_type= =*/ token_type_ids.data(), /*n_seq_id =*/ n_seq_id.data(), /*seq_id =*/ seq_ids.data(), /*logits =*/ logits.data(), @@ -225,7 +221,6 @@ struct decode_embd_batch { /*tokens =*/ nullptr, /*embd =*/ batch.embd + offset * n_mmproj_embd, /*pos =*/ pos_ptr, - /*token_type =*/ batch.token_type + offset, /*n_seq_id =*/ batch.n_seq_id + offset, /*seq_id =*/ batch.seq_id + offset, /*logits =*/ batch.logits + offset, diff --git a/tools/perplexity/perplexity.cpp b/tools/perplexity/perplexity.cpp index c7ac35f7234..6e319ce55d4 100644 --- a/tools/perplexity/perplexity.cpp +++ b/tools/perplexity/perplexity.cpp @@ -670,7 +670,6 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< batch.token + i, nullptr, batch.pos + i, - batch.token_type +i, batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index c725c674c98..b31981c5628 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -2733,7 +2733,6 @@ struct server_context_impl { batch.token + i, nullptr, batch.pos + i, - batch.token_type+i, batch.n_seq_id + i, batch.seq_id + i, batch.logits + i, From f7765be6f578f8053c6367d868945438885da286 Mon Sep 17 00:00:00 2001 From: "jian.chen03" Date: Sat, 11 Apr 2026 15:38:57 +0800 Subject: [PATCH 3/6] refactor Signed-off-by: jian.chen03 --- examples/embedding/embedding.cpp | 47 ++++++++++++++++++++------------ src/llama-batch.cpp | 16 +++++------ src/llama-graph.cpp | 12 +++++++- src/llama-graph.h | 4 +++ src/models/bert.cpp | 8 ++++-- 5 files changed, 58 insertions(+), 29 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index f6a20ef9d07..8d089f3d7f9 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -169,6 +169,7 @@ int main(int argc, char ** argv) { // split the prompt into lines std::vector prompts = split_lines(params.prompt, params.embd_sep); + int32_t token_type_offset = llama_vocab_n_tokens(vocab); // max batch size const uint64_t n_batch = params.n_batch; @@ -186,27 +187,39 @@ int main(int argc, char ** argv) { // split classification pairs and insert expected separator tokens if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) { std::vector pairs = split_lines(prompt, params.cls_sep); + const std::string query = pairs[0]; + const std::string doc = pairs[1]; if (rerank_prompt != nullptr) { - const std::string query = pairs[0]; - const std::string doc = pairs[1]; std::string final_prompt = rerank_prompt; - string_replace_all(final_prompt, "{query}" , query); - string_replace_all(final_prompt, "{document}", doc ); - inp = common_tokenize(vocab, final_prompt, true, true); + size_t pos = final_prompt.find("{document}"); + std::string q_prompt = final_prompt.substr(0, pos); + std::string d_prompt = final_prompt.substr(pos); + string_replace_all(q_prompt, "{query}" , query); + string_replace_all(d_prompt, "{document}", doc ); + + auto inp_q= common_tokenize(vocab, q_prompt, false, true); + auto inp_d= common_tokenize(vocab, d_prompt, false, true); + + for(auto token: inp_q){ + inp.emplace_back(token); + + } + for(auto token: inp_d){ + inp.emplace_back(token + token_type_offset); + } } else { - std::string final_prompt; - for (size_t i = 0; i < pairs.size(); i++) { - final_prompt += pairs[i]; - if (i != pairs.size() - 1) { - if (!added_eos_token.empty()) { - final_prompt += added_eos_token; - } - if (!added_sep_token.empty()) { - final_prompt += added_sep_token; - } - } + auto inp_q= common_tokenize(vocab, query, false, false); + auto inp_d= common_tokenize(vocab, doc, false, false); + inp.emplace_back(llama_vocab_bos(vocab)); //add bos + inp.insert(inp.end(), inp_q.begin(), inp_q.end());//add seq A + inp.emplace_back(llama_vocab_eos(vocab)); //add eos + inp.emplace_back(llama_vocab_sep(vocab)); //add sep + + for(auto token: inp_d){ //add seq B + inp.emplace_back(token + token_type_offset); + } - inp = common_tokenize(ctx, final_prompt, true, true); + inp.emplace_back(llama_vocab_eos(vocab) +token_type_offset); //add eos } } else { inp = common_tokenize(ctx, prompt, true, true); diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 6bf76939cdd..92b68b8b347 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -46,14 +46,14 @@ bool llama_batch_allocr::init( return false; } - if (batch.token) { - for (int32_t i = 0; i < batch.n_tokens; ++i) { - if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) { - LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); - return false; - } - } - } + // if (batch.token) { + // for (int32_t i = 0; i < batch.n_tokens; ++i) { + // if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) { + // LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); + // return false; + // } + // } + // } if (batch.seq_id) { for (int32_t i = 0; i < batch.n_tokens; ++i) { diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 8e2b6ab8e7e..988e0a122ca 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1633,7 +1633,6 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); cb(inp->tokens, "inp_tokens", -1); ggml_set_input(inp->tokens); - res->t_inp_tokens = inp->tokens; inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens); cb(inp->embd, "inp_embd", -1); @@ -1643,6 +1642,17 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { // ref: https://github.com/ggml-org/llama.cpp/pull/18550 std::array inps; + if (res->remove_token_offset) { + auto tokens_f32 = ggml_cast(ctx0, inp->tokens,GGML_TYPE_F32); + res->t_token_types = ggml_scale(ctx0, tokens_f32,1.0f/res->token_offset); + res->t_token_types = ggml_floor_inplace(ctx0, res->t_token_types); + inp->tokens = ggml_sub(ctx0, tokens_f32, ggml_scale(ctx0, res->t_token_types,res->token_offset)); + inp->tokens = ggml_cast (ctx0, inp->tokens,GGML_TYPE_I32); + res->t_token_types = ggml_cast(ctx0,res->t_token_types,GGML_TYPE_I32); + ggml_build_forward_expand(gf, res->t_token_types); + } + + res->t_inp_tokens = inp->tokens; // token embeddings path (ubatch.token != nullptr) { auto & cur = inps[0]; diff --git a/src/llama-graph.h b/src/llama-graph.h index 29e78451fbb..014446acb08 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -671,6 +671,10 @@ class llm_graph_result { ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + ggml_tensor * t_token_types = nullptr; + + bool remove_token_offset = false; + int32_t token_offset = 0; std::map t_sampled_logits; std::map t_candidates; diff --git a/src/models/bert.cpp b/src/models/bert.cpp index 6ab8c136858..1b3537edaae 100644 --- a/src/models/bert.cpp +++ b/src/models/bert.cpp @@ -14,13 +14,15 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params inp_pos = build_inp_pos(); } + if (model.type_embd) { + res->remove_token_offset = true; + res->token_offset = llama_vocab_n_tokens(&model.vocab); + } // construct input embeddings (token, type, position) inpL = build_inp_embd(model.tok_embd); - // token types are hardcoded to zero ("Sentence A") if (model.type_embd) { - ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); - inpL = ggml_add(ctx0, inpL, type_row0); + inpL = ggml_add(ctx0, inpL, ggml_get_rows(ctx0, model.type_embd, res->t_token_types)); } if (model.arch == LLM_ARCH_BERT) { inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); From 3919494beedcb4cd9641ca8016cd3c8054fec1a5 Mon Sep 17 00:00:00 2001 From: "jian.chen03" Date: Sat, 11 Apr 2026 16:30:08 +0800 Subject: [PATCH 4/6] updata convert script Signed-off-by: jian.chen03 --- convert_hf_to_gguf.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 8d6b0a97a02..2dc3d4ce8a5 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6417,7 +6417,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter name = name[:-5] + ".bias" # we are only using BERT for embeddings so we don't need the pooling layer - if name in ("embeddings.position_ids", "pooler.dense.weight", "pooler.dense.bias"): + if name == "embeddings.position_ids": return # we don't need these if name.startswith("cls.predictions"): @@ -6434,6 +6434,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if name == "classifier.bias": name = "classifier.out_proj.bias" + if name == "pooler.dense.weight": + name = "classifier.weight" + + if name == "pooler.dense.bias": + name = "classifier.bias" + yield from super().modify_tensors(data_torch, name, bid) def _xlmroberta_tokenizer_init(self) -> None: From bc4d2c14b147ffc1a4ec588e5e2cb3483155d1dd Mon Sep 17 00:00:00 2001 From: "jian.chen03" Date: Mon, 13 Apr 2026 11:23:18 +0800 Subject: [PATCH 5/6] Revert "refactor" This reverts commit f7765be6f578f8053c6367d868945438885da286. --- examples/embedding/embedding.cpp | 47 ++++++++++++-------------------- src/llama-batch.cpp | 16 +++++------ src/llama-graph.cpp | 12 +------- src/llama-graph.h | 4 --- src/models/bert.cpp | 8 ++---- 5 files changed, 29 insertions(+), 58 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 8d089f3d7f9..f6a20ef9d07 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -169,7 +169,6 @@ int main(int argc, char ** argv) { // split the prompt into lines std::vector prompts = split_lines(params.prompt, params.embd_sep); - int32_t token_type_offset = llama_vocab_n_tokens(vocab); // max batch size const uint64_t n_batch = params.n_batch; @@ -187,39 +186,27 @@ int main(int argc, char ** argv) { // split classification pairs and insert expected separator tokens if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) { std::vector pairs = split_lines(prompt, params.cls_sep); - const std::string query = pairs[0]; - const std::string doc = pairs[1]; if (rerank_prompt != nullptr) { + const std::string query = pairs[0]; + const std::string doc = pairs[1]; std::string final_prompt = rerank_prompt; - size_t pos = final_prompt.find("{document}"); - std::string q_prompt = final_prompt.substr(0, pos); - std::string d_prompt = final_prompt.substr(pos); - string_replace_all(q_prompt, "{query}" , query); - string_replace_all(d_prompt, "{document}", doc ); - - auto inp_q= common_tokenize(vocab, q_prompt, false, true); - auto inp_d= common_tokenize(vocab, d_prompt, false, true); - - for(auto token: inp_q){ - inp.emplace_back(token); - - } - for(auto token: inp_d){ - inp.emplace_back(token + token_type_offset); - } + string_replace_all(final_prompt, "{query}" , query); + string_replace_all(final_prompt, "{document}", doc ); + inp = common_tokenize(vocab, final_prompt, true, true); } else { - auto inp_q= common_tokenize(vocab, query, false, false); - auto inp_d= common_tokenize(vocab, doc, false, false); - inp.emplace_back(llama_vocab_bos(vocab)); //add bos - inp.insert(inp.end(), inp_q.begin(), inp_q.end());//add seq A - inp.emplace_back(llama_vocab_eos(vocab)); //add eos - inp.emplace_back(llama_vocab_sep(vocab)); //add sep - - for(auto token: inp_d){ //add seq B - inp.emplace_back(token + token_type_offset); - + std::string final_prompt; + for (size_t i = 0; i < pairs.size(); i++) { + final_prompt += pairs[i]; + if (i != pairs.size() - 1) { + if (!added_eos_token.empty()) { + final_prompt += added_eos_token; + } + if (!added_sep_token.empty()) { + final_prompt += added_sep_token; + } + } } - inp.emplace_back(llama_vocab_eos(vocab) +token_type_offset); //add eos + inp = common_tokenize(ctx, final_prompt, true, true); } } else { inp = common_tokenize(ctx, prompt, true, true); diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 92b68b8b347..6bf76939cdd 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -46,14 +46,14 @@ bool llama_batch_allocr::init( return false; } - // if (batch.token) { - // for (int32_t i = 0; i < batch.n_tokens; ++i) { - // if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) { - // LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); - // return false; - // } - // } - // } + if (batch.token) { + for (int32_t i = 0; i < batch.n_tokens; ++i) { + if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) { + LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); + return false; + } + } + } if (batch.seq_id) { for (int32_t i = 0; i < batch.n_tokens; ++i) { diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 988e0a122ca..8e2b6ab8e7e 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1633,6 +1633,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); cb(inp->tokens, "inp_tokens", -1); ggml_set_input(inp->tokens); + res->t_inp_tokens = inp->tokens; inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens); cb(inp->embd, "inp_embd", -1); @@ -1642,17 +1643,6 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { // ref: https://github.com/ggml-org/llama.cpp/pull/18550 std::array inps; - if (res->remove_token_offset) { - auto tokens_f32 = ggml_cast(ctx0, inp->tokens,GGML_TYPE_F32); - res->t_token_types = ggml_scale(ctx0, tokens_f32,1.0f/res->token_offset); - res->t_token_types = ggml_floor_inplace(ctx0, res->t_token_types); - inp->tokens = ggml_sub(ctx0, tokens_f32, ggml_scale(ctx0, res->t_token_types,res->token_offset)); - inp->tokens = ggml_cast (ctx0, inp->tokens,GGML_TYPE_I32); - res->t_token_types = ggml_cast(ctx0,res->t_token_types,GGML_TYPE_I32); - ggml_build_forward_expand(gf, res->t_token_types); - } - - res->t_inp_tokens = inp->tokens; // token embeddings path (ubatch.token != nullptr) { auto & cur = inps[0]; diff --git a/src/llama-graph.h b/src/llama-graph.h index 014446acb08..29e78451fbb 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -671,10 +671,6 @@ class llm_graph_result { ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; - ggml_tensor * t_token_types = nullptr; - - bool remove_token_offset = false; - int32_t token_offset = 0; std::map t_sampled_logits; std::map t_candidates; diff --git a/src/models/bert.cpp b/src/models/bert.cpp index 1b3537edaae..6ab8c136858 100644 --- a/src/models/bert.cpp +++ b/src/models/bert.cpp @@ -14,15 +14,13 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params inp_pos = build_inp_pos(); } - if (model.type_embd) { - res->remove_token_offset = true; - res->token_offset = llama_vocab_n_tokens(&model.vocab); - } // construct input embeddings (token, type, position) inpL = build_inp_embd(model.tok_embd); + // token types are hardcoded to zero ("Sentence A") if (model.type_embd) { - inpL = ggml_add(ctx0, inpL, ggml_get_rows(ctx0, model.type_embd, res->t_token_types)); + ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); + inpL = ggml_add(ctx0, inpL, type_row0); } if (model.arch == LLM_ARCH_BERT) { inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); From 99bb3d232265728e880b89ae3a27b7e82371057c Mon Sep 17 00:00:00 2001 From: "jian.chen03" Date: Mon, 13 Apr 2026 14:05:50 +0800 Subject: [PATCH 6/6] update Signed-off-by: jian.chen03 --- examples/embedding/CMakeLists.txt | 1 + examples/embedding/embedding.cpp | 64 +++++++++++++++++++++---------- src/llama-batch.cpp | 25 +++++++++--- src/llama-batch.h | 3 ++ src/llama-graph.cpp | 20 ++++++++++ src/llama-graph.h | 13 +++++++ src/models/bert.cpp | 6 +-- tools/server/server-common.cpp | 33 ++++++++++++---- 8 files changed, 129 insertions(+), 36 deletions(-) diff --git a/examples/embedding/CMakeLists.txt b/examples/embedding/CMakeLists.txt index 809040307d2..523d201a4b1 100644 --- a/examples/embedding/CMakeLists.txt +++ b/examples/embedding/CMakeLists.txt @@ -2,4 +2,5 @@ set(TARGET llama-embedding) add_executable(${TARGET} embedding.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index f6a20ef9d07..2a2a52f771c 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -2,9 +2,9 @@ #include "common.h" #include "log.h" #include "llama.h" - +#include +#include #include -#include #include #if defined(_MSC_VER) @@ -169,13 +169,12 @@ int main(int argc, char ** argv) { // split the prompt into lines std::vector prompts = split_lines(params.prompt, params.embd_sep); + int32_t token_type_offset = llama_vocab_n_tokens(vocab); // max batch size const uint64_t n_batch = params.n_batch; // get added sep and eos token, if any - const std::string added_sep_token = llama_vocab_get_add_sep(vocab) ? llama_vocab_get_text(vocab, llama_vocab_sep(vocab)) : ""; - const std::string added_eos_token = llama_vocab_get_add_eos(vocab) ? llama_vocab_get_text(vocab, llama_vocab_eos(vocab)) : ""; const char * rerank_prompt = llama_model_chat_template(model, "rerank"); // tokenize the prompts and trim @@ -186,27 +185,50 @@ int main(int argc, char ** argv) { // split classification pairs and insert expected separator tokens if (pooling_type == LLAMA_POOLING_TYPE_RANK && prompt.find(params.cls_sep) != std::string::npos) { std::vector pairs = split_lines(prompt, params.cls_sep); + const std::string query = pairs[0]; + const std::string doc = pairs[1]; if (rerank_prompt != nullptr) { - const std::string query = pairs[0]; - const std::string doc = pairs[1]; std::string final_prompt = rerank_prompt; - string_replace_all(final_prompt, "{query}" , query); - string_replace_all(final_prompt, "{document}", doc ); - inp = common_tokenize(vocab, final_prompt, true, true); + size_t pos = final_prompt.find("{document}"); + std::string query_prompt = final_prompt.substr(0, pos); + std::string doc_prompt = final_prompt.substr(pos); + string_replace_all(query_prompt, "{query}" , query); + string_replace_all(doc_prompt, "{document}", doc ); + + auto inp_q= common_tokenize(vocab, query_prompt, false, true); + auto inp_d= common_tokenize(vocab, doc_prompt, false, true); + + for(auto token: inp_q){ + inp.emplace_back(token); + } + for(auto token: inp_d){ + inp.emplace_back(model->arch == LLM_ARCH_BERT ? token + token_type_offset : token ); + } } else { - std::string final_prompt; - for (size_t i = 0; i < pairs.size(); i++) { - final_prompt += pairs[i]; - if (i != pairs.size() - 1) { - if (!added_eos_token.empty()) { - final_prompt += added_eos_token; - } - if (!added_sep_token.empty()) { - final_prompt += added_sep_token; - } - } + llama_token eos_token = llama_vocab_eos(vocab); + if (eos_token == LLAMA_TOKEN_NULL) { + eos_token = llama_vocab_sep(vocab); + } + + auto inp_q= common_tokenize(vocab, query, false, false); + auto inp_d= common_tokenize(vocab, doc, false, false); + if (llama_vocab_get_add_bos(vocab)) { + inp.emplace_back(llama_vocab_bos(vocab)); //add bos + } + inp.insert(inp.end(), inp_q.begin(), inp_q.end());//add seq A + if (llama_vocab_get_add_eos(vocab)) { + inp.emplace_back(eos_token); //add eos + } + if (llama_vocab_get_add_sep(vocab)) { + inp.emplace_back(llama_vocab_sep(vocab)); //add sep + } + for(auto token: inp_d){ //add seq B + inp.emplace_back(model->arch == LLM_ARCH_BERT ? token + token_type_offset : token); + + } + if (llama_vocab_get_add_eos(vocab)) { + inp.emplace_back(model->arch == LLM_ARCH_BERT ? eos_token + token_type_offset : eos_token); //add eos } - inp = common_tokenize(ctx, final_prompt, true, true); } } else { inp = common_tokenize(ctx, prompt, true, true); diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 6bf76939cdd..0695630d1a4 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -46,11 +46,12 @@ bool llama_batch_allocr::init( return false; } + int32_t vocab_size = vocab.n_tokens(); if (batch.token) { for (int32_t i = 0; i < batch.n_tokens; ++i) { - if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) { - LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); - return false; + if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab_size) { + LLAMA_LOG_WARN("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]); + break; } } } @@ -69,6 +70,12 @@ bool llama_batch_allocr::init( // // auto-generate missing fields // + token_type_ids.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; ++i) { + int32_t token_type = batch.token[i] / vocab_size; + batch.token[i] = batch.token[i] - token_type * vocab_size; + token_type_ids[i] = token_type; + } if (!batch.n_seq_id) { n_seq_id.resize(batch.n_tokens); @@ -219,6 +226,7 @@ bool llama_batch_allocr::init( /*.token =*/ batch.token, /*.embd =*/ batch.embd, /*.pos =*/ batch.pos, + /*.token_type =*/ token_type_ids.data(), /*.n_seq_id =*/ batch.n_seq_id, /*.seq_id =*/ batch.seq_id, /*.seq_id_unq =*/ this->seq_id_unq.data(), @@ -401,6 +409,7 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t udata->token .resize(n_tokens); udata->embd .clear(); udata->pos .resize(n_pos_all); + udata->token_type.resize(n_tokens); udata->n_seq_id .resize(n_tokens); udata->seq_id .resize(n_tokens); udata->seq_id_unq.resize(0); @@ -423,6 +432,7 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t /*.token =*/ udata->token.data(), /*.embd =*/ nullptr, /*.pos =*/ udata->pos.data(), + /*.token_type =*/ udata->token_type.data(), /*.n_seq_id =*/ udata->n_seq_id.data(), /*.seq_id =*/ udata->seq_id.data(), /*.seq_id_unq =*/ udata->seq_id_unq.data(), @@ -658,6 +668,7 @@ void llama_batch_allocr::clear() { batch = {}; pos .clear(); + token_type_ids.clear(); n_seq_id .clear(); seq_id .clear(); seq_id_unq.clear(); @@ -691,6 +702,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u udata->token .resize(n_tokens); udata->embd .resize(n_embd_all); udata->pos .resize(n_pos_all); + udata->token_type.resize(n_tokens); udata->n_seq_id .resize(n_tokens); udata->seq_id .resize(n_tokens); udata->seq_id_unq.resize(0); @@ -719,6 +731,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u udata->pos[j*n_tokens + i] = batch.pos[src_off + idxs[i]]; } + udata->token_type[i] = token_type_ids[idxs[i]]; udata->n_seq_id[i] = batch.n_seq_id[idxs[i]]; udata->output[i] = batch.logits[idxs[i]]; @@ -758,6 +771,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u /*.token =*/ batch.token ? udata->token.data() : nullptr, /*.embd =*/ batch.embd ? udata->embd.data() : nullptr, /*.pos =*/ udata->pos.data(), + /*token_type =*/ udata->token_type.data(), /*.n_seq_id =*/ udata->n_seq_id.data(), /*.seq_id =*/ udata->seq_id.data(), /*.seq_id_unq =*/ udata->seq_id_unq.data(), @@ -807,6 +821,7 @@ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) { LLAMA_LOG_DEBUG("%s: token = %p\n", __func__, (void *) ubatch.token); LLAMA_LOG_DEBUG("%s: embd = %p\n", __func__, (void *) ubatch.embd); LLAMA_LOG_DEBUG("%s: pos = %p\n", __func__, (void *) ubatch.pos); + LLAMA_LOG_DEBUG("%s: token_type = %p\n", __func__, (void *) ubatch.token_type); LLAMA_LOG_DEBUG("%s: n_seq_id = %p\n", __func__, (void *) ubatch.n_seq_id); LLAMA_LOG_DEBUG("%s: seq_id = %p\n", __func__, (void *) ubatch.seq_id); LLAMA_LOG_DEBUG("%s: seq_id_unq = %s\n", __func__, ss_seq_id_unq.str().c_str()); @@ -843,9 +858,9 @@ void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) { } if (ubatch.token) { - LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n", + LLAMA_LOG_DEBUG("%s: %4d: id = %6d (%16s), pos = %4d, token_type = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n", __func__, i, ubatch.token[i], vocab->token_to_piece(ubatch.token[i]).c_str(), - ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]); + ubatch.pos[i], ubatch.token_type[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]); } else { LLAMA_LOG_DEBUG("%s: %4d: [embd], pos = %4d, n_seq_id = %2d, seq_id = [%s], output = %d\n", __func__, i, ubatch.pos[i], ubatch.n_seq_id[i], ss.str().c_str(), ubatch.output[i]); diff --git a/src/llama-batch.h b/src/llama-batch.h index f77520e86c3..6746c3648b5 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -45,6 +45,7 @@ struct llama_ubatch { llama_token * token; // [n_tokens] | i | id, token float * embd; // [n_embd, n_tokens] | i | embd llama_pos * pos; // [n_tokens*n_pos] | i | pos + int32_t * token_type; // [n_tokens] | i | token_type int32_t * n_seq_id; // [n_tokens] | i | - llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id @@ -55,6 +56,7 @@ struct llama_ubatch { std::vector token; std::vector embd; std::vector pos; + std::vector token_type; std::vector n_seq_id; std::vector seq_id; // these point into the seq_id_data below std::vector seq_id_unq; @@ -139,6 +141,7 @@ class llama_batch_allocr { std::array seq_id_0 = {{ 0 }}; // default sequence id std::vector pos; + std::vector token_type_ids; std::vector n_seq_id; std::vector seq_id; std::vector seq_id_unq; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 8e2b6ab8e7e..71fa85232ae 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -126,6 +126,17 @@ bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_token_type::set_input(const llama_ubatch * ubatch) { + if (ubatch->token_type && type) { + const int64_t n_tokens = ubatch->n_tokens; + ggml_backend_tensor_set(type, ubatch->token_type, 0, n_tokens*ggml_element_size(type)); + } +} + +bool llm_graph_input_token_type::can_reuse(const llm_graph_params & params) { + return type->ne[0] == params.ubatch.n_tokens; +} + void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) { if (ubatch->pos && attn_scale) { const int64_t n_tokens = ubatch->n_tokens; @@ -1719,6 +1730,15 @@ ggml_tensor * llm_graph_context::build_inp_pos() const { return cur; } +ggml_tensor * llm_graph_context::build_inp_token_type() const { + auto inp = std::make_unique(); + auto & cur = inp->type; + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(cur); + res->add_input(std::move(inp)); + return cur; +} + ggml_tensor * llm_graph_context::build_inp_attn_scale() const { auto inp = std::make_unique(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale, hparams.f_attn_temp_offset); diff --git a/src/llama-graph.h b/src/llama-graph.h index 29e78451fbb..4753e8375bd 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -133,6 +133,18 @@ class llm_graph_input_pos : public llm_graph_input_i { const uint32_t n_pos_per_embd = 1; }; +class llm_graph_input_token_type: public llm_graph_input_i{ + public: + llm_graph_input_token_type() =default; + virtual ~llm_graph_input_token_type() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * type = nullptr; // I32 [n_batch] +}; + // temperature tuning, used by llama4 class llm_graph_input_attn_temp : public llm_graph_input_i { public: @@ -861,6 +873,7 @@ struct llm_graph_context { ggml_tensor * build_inp_embd(ggml_tensor * tok_embd) const; ggml_tensor * build_inp_pos() const; + ggml_tensor * build_inp_token_type() const; ggml_tensor * build_inp_attn_scale() const; ggml_tensor * build_inp_out_ids() const; ggml_tensor * build_inp_mean() const; diff --git a/src/models/bert.cpp b/src/models/bert.cpp index 6ab8c136858..260447d6803 100644 --- a/src/models/bert.cpp +++ b/src/models/bert.cpp @@ -9,6 +9,7 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params ggml_tensor * cur; ggml_tensor * inpL; ggml_tensor * inp_pos = nullptr; + ggml_tensor * inp_token_type; if (model.arch != LLM_ARCH_JINA_BERT_V2) { inp_pos = build_inp_pos(); @@ -17,10 +18,9 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params // construct input embeddings (token, type, position) inpL = build_inp_embd(model.tok_embd); - // token types are hardcoded to zero ("Sentence A") if (model.type_embd) { - ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0); - inpL = ggml_add(ctx0, inpL, type_row0); + inp_token_type = build_inp_token_type(); + inpL = ggml_add(ctx0, inpL, ggml_get_rows(ctx0, model.type_embd, inp_token_type)); } if (model.arch == LLM_ARCH_BERT) { inpL = ggml_add(ctx0, ggml_get_rows(ctx0, model.pos_embd, inp_pos), inpL); diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index ed5e306fc5b..856a5042a7e 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -7,6 +7,8 @@ #include "chat.h" #include "base64.hpp" +#include +#include #include "server-common.h" #include @@ -503,7 +505,8 @@ bool server_tokens::validate(const struct llama_context * ctx) const { return false; } } else if (t < 0 || t >= n_vocab) { - return false; + // t = t + token_type_id*n_vocab , if arch == bert + return model->arch==LLM_ARCH_BERT; } } return true; @@ -2037,17 +2040,33 @@ server_tokens format_prompt_rerank( server_tokens result = {}; const char * rerank_prompt = llama_model_chat_template(model, "rerank"); - + auto vocab_size = llama_vocab_n_tokens(vocab); if (rerank_prompt != nullptr) { std::string prompt = rerank_prompt; - string_replace_all(prompt, "{query}" , query); - string_replace_all(prompt, "{document}", doc ); - server_tokens tokens = tokenize_input_subprompt(vocab, mctx, prompt, false, true); - result.push_back(tokens); + size_t pos = prompt.find("{document}"); + std::string query_prompt = prompt.substr(0, pos); + std::string doc_prompt = prompt.substr(pos); + string_replace_all(query_prompt, "{query}" , query); + string_replace_all(doc_prompt, "{document}", doc ); + auto query_tokens= tokenize_input_subprompt(vocab, mctx,query_prompt, false, true); + auto doc_tokens= tokenize_input_subprompt(vocab, mctx,doc_prompt, false, true); + if (model->arch == LLM_ARCH_BERT){ + // token_id = token_id + token_type_ids*vocab_size + for (int32_t i = 0; i < doc_tokens.size(); i++) { + doc_tokens.set_token(i,doc_tokens[i] + vocab_size); + } + } + result.push_back(query_tokens); + result.push_back(doc_tokens); } else { // Get EOS token - use SEP token as fallback if EOS is not available server_tokens query_tokens = tokenize_input_subprompt(vocab, mctx, query, false, false); server_tokens doc_tokens = tokenize_input_subprompt(vocab, mctx, doc, false, false); + if (model->arch == LLM_ARCH_BERT) { + for (int32_t i = 0; i < doc_tokens.size(); i++) { + doc_tokens.set_token(i,doc_tokens[i] + vocab_size); + } + } llama_token eos_token = llama_vocab_eos(vocab); if (eos_token == LLAMA_TOKEN_NULL) { eos_token = llama_vocab_sep(vocab); @@ -2065,7 +2084,7 @@ server_tokens format_prompt_rerank( } result.push_back(doc_tokens); if (llama_vocab_get_add_eos(vocab)) { - result.push_back(eos_token); + result.push_back(model->arch == LLM_ARCH_BERT ? eos_token+ vocab_size : eos_token); } }