diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 8d6b0a97a02..0eaf5b0dfec 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2223,6 +2223,9 @@ def set_gguf_parameters(self): self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size", "vt_intermediate_size"])) self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys)) self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads", "heads", "vt_num_attention_heads"])) + n_kv = self.find_vparam(["num_key_value_heads"], optional=True) + if n_kv is not None: + self.gguf_writer.add_vision_head_count_kv(n_kv) # preprocessor config image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"] @@ -10435,6 +10438,128 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed experts: {experts}") +@ModelBase.register("Exaone4_5_ForConditionalGeneration") +class Exaone4_5_VLTextModel(Exaone4Model): + """Text tower of EXAONE 4.5; Tensors match EXAONE4""" + + model_arch = gguf.MODEL_ARCH.EXAONE4 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0) or 0) + if n_nextn > 0: + self.block_count = self.hparams["num_hidden_layers"] + n_nextn + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0) or 0) + if n_nextn > 0: + self.gguf_writer.add_nextn_predict_layers(n_nextn) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("model.visual."): + return + if name.startswith("model.language_model."): + name = name.replace("model.language_model.", "model.", 1) + + if name.startswith("mtp."): + n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0) or 0) + if n_nextn <= 0: + return + nh = self.hparams["num_hidden_layers"] + if ".layers." in name: + share = self.hparams.get("mtp_share_layers", False) + mtp_bid = bid if bid is not None else 0 + if share: + for k in range(n_nextn): + nn = name.replace(f"mtp.layers.{mtp_bid}", f"model.layers.{nh + k}") + yield from super().modify_tensors(data_torch, nn, nh + k) + return + name = name.replace(f"mtp.layers.{mtp_bid}", f"model.layers.{mtp_bid + nh}") + else: + remapper = { + "mtp.fc": "model.layers.{bid}.eh_proj", + "mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm", + "mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm", + "mtp.norm": "model.layers.{bid}.shared_head.norm", + } + _n = Path(name) + key = _n.stem + if key not in remapper: + return + new_name = remapper[key] + _n.suffix + for bid_mtp in range(nh, self.block_count): + yield from super().modify_tensors(data_torch, new_name.format(bid=bid_mtp), bid_mtp) + return + + yield from super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("Exaone4_5_ForConditionalGeneration") +class Exaone4_5VLVisionModel(Qwen2VLVisionModel): + """Vision tower for EXAONE 4.5; Qwen2-VL-style ViT (GQA) + patch merger""" + + def set_gguf_parameters(self): + MmprojModel.set_gguf_parameters(self) + assert self.hparams_vision is not None + hparams = self.hparams_vision + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.EXAONE4_5) + self.gguf_writer.add_vision_use_silu(True) + eps = hparams.get("rms_norm_eps", self.global_config.get("rms_norm_eps", 1e-6)) + self.gguf_writer.add_vision_attention_layernorm_eps(eps) + if (window_size := hparams.get("window_size")) is not None: + self.gguf_writer.add_vision_window_size(window_size) + fullatt_block_indexes = hparams.get("fullatt_block_indexes") + if fullatt_block_indexes: + n_wa_pattern = fullatt_block_indexes[0] + 1 + for i in range(1, len(fullatt_block_indexes)): + if fullatt_block_indexes[i] - fullatt_block_indexes[i - 1] != n_wa_pattern: + raise ValueError(f"Invalid EXAONE4.5 fullatt_block_indexes: {fullatt_block_indexes}") + self.gguf_writer.add_vision_n_wa_pattern(n_wa_pattern) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name.startswith("model.language_model.") or name.startswith("lm_head."): + return + if name.startswith("mtp."): + return + if name.startswith("model.visual."): + name = name.replace("model.visual.", "visual.", 1) + + # blueprint_exaone4_5: ViT uses GQA (HF fused qkv = q_dim + kv_dim + kv_dim), not MQA / equal thirds. + if name.startswith("visual.") and ".qkv." in name: + assert self.hparams_vision is not None + hv = self.hparams_vision + n_heads = hv["num_heads"] + n_kv = int(hv.get("num_key_value_heads", n_heads)) + hidden = hv["hidden_size"] + head_dim = hidden // n_heads + q_dim = n_heads * head_dim + kv_dim = n_kv * head_dim + total_out = q_dim + 2 * kv_dim + out_dim = data_torch.shape[0] + if out_dim != total_out: + raise ValueError(f"EXAONE 4.5 vision qkv out dim mismatch: got {out_dim}, expected {total_out} ({name})") + wq = data_torch[:q_dim] + wk = data_torch[q_dim : q_dim + kv_dim] + wv = data_torch[q_dim + kv_dim :] + nq = name.replace("qkv", "q", 1) + nk = name.replace("qkv", "k", 1) + nv = name.replace("qkv", "v", 1) + yield from ModelBase.modify_tensors(self, wq, nq, bid) + yield from ModelBase.modify_tensors(self, wk, nk, bid) + yield from ModelBase.modify_tensors(self, wv, nv, bid) + return + + # EXAONE4.5 PatchMerger includes ln_q (RMSNorm), but generic Qwen2-VL mapping can miss it. + # Keep explicit mapping to mm.input_norm for activation-level parity with HF. + if name == "visual.merger.ln_q.weight": + yield ("mm.input_norm.weight", data_torch) + return + + yield from Qwen2VLVisionModel.modify_tensors(self, data_torch, name, bid) + + @ModelBase.register("GraniteForCausalLM") class GraniteModel(LlamaModel): """Conversion for IBM's GraniteForCausalLM""" @@ -13118,7 +13243,9 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st # if "architectures" is found in the sub-config, use that instead if model_type == ModelType.TEXT and text_config.get("architectures") is not None: - arch = text_config["architectures"][0] + # Multimodal EXAONE 4.5 stores the inner causal LM class in text_config; HF→GGUF must use the VL root arch. + if hparams.get("model_type") != "exaone4_5": + arch = text_config["architectures"][0] elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None: arch = vision_config["architectures"][0] if arch is None: diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 53ce138fce8..273a1e5914a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -321,6 +321,7 @@ class ClipVision: class Attention: HEAD_COUNT = "clip.vision.attention.head_count" + HEAD_COUNT_KV = "clip.vision.attention.head_count_kv" LAYERNORM_EPS = "clip.vision.attention.layer_norm_epsilon" class Projector: @@ -456,6 +457,7 @@ class MODEL_ARCH(IntEnum): EXAONE = auto() EXAONE4 = auto() EXAONE_MOE = auto() + EXAONE4_5 = auto() GRANITE = auto() GRANITE_MOE = auto() GRANITE_HYBRID = auto() @@ -939,6 +941,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.EXAONE: "exaone", MODEL_ARCH.EXAONE4: "exaone4", MODEL_ARCH.EXAONE_MOE: "exaone-moe", + MODEL_ARCH.EXAONE4_5: "exaone4_5", MODEL_ARCH.GRANITE: "granite", MODEL_ARCH.GRANITE_MOE: "granitemoe", MODEL_ARCH.GRANITE_HYBRID: "granitehybrid", @@ -3137,6 +3140,13 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_POST_NORM, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], MODEL_ARCH.EXAONE_MOE: [ MODEL_TENSOR.TOKEN_EMBD, @@ -3170,6 +3180,30 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], + MODEL_ARCH.EXAONE4_5: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_POST_NORM, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, + ], MODEL_ARCH.GRANITE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -4107,6 +4141,7 @@ class VisionProjectorType: LLAMA4 = "llama4" QWEN2VL = "qwen2vl_merger" QWEN25VL = "qwen2.5vl_merger" + EXAONE4_5 = "exaone4_5" QWEN3VL = "qwen3vl_merger" STEP3VL = "step3vl" ULTRAVOX = "ultravox" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 90d500dc771..2ff4af15c1f 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1145,6 +1145,9 @@ def add_vision_block_count(self, value: int) -> None: def add_vision_head_count(self, value: int) -> None: self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT, value) + def add_vision_head_count_kv(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.Attention.HEAD_COUNT_KV, value) + def add_vision_attention_layernorm_eps(self, value: float) -> None: self.add_float32(Keys.ClipVision.Attention.LAYERNORM_EPS, value) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 82af6b6bee3..3ef996ac50d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2248,7 +2248,13 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_EXAONE4: { - if (hparams.n_layer == 64) { // 32B + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + const uint32_t n_layer_text = hparams.n_layer > hparams.nextn_predict_layers + ? hparams.n_layer - hparams.nextn_predict_layers + : hparams.n_layer; + + if (n_layer_text == 64) { // 32B transformer stack (MTP blocks excluded) hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; uint32_t swa_period = 4; @@ -2263,7 +2269,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { + switch (n_layer_text) { case 30: type = LLM_TYPE_1_2B; break; case 64: type = LLM_TYPE_32B; break; default: type = LLM_TYPE_UNKNOWN; @@ -6188,23 +6194,38 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + flags |= TENSOR_SKIP; + } + auto & layer = layers[i]; - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, flags); layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags); + + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags); + + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); + } } } break; case LLM_ARCH_EXAONE_MOE: diff --git a/src/models/exaone4.cpp b/src/models/exaone4.cpp index 755af3b747b..d78956e9115 100644 --- a/src/models/exaone4.cpp +++ b/src/models/exaone4.cpp @@ -27,7 +27,11 @@ llm_build_exaone4::llm_build_exaone4(const llama_model & model, const llm_ } ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // MTP / NextN tail blocks are loaded for compatibility but not executed (same as exaone-moe). + const int n_layer_main = int(n_layer) - int(hparams.nextn_predict_layers); + GGML_ASSERT(n_layer_main > 0); + + for (int il = 0; il < n_layer_main; ++il) { ggml_tensor * inpSA = inpL; // use RoPE for SWA layers or non-SWA models @@ -73,7 +77,7 @@ llm_build_exaone4::llm_build_exaone4(const llama_model & model, const llm_ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_layer_main - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 6a4267d2e1d..2006a969142 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -18,6 +18,7 @@ add_library(mtmd models/cogvlm.cpp models/conformer.cpp models/dotsocr.cpp + models/exaone4_5.cpp models/gemma4v.cpp models/glm4v.cpp models/hunyuanocr.cpp diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index c812e6c4b5d..2128d9f8abe 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -47,6 +47,7 @@ #define KEY_FEATURE_LAYER "clip.vision.feature_layer" #define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor" #define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size" +#define KEY_N_HEAD_KV "clip.vision.attention.head_count_kv" #define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers" #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" @@ -274,6 +275,7 @@ enum projector_type { PROJECTOR_TYPE_KIMIK25, PROJECTOR_TYPE_NEMOTRON_V2_VL, PROJECTOR_TYPE_HUNYUANOCR, + PROJECTOR_TYPE_EXAONE4_5, PROJECTOR_TYPE_UNKNOWN, }; @@ -317,6 +319,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_KIMIK25, "kimik25"}, { PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"}, { PROJECTOR_TYPE_HUNYUANOCR, "hunyuanocr"}, + { PROJECTOR_TYPE_EXAONE4_5, "exaone4_5"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index b2cd27dcbf7..59d723170ea 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -42,6 +42,7 @@ struct clip_hparams { int32_t n_ff = 0; int32_t projection_dim = 0; int32_t n_head = 0; + int32_t n_kv_head = 0; int32_t n_layer = 0; // idefics3 int32_t n_merge = 0; // number of patch merges **per-side** diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index b947a4183ed..27cde002e6d 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -866,6 +866,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_EXAONE4_5: + { + builder = std::make_unique(ctx, img); + } break; case PROJECTOR_TYPE_STEP3VL: { builder = std::make_unique(ctx, img); @@ -1342,7 +1346,7 @@ struct clip_model_loader { hparams.n_merge = 2; // default value for Qwen 2 and 2.5 hparams.image_resize_algo = RESIZE_ALGO_BILINEAR; get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); - get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern, model.proj_type == PROJECTOR_TYPE_QWEN25VL); // only 2.5 requires it + get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern, model.proj_type == PROJECTOR_TYPE_QWEN25VL); // ref: https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct/blob/main/preprocessor_config.json hparams.set_limit_image_tokens(8, 4096); hparams.set_warmup_n_tokens(46*46); // avoid OOM on warmup @@ -1456,6 +1460,21 @@ struct clip_model_loader { hparams.audio_window_len = 400; hparams.audio_hop_len = 160; } break; + case PROJECTOR_TYPE_EXAONE4_5: + { + hparams.n_merge = 2; + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); + get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern, false); + hparams.set_limit_image_tokens(8, 4096); + hparams.set_warmup_n_tokens(46 * 46); + if (hparams.rope_theta <= 0.0f) { + hparams.rope_theta = 10000.0f; + } + get_u32(KEY_N_HEAD_KV, hparams.n_kv_head, false); + if (hparams.n_kv_head <= 0) { + hparams.n_kv_head = 8; + } + } break; case PROJECTOR_TYPE_JANUS_PRO: { hparams.image_pad_color = {127, 127, 127}; @@ -1663,6 +1682,7 @@ struct clip_model_loader { || model.proj_type == PROJECTOR_TYPE_LDPV2 || model.proj_type == PROJECTOR_TYPE_QWEN2VL || model.proj_type == PROJECTOR_TYPE_QWEN25VL + || model.proj_type == PROJECTOR_TYPE_EXAONE4_5 || model.proj_type == PROJECTOR_TYPE_GLM_EDGE || model.proj_type == PROJECTOR_TYPE_GEMMA3 || model.proj_type == PROJECTOR_TYPE_IDEFICS3 @@ -1783,7 +1803,11 @@ struct clip_model_loader { } break; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: + case PROJECTOR_TYPE_EXAONE4_5: { + if (model.proj_type == PROJECTOR_TYPE_EXAONE4_5) { + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); + } model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); @@ -2656,6 +2680,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_EXAONE4_5: case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_PADDLEOCR: case PROJECTOR_TYPE_HUNYUANOCR: @@ -2676,6 +2701,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_EXAONE4_5: case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_PADDLEOCR: case PROJECTOR_TYPE_YOUTUVL: @@ -2744,6 +2770,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: + case PROJECTOR_TYPE_EXAONE4_5: case PROJECTOR_TYPE_GLM4V: case PROJECTOR_TYPE_YOUTUVL: { @@ -3120,11 +3147,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima set_input_i32("positions", positions); } break; case PROJECTOR_TYPE_QWEN25VL: + case PROJECTOR_TYPE_EXAONE4_5: case PROJECTOR_TYPE_YOUTUVL: { // pw * ph = number of tokens output by ViT after apply patch merger // ipw * ipw = number of vision token been processed inside ViT - const bool use_window_attn = ctx->model.proj_type == PROJECTOR_TYPE_QWEN25VL ? hparams.n_wa_pattern > 0 : !hparams.wa_layer_indexes.empty(); + const bool use_window_attn = + (ctx->model.proj_type == PROJECTOR_TYPE_QWEN25VL || ctx->model.proj_type == PROJECTOR_TYPE_EXAONE4_5) + ? hparams.n_wa_pattern > 0 + : !hparams.wa_layer_indexes.empty(); const int merge_ratio = 2; const int pw = image_size_width / patch_size / merge_ratio; const int ph = image_size_height / patch_size / merge_ratio; @@ -3444,6 +3475,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_model_mlp_3_w->ne[1]; case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: + case PROJECTOR_TYPE_EXAONE4_5: case PROJECTOR_TYPE_JANUS_PRO: case PROJECTOR_TYPE_YOUTUVL: return ctx->model.mm_1_b->ne[0]; diff --git a/tools/mtmd/models/exaone4_5.cpp b/tools/mtmd/models/exaone4_5.cpp new file mode 100644 index 00000000000..ea226717611 --- /dev/null +++ b/tools/mtmd/models/exaone4_5.cpp @@ -0,0 +1,190 @@ +#include "models.h" + +static ggml_tensor * clip_repeat_kv_heads( + ggml_context * ctx, + ggml_tensor * cur, + int64_t d_head, + int64_t n_kv_head, + int64_t n_head, + int64_t n_tok) { + GGML_ASSERT(n_head % n_kv_head == 0); + const int64_t n_rep = n_head / n_kv_head; + if (n_rep == 1) { + return cur; + } + // Match PyTorch repeat_interleave(dim=head): + // [d, n_kv, n_tok] -> [d, 1, n_kv, n_tok] -> repeat on dim1 -> [d, n_rep, n_kv, n_tok] + // flatten -> [d, n_head, n_tok] with head order [kv0 x rep, kv1 x rep, ...]. + cur = ggml_reshape_4d(ctx, cur, d_head, 1, n_kv_head, n_tok); + cur = ggml_repeat_4d(ctx, cur, d_head, n_rep, n_kv_head, n_tok); + cur = ggml_reshape_3d(ctx, cur, d_head, n_head, n_tok); + return cur; +} + +ggml_cgraph * clip_graph_exaone4_5::build() { + GGML_ASSERT(model.patch_bias == nullptr); + GGML_ASSERT(model.class_embedding == nullptr); + + const int batch_size = 1; + const bool use_window_attn = hparams.n_wa_pattern > 0; + const int n_wa_pattern = hparams.n_wa_pattern; + const int n_pos = n_patches; + const int num_position_ids = n_pos * 4; + + const norm_type norm_t = NORM_TYPE_RMS; + + const int64_t n_kv_head = hparams.n_kv_head > 0 ? hparams.n_kv_head : n_head; + GGML_ASSERT(n_head % n_kv_head == 0); + + int rope_sections[4] = { d_head / 4, d_head / 4, d_head / 4, d_head / 4 }; + const float rope_freq_base = hparams.rope_theta > 0.0f ? hparams.rope_theta : 10000.0f; + + ggml_tensor * inp_raw = build_inp_raw(); + ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + + GGML_ASSERT(img.nx % (patch_size * 2) == 0); + GGML_ASSERT(img.ny % (patch_size * 2) == 0); + + { + ggml_tensor * inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1); + inp = ggml_add(ctx0, inp, inp_1); + inp = ggml_permute(ctx0, inp, 1, 2, 0, 3); + inp = ggml_cont_4d( + ctx0, inp, + n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); + inp = ggml_reshape_4d( + ctx0, inp, + n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_3d( + ctx0, inp, + n_embd, n_patches_x * n_patches_y, batch_size); + } + + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + ggml_tensor * inpL = build_norm(inp, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1); + ggml_tensor * window_mask = nullptr; + ggml_tensor * window_idx = nullptr; + ggml_tensor * inv_window_idx = nullptr; + + if (use_window_attn) { + inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4); + ggml_set_name(inv_window_idx, "inv_window_idx"); + ggml_set_input(inv_window_idx); + + window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos); + ggml_set_name(window_mask, "window_mask"); + ggml_set_input(window_mask); + + if (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { + window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16); + } + + GGML_ASSERT(batch_size == 1); + inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4); + inpL = ggml_get_rows(ctx0, inpL, inv_window_idx); + inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size); + } + + for (int il = 0; il < n_layer; il++) { + const auto & layer = model.layers[il]; + const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true; + ggml_tensor * cur = inpL; + + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il); + cb(cur, "ln1", il); + + { + ggml_tensor * Qcur = build_mm(layer.q_w, cur); + ggml_tensor * Kcur = build_mm(layer.k_w, cur); + ggml_tensor * Vcur = build_mm(layer.v_w, cur); + if (layer.q_b) { + Qcur = ggml_add(ctx0, Qcur, layer.q_b); + } + if (layer.k_b) { + Kcur = ggml_add(ctx0, Kcur, layer.k_b); + } + if (layer.v_b) { + Vcur = ggml_add(ctx0, Vcur, layer.v_b); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches); + Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_kv_head, n_patches); + Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_kv_head, n_patches); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Kcur = clip_repeat_kv_heads(ctx0, Kcur, d_head, n_kv_head, n_head, n_patches); + Vcur = clip_repeat_kv_heads(ctx0, Vcur, d_head, n_kv_head, n_head, n_patches); + + Qcur = ggml_rope_multi( + ctx0, Qcur, positions, nullptr, + d_head / 2, rope_sections, GGML_ROPE_TYPE_VISION, 32768, rope_freq_base, 1, 0, 1, 32, 1); + Kcur = ggml_rope_multi( + ctx0, Kcur, positions, nullptr, + d_head / 2, rope_sections, GGML_ROPE_TYPE_VISION, 32768, rope_freq_base, 1, 0, 1, 32, 1); + + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + cb(Vcur, "Vcur_rep", il); + + ggml_tensor * attn_mask = full_attn ? nullptr : window_mask; + cur = build_attn(layer.o_w, layer.o_b, Qcur, Kcur, Vcur, attn_mask, kq_scale, il); + cb(cur, "attn_out", il); + } + + cur = ggml_add(ctx0, cur, inpL); + inpL = cur; + + cb(cur, "ffn_inp", il); + + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il); + cb(cur, "ffn_inp_normed", il); + + cur = build_ffn(cur, + layer.ff_up_w, layer.ff_up_b, + layer.ff_gate_w, layer.ff_gate_b, + layer.ff_down_w, layer.ff_down_b, + hparams.ffn_op, il); + + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, inpL, cur); + cb(cur, "layer_out", il); + + inpL = cur; + } + + ggml_tensor * embeddings = inpL; + // EXAONE4.5 merger follows HF PatchMerger: ln_q over context_dim before 2x2 flatten. + if (model.mm_input_norm_w) { + embeddings = build_norm(embeddings, model.mm_input_norm_w, nullptr, NORM_TYPE_RMS, eps, -1); + } + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size); + embeddings = build_ffn(embeddings, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, + model.mm_1_w, model.mm_1_b, + FFN_GELU, + -1); + + if (use_window_attn) { + window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4); + ggml_set_name(window_idx, "window_idx"); + ggml_set_input(window_idx); + + GGML_ASSERT(batch_size == 1); + embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, n_patches_x * n_patches_y / 4); + embeddings = ggml_get_rows(ctx0, embeddings, window_idx); + embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, n_patches_x * n_patches_y / 4, batch_size); + } + + ggml_build_forward_expand(gf, embeddings); + + return gf; +} diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 5f5b76040de..ed1e516471c 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -152,3 +152,8 @@ struct clip_graph_kimik25 : clip_graph { ggml_tensor * resize_position_embeddings_3d(uint32_t interpolation_mode); }; + +struct clip_graph_exaone4_5 : clip_graph { + clip_graph_exaone4_5(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 41c5211375b..fa4d92a606f 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -436,6 +436,13 @@ struct mtmd_context { img_end = "<|hy_place▁holder▁no▁101|>"; image_preproc = std::make_unique(ctx_v); } break; + case PROJECTOR_TYPE_EXAONE4_5: + { + // ... (image embeddings) ... + img_beg = ""; + img_end = ""; + image_preproc = std::make_unique(ctx_v); + } break; default: throw std::runtime_error(string_format("%s: unexpected vision projector type %d\n", __func__, proj)); }