Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 128 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2223,6 +2223,9 @@ def set_gguf_parameters(self):
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size", "vt_intermediate_size"]))
self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys))
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads", "heads", "vt_num_attention_heads"]))
n_kv = self.find_vparam(["num_key_value_heads"], optional=True)
if n_kv is not None:
self.gguf_writer.add_vision_head_count_kv(n_kv)

# preprocessor config
image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
Expand Down Expand Up @@ -10435,6 +10438,128 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed experts: {experts}")


@ModelBase.register("Exaone4_5_ForConditionalGeneration")
class Exaone4_5_VLTextModel(Exaone4Model):
"""Text tower of EXAONE 4.5; Tensors match EXAONE4"""

model_arch = gguf.MODEL_ARCH.EXAONE4

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0) or 0)
if n_nextn > 0:
self.block_count = self.hparams["num_hidden_layers"] + n_nextn
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)

def set_gguf_parameters(self):
super().set_gguf_parameters()
n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0) or 0)
if n_nextn > 0:
self.gguf_writer.add_nextn_predict_layers(n_nextn)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("model.visual."):
return
if name.startswith("model.language_model."):
name = name.replace("model.language_model.", "model.", 1)

if name.startswith("mtp."):
n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0) or 0)
if n_nextn <= 0:
return
nh = self.hparams["num_hidden_layers"]
if ".layers." in name:
share = self.hparams.get("mtp_share_layers", False)
mtp_bid = bid if bid is not None else 0
if share:
for k in range(n_nextn):
nn = name.replace(f"mtp.layers.{mtp_bid}", f"model.layers.{nh + k}")
yield from super().modify_tensors(data_torch, nn, nh + k)
return
name = name.replace(f"mtp.layers.{mtp_bid}", f"model.layers.{mtp_bid + nh}")
else:
remapper = {
"mtp.fc": "model.layers.{bid}.eh_proj",
"mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm",
"mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm",
"mtp.norm": "model.layers.{bid}.shared_head.norm",
Comment on lines +10481 to +10485
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use proper tensor_mapping

}
_n = Path(name)
key = _n.stem
if key not in remapper:
return
new_name = remapper[key] + _n.suffix
for bid_mtp in range(nh, self.block_count):
yield from super().modify_tensors(data_torch, new_name.format(bid=bid_mtp), bid_mtp)
return

yield from super().modify_tensors(data_torch, name, bid)


@ModelBase.register("Exaone4_5_ForConditionalGeneration")
class Exaone4_5VLVisionModel(Qwen2VLVisionModel):
"""Vision tower for EXAONE 4.5; Qwen2-VL-style ViT (GQA) + patch merger"""

def set_gguf_parameters(self):
MmprojModel.set_gguf_parameters(self)
assert self.hparams_vision is not None
hparams = self.hparams_vision
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.EXAONE4_5)
self.gguf_writer.add_vision_use_silu(True)
eps = hparams.get("rms_norm_eps", self.global_config.get("rms_norm_eps", 1e-6))
self.gguf_writer.add_vision_attention_layernorm_eps(eps)
if (window_size := hparams.get("window_size")) is not None:
self.gguf_writer.add_vision_window_size(window_size)
fullatt_block_indexes = hparams.get("fullatt_block_indexes")
if fullatt_block_indexes:
n_wa_pattern = fullatt_block_indexes[0] + 1
for i in range(1, len(fullatt_block_indexes)):
if fullatt_block_indexes[i] - fullatt_block_indexes[i - 1] != n_wa_pattern:
raise ValueError(f"Invalid EXAONE4.5 fullatt_block_indexes: {fullatt_block_indexes}")
self.gguf_writer.add_vision_n_wa_pattern(n_wa_pattern)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("model.language_model.") or name.startswith("lm_head."):
return
if name.startswith("mtp."):
return
if name.startswith("model.visual."):
name = name.replace("model.visual.", "visual.", 1)

# blueprint_exaone4_5: ViT uses GQA (HF fused qkv = q_dim + kv_dim + kv_dim), not MQA / equal thirds.
if name.startswith("visual.") and ".qkv." in name:
assert self.hparams_vision is not None
hv = self.hparams_vision
n_heads = hv["num_heads"]
n_kv = int(hv.get("num_key_value_heads", n_heads))
hidden = hv["hidden_size"]
head_dim = hidden // n_heads
q_dim = n_heads * head_dim
kv_dim = n_kv * head_dim
total_out = q_dim + 2 * kv_dim
out_dim = data_torch.shape[0]
if out_dim != total_out:
raise ValueError(f"EXAONE 4.5 vision qkv out dim mismatch: got {out_dim}, expected {total_out} ({name})")
wq = data_torch[:q_dim]
wk = data_torch[q_dim : q_dim + kv_dim]
wv = data_torch[q_dim + kv_dim :]
nq = name.replace("qkv", "q", 1)
nk = name.replace("qkv", "k", 1)
nv = name.replace("qkv", "v", 1)
yield from ModelBase.modify_tensors(self, wq, nq, bid)
yield from ModelBase.modify_tensors(self, wk, nk, bid)
yield from ModelBase.modify_tensors(self, wv, nv, bid)
Comment on lines +10530 to +10551
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should not split qkv, instead, use ggml_view to split the result (not the weight)

see build_vit in clip.cpp for an example

return

# EXAONE4.5 PatchMerger includes ln_q (RMSNorm), but generic Qwen2-VL mapping can miss it.
# Keep explicit mapping to mm.input_norm for activation-level parity with HF.
if name == "visual.merger.ln_q.weight":
yield ("mm.input_norm.weight", data_torch)
return

yield from Qwen2VLVisionModel.modify_tensors(self, data_torch, name, bid)


@ModelBase.register("GraniteForCausalLM")
class GraniteModel(LlamaModel):
"""Conversion for IBM's GraniteForCausalLM"""
Expand Down Expand Up @@ -13118,7 +13243,9 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st

# if "architectures" is found in the sub-config, use that instead
if model_type == ModelType.TEXT and text_config.get("architectures") is not None:
arch = text_config["architectures"][0]
# Multimodal EXAONE 4.5 stores the inner causal LM class in text_config; HF→GGUF must use the VL root arch.
if hparams.get("model_type") != "exaone4_5":
arch = text_config["architectures"][0]
elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None:
arch = vision_config["architectures"][0]
if arch is None:
Expand Down
35 changes: 35 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ class ClipVision:

class Attention:
HEAD_COUNT = "clip.vision.attention.head_count"
HEAD_COUNT_KV = "clip.vision.attention.head_count_kv"
LAYERNORM_EPS = "clip.vision.attention.layer_norm_epsilon"

class Projector:
Expand Down Expand Up @@ -456,6 +457,7 @@ class MODEL_ARCH(IntEnum):
EXAONE = auto()
EXAONE4 = auto()
EXAONE_MOE = auto()
EXAONE4_5 = auto()
GRANITE = auto()
GRANITE_MOE = auto()
GRANITE_HYBRID = auto()
Expand Down Expand Up @@ -939,6 +941,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.EXAONE: "exaone",
MODEL_ARCH.EXAONE4: "exaone4",
MODEL_ARCH.EXAONE_MOE: "exaone-moe",
MODEL_ARCH.EXAONE4_5: "exaone4_5",
MODEL_ARCH.GRANITE: "granite",
MODEL_ARCH.GRANITE_MOE: "granitemoe",
MODEL_ARCH.GRANITE_HYBRID: "granitehybrid",
Expand Down Expand Up @@ -3137,6 +3140,13 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_POST_NORM,
# NextN/MTP tensors - preserved but unused
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.EXAONE_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
Expand Down Expand Up @@ -3170,6 +3180,30 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.EXAONE4_5: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_POST_NORM,
# NextN/MTP tensors - preserved but unused
MODEL_TENSOR.NEXTN_EH_PROJ,
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
MODEL_TENSOR.NEXTN_ENORM,
MODEL_TENSOR.NEXTN_HNORM,
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
],
MODEL_ARCH.GRANITE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down Expand Up @@ -4107,6 +4141,7 @@ class VisionProjectorType:
LLAMA4 = "llama4"
QWEN2VL = "qwen2vl_merger"
QWEN25VL = "qwen2.5vl_merger"
EXAONE4_5 = "exaone4_5"
QWEN3VL = "qwen3vl_merger"
STEP3VL = "step3vl"
ULTRAVOX = "ultravox"
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,9 @@ def add_vision_block_count(self, value: int) -> None:
def add_vision_head_count(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value)

def add_vision_head_count_kv(self, value: int) -> None:
self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT_KV, value)

def add_vision_attention_layernorm_eps(self, value: float) -> None:
self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value)

Expand Down
47 changes: 34 additions & 13 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2248,7 +2248,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
} break;
case LLM_ARCH_EXAONE4:
{
if (hparams.n_layer == 64) { // 32B
ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false);

const uint32_t n_layer_text = hparams.n_layer > hparams.nextn_predict_layers
? hparams.n_layer - hparams.nextn_predict_layers
: hparams.n_layer;

if (n_layer_text == 64) { // 32B transformer stack (MTP blocks excluded)
hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
hparams.n_swa = 4096;
uint32_t swa_period = 4;
Expand All @@ -2263,7 +2269,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);

switch (hparams.n_layer) {
switch (n_layer_text) {
case 30: type = LLM_TYPE_1_2B; break;
case 64: type = LLM_TYPE_32B; break;
default: type = LLM_TYPE_UNKNOWN;
Expand Down Expand Up @@ -6188,23 +6194,38 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
}

for (int i = 0; i < n_layer; ++i) {
int flags = 0;
if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
flags |= TENSOR_SKIP;
}

auto & layer = layers[i];

layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, flags);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, flags);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, flags);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, flags);

layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));

layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, flags);
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags);

layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags);
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags);

if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) {
layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags);
layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags);
layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags);

layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED);
layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED);
layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED);
}
}
} break;
case LLM_ARCH_EXAONE_MOE:
Expand Down
8 changes: 6 additions & 2 deletions src/models/exaone4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ llm_build_exaone4<iswa>::llm_build_exaone4(const llama_model & model, const llm_
}
ggml_tensor * inp_out_ids = build_inp_out_ids();

for (int il = 0; il < n_layer; ++il) {
// MTP / NextN tail blocks are loaded for compatibility but not executed (same as exaone-moe).
const int n_layer_main = int(n_layer) - int(hparams.nextn_predict_layers);
GGML_ASSERT(n_layer_main > 0);

for (int il = 0; il < n_layer_main; ++il) {
ggml_tensor * inpSA = inpL;

// use RoPE for SWA layers or non-SWA models
Expand Down Expand Up @@ -73,7 +77,7 @@ llm_build_exaone4<iswa>::llm_build_exaone4(const llama_model & model, const llm_
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
cb(cur, "attn_out", il);
}
if (il == n_layer - 1 && inp_out_ids) {
if (il == n_layer_main - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
Expand Down
1 change: 1 addition & 0 deletions tools/mtmd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_library(mtmd
models/cogvlm.cpp
models/conformer.cpp
models/dotsocr.cpp
models/exaone4_5.cpp
models/gemma4v.cpp
models/glm4v.cpp
models/hunyuanocr.cpp
Expand Down
3 changes: 3 additions & 0 deletions tools/mtmd/clip-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#define KEY_FEATURE_LAYER "clip.vision.feature_layer"
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
#define KEY_N_HEAD_KV "clip.vision.attention.head_count_kv"
#define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers"

#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
Expand Down Expand Up @@ -274,6 +275,7 @@ enum projector_type {
PROJECTOR_TYPE_KIMIK25,
PROJECTOR_TYPE_NEMOTRON_V2_VL,
PROJECTOR_TYPE_HUNYUANOCR,
PROJECTOR_TYPE_EXAONE4_5,
PROJECTOR_TYPE_UNKNOWN,
};

Expand Down Expand Up @@ -317,6 +319,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_KIMIK25, "kimik25"},
{ PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"},
{ PROJECTOR_TYPE_HUNYUANOCR, "hunyuanocr"},
{ PROJECTOR_TYPE_EXAONE4_5, "exaone4_5"},
};

static projector_type clip_projector_type_from_string(const std::string & str) {
Expand Down
1 change: 1 addition & 0 deletions tools/mtmd/clip-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ struct clip_hparams {
int32_t n_ff = 0;
int32_t projection_dim = 0;
int32_t n_head = 0;
int32_t n_kv_head = 0;
int32_t n_layer = 0;
// idefics3
int32_t n_merge = 0; // number of patch merges **per-side**
Expand Down
Loading