Skip to content

Add EXAONE 4.5 implementations#21733

Open
nuxlear wants to merge 5 commits intoggml-org:masterfrom
nuxlear:add-exaone4_5
Open

Add EXAONE 4.5 implementations#21733
nuxlear wants to merge 5 commits intoggml-org:masterfrom
nuxlear:add-exaone4_5

Conversation

@nuxlear
Copy link
Copy Markdown
Contributor

@nuxlear nuxlear commented Apr 10, 2026

Overview

Add support for the EXAONE 4.5 architecture for the EXAONE 4.5 model released by LG AI Research.

Additional information

This PR adds the modeling code for EXAONE 4.5, which uses the same LLM architecture as EXAONE 4.
It also adds n_kv_heads to the CLIP model to make the ViT compatible with the GQA structure.

Requirements

  • I have read and agree with the contributing guidelines
  • AI usage disclosure:
    YES. The modeling code was implemented with the help of an AI assistant.

EXAONE 4.5 uses <vision> and </vision> for image boundaries; Qwen keeps
<|vision_start|> and <|vision_end|>.

Route EXAONE 4.5 through the Qwen2.5-VL-style encode path (window attention
pattern, optional mmproj input norm). Update exaone4_5 projector weights and
convert_hf_to_gguf for mmproj export.
Align EXAONE4 tensor registration with EXAONE_MOE for NextN/MTP slots and avoid skip-flag propagation on duplicated rope_freqs so model loading succeeds for EXAONE 4.5 GGUF.
@nuxlear nuxlear requested review from a team and CISC as code owners April 10, 2026 16:02
@ngxson
Copy link
Copy Markdown
Contributor

ngxson commented Apr 10, 2026

please rebase before requesting a review

@ngxson ngxson marked this pull request as draft April 10, 2026 16:06
@github-actions github-actions bot added model Model specific examples python python script changes labels Apr 10, 2026
@nuxlear nuxlear marked this pull request as ready for review April 10, 2026 20:00
Comment on lines +122 to +123
Kcur = clip_repeat_kv_heads(ctx0, Kcur, d_head, n_kv_head, n_head, n_patches);
Vcur = clip_repeat_kv_heads(ctx0, Vcur, d_head, n_kv_head, n_head, n_patches);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be redundant, ggml support broadcasting automatically

Comment on lines +177 to +180
window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4);
ggml_set_name(window_idx, "window_idx");
ggml_set_input(window_idx);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to top. any inputs should be defined on top of cgraph

hparams.n_merge = 2;
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false);
get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern, false);
hparams.set_limit_image_tokens(8, 4096);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_limit_image_tokens is only used by old model. if you already know the min/max pixels supported, write it to GGUF metadata instead

Comment on lines +1473 to +1476
get_u32(KEY_N_HEAD_KV, hparams.n_kv_head, false);
if (hparams.n_kv_head <= 0) {
hparams.n_kv_head = 8;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see now reasons for this being an optional param. GGUF must have it, otherwise it's a faulty file

Suggested change
get_u32(KEY_N_HEAD_KV, hparams.n_kv_head, false);
if (hparams.n_kv_head <= 0) {
hparams.n_kv_head = 8;
}
get_u32(KEY_N_HEAD_KV, hparams.n_kv_head);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please specify exactly how this is different from qwen2, so that we can merge 2 models into one file in the future

Comment on lines +10481 to +10485
remapper = {
"mtp.fc": "model.layers.{bid}.eh_proj",
"mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm",
"mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm",
"mtp.norm": "model.layers.{bid}.shared_head.norm",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use proper tensor_mapping

Comment on lines +10530 to +10551
if name.startswith("visual.") and ".qkv." in name:
assert self.hparams_vision is not None
hv = self.hparams_vision
n_heads = hv["num_heads"]
n_kv = int(hv.get("num_key_value_heads", n_heads))
hidden = hv["hidden_size"]
head_dim = hidden // n_heads
q_dim = n_heads * head_dim
kv_dim = n_kv * head_dim
total_out = q_dim + 2 * kv_dim
out_dim = data_torch.shape[0]
if out_dim != total_out:
raise ValueError(f"EXAONE 4.5 vision qkv out dim mismatch: got {out_dim}, expected {total_out} ({name})")
wq = data_torch[:q_dim]
wk = data_torch[q_dim : q_dim + kv_dim]
wv = data_torch[q_dim + kv_dim :]
nq = name.replace("qkv", "q", 1)
nk = name.replace("qkv", "k", 1)
nv = name.replace("qkv", "v", 1)
yield from ModelBase.modify_tensors(self, wq, nq, bid)
yield from ModelBase.modify_tensors(self, wk, nk, bid)
yield from ModelBase.modify_tensors(self, wv, nv, bid)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should not split qkv, instead, use ggml_view to split the result (not the weight)

see build_vit in clip.cpp for an example

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples model Model specific python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants