Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "llama-graph.h"

#include "llama-impl.h"
#include "llama-model.h"
#include "llama-batch.h"
#include "llama-cparams.h"

Expand Down Expand Up @@ -1043,6 +1044,84 @@ ggml_tensor * llm_graph_context::build_norm(
return cur;
}


llm_graph_qkv llm_graph_context::build_qkv(
const llama_layer & layer,
ggml_tensor * cur,
int64_t n_embd_head,
int64_t n_head,
int64_t n_head_kv,
int il) const {
const int64_t n_embd_q = n_embd_head * n_head;
const int64_t n_embd_kv = n_embd_head * n_head_kv;

ggml_tensor * Qcur, * Kcur, * Vcur;

if (layer.wqkv) {
// fused QKV path
ggml_tensor * qkv = build_lora_mm(layer.wqkv, cur, layer.wqkv_s);
cb(qkv, "wqkv", il);
if (layer.bqkv) {
qkv = ggml_add(ctx0, qkv, layer.bqkv);
cb(qkv, "bqkv", il);
}
if (hparams.f_clamp_kqv > 0.0f) {
qkv = ggml_clamp(ctx0, qkv, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
cb(qkv, "wqkv_clamped", il);
}
Qcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head, n_tokens,
ggml_element_size(qkv) * n_embd_head, qkv->nb[1], 0);
Kcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens,
ggml_element_size(qkv) * n_embd_head, qkv->nb[1],
ggml_element_size(qkv) * n_embd_q);
Vcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens,
ggml_element_size(qkv) * n_embd_head, qkv->nb[1],
ggml_element_size(qkv) * (n_embd_q + n_embd_kv));
} else {
// separate Q/K/V path
Qcur = build_lora_mm(layer.wq, cur, layer.wq_s);
cb(Qcur, "Qcur", il);
if (layer.bq) {
Qcur = ggml_add(ctx0, Qcur, layer.bq);
cb(Qcur, "Qcur", il);
}
if (hparams.f_clamp_kqv > 0.0f) {
Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
cb(Qcur, "Qcur_clamped", il);
}
Kcur = build_lora_mm(layer.wk, cur, layer.wk_s);
cb(Kcur, "Kcur", il);
if (layer.bk) {
Kcur = ggml_add(ctx0, Kcur, layer.bk);
cb(Kcur, "Kcur", il);
}
if (hparams.f_clamp_kqv > 0.0f) {
Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
cb(Kcur, "Kcur_clamped", il);
}
Vcur = build_lora_mm(layer.wv, cur, layer.wv_s);
cb(Vcur, "Vcur", il);
if (layer.bv) {
Vcur = ggml_add(ctx0, Vcur, layer.bv);
cb(Vcur, "Vcur", il);
}
if (hparams.f_clamp_kqv > 0.0f) {
Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
cb(Vcur, "Vcur_clamped", il);
}
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
}

cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);

return { Qcur, Kcur, Vcur };
}


ggml_tensor * llm_graph_context::build_ffn(
ggml_tensor * cur,
ggml_tensor * up,
Expand Down
18 changes: 18 additions & 0 deletions src/llama-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ struct ggml_context;
struct ggml_tensor;

struct llama_cparams;
struct llama_layer;

struct llama_memory_context_i;

Expand Down Expand Up @@ -705,6 +706,12 @@ using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>;
// used in build_rs to properly order writes and avoid unnecessary copies
using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>;

struct llm_graph_qkv {
ggml_tensor * q; // [n_embd_head, n_head, n_tokens]
ggml_tensor * k; // [n_embd_head, n_head_kv, n_tokens]
ggml_tensor * v; // [n_embd_head, n_head_kv, n_tokens]
};

struct llm_graph_context {
const llm_arch arch;

Expand Down Expand Up @@ -791,6 +798,17 @@ struct llm_graph_context {
llm_norm_type type,
int il) const;


// compute Q, K, V projections with optional bias and reshape
// supports both fused wqkv and separate wq/wk/wv paths
llm_graph_qkv build_qkv(
const llama_layer & layer,
ggml_tensor * cur,
int64_t n_embd_head,
int64_t n_head,
int64_t n_head_kv,
int il) const;

ggml_tensor * build_ffn(
ggml_tensor * cur,
ggml_tensor * up,
Expand Down
Loading