-
Notifications
You must be signed in to change notification settings - Fork 16.8k
Add EXAONE 4.5 implementations #21733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
nuxlear
wants to merge
5
commits into
ggml-org:master
Choose a base branch
from
nuxlear:add-exaone4_5
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
0b80839
Add EXAONE 4.5 and Add GQA for MMproj
lgai-exaone 3b12fcd
mtmd: EXAONE 4.5 vision markers and projector path
lgai-exaone d011393
mtmd: load EXAONE4 nextn tensors correctly
lgai-exaone f1e3ff2
Minor fixes
lgai-exaone 5ba14d5
Merge branch 'master' of https://github.com/ggerganov/llama.cpp into …
lgai-exaone File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2223,6 +2223,9 @@ def set_gguf_parameters(self): | |
| self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size", "vt_intermediate_size"])) | ||
| self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys)) | ||
| self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads", "heads", "vt_num_attention_heads"])) | ||
| n_kv = self.find_vparam(["num_key_value_heads"], optional=True) | ||
| if n_kv is not None: | ||
| self.gguf_writer.add_vision_head_count_kv(n_kv) | ||
|
|
||
| # preprocessor config | ||
| image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"] | ||
|
|
@@ -10435,6 +10438,128 @@ def prepare_tensors(self): | |
| raise ValueError(f"Unprocessed experts: {experts}") | ||
|
|
||
|
|
||
| @ModelBase.register("Exaone4_5_ForConditionalGeneration") | ||
| class Exaone4_5_VLTextModel(Exaone4Model): | ||
| """Text tower of EXAONE 4.5; Tensors match EXAONE4""" | ||
|
|
||
| model_arch = gguf.MODEL_ARCH.EXAONE4 | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
| n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0) or 0) | ||
| if n_nextn > 0: | ||
| self.block_count = self.hparams["num_hidden_layers"] + n_nextn | ||
| self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) | ||
|
|
||
| def set_gguf_parameters(self): | ||
| super().set_gguf_parameters() | ||
| n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0) or 0) | ||
| if n_nextn > 0: | ||
| self.gguf_writer.add_nextn_predict_layers(n_nextn) | ||
|
|
||
| def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||
| if name.startswith("model.visual."): | ||
| return | ||
| if name.startswith("model.language_model."): | ||
| name = name.replace("model.language_model.", "model.", 1) | ||
|
|
||
| if name.startswith("mtp."): | ||
| n_nextn = int(self.hparams.get("num_nextn_predict_layers", 0) or 0) | ||
| if n_nextn <= 0: | ||
| return | ||
| nh = self.hparams["num_hidden_layers"] | ||
| if ".layers." in name: | ||
| share = self.hparams.get("mtp_share_layers", False) | ||
| mtp_bid = bid if bid is not None else 0 | ||
| if share: | ||
| for k in range(n_nextn): | ||
| nn = name.replace(f"mtp.layers.{mtp_bid}", f"model.layers.{nh + k}") | ||
| yield from super().modify_tensors(data_torch, nn, nh + k) | ||
| return | ||
| name = name.replace(f"mtp.layers.{mtp_bid}", f"model.layers.{mtp_bid + nh}") | ||
| else: | ||
| remapper = { | ||
| "mtp.fc": "model.layers.{bid}.eh_proj", | ||
| "mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm", | ||
| "mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm", | ||
| "mtp.norm": "model.layers.{bid}.shared_head.norm", | ||
| } | ||
| _n = Path(name) | ||
| key = _n.stem | ||
| if key not in remapper: | ||
| return | ||
| new_name = remapper[key] + _n.suffix | ||
| for bid_mtp in range(nh, self.block_count): | ||
| yield from super().modify_tensors(data_torch, new_name.format(bid=bid_mtp), bid_mtp) | ||
| return | ||
|
|
||
| yield from super().modify_tensors(data_torch, name, bid) | ||
|
|
||
|
|
||
| @ModelBase.register("Exaone4_5_ForConditionalGeneration") | ||
| class Exaone4_5VLVisionModel(Qwen2VLVisionModel): | ||
| """Vision tower for EXAONE 4.5; Qwen2-VL-style ViT (GQA) + patch merger""" | ||
|
|
||
| def set_gguf_parameters(self): | ||
| MmprojModel.set_gguf_parameters(self) | ||
| assert self.hparams_vision is not None | ||
| hparams = self.hparams_vision | ||
| self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.EXAONE4_5) | ||
| self.gguf_writer.add_vision_use_silu(True) | ||
| eps = hparams.get("rms_norm_eps", self.global_config.get("rms_norm_eps", 1e-6)) | ||
| self.gguf_writer.add_vision_attention_layernorm_eps(eps) | ||
| if (window_size := hparams.get("window_size")) is not None: | ||
| self.gguf_writer.add_vision_window_size(window_size) | ||
| fullatt_block_indexes = hparams.get("fullatt_block_indexes") | ||
| if fullatt_block_indexes: | ||
| n_wa_pattern = fullatt_block_indexes[0] + 1 | ||
| for i in range(1, len(fullatt_block_indexes)): | ||
| if fullatt_block_indexes[i] - fullatt_block_indexes[i - 1] != n_wa_pattern: | ||
| raise ValueError(f"Invalid EXAONE4.5 fullatt_block_indexes: {fullatt_block_indexes}") | ||
| self.gguf_writer.add_vision_n_wa_pattern(n_wa_pattern) | ||
|
|
||
| def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: | ||
| if name.startswith("model.language_model.") or name.startswith("lm_head."): | ||
| return | ||
| if name.startswith("mtp."): | ||
| return | ||
| if name.startswith("model.visual."): | ||
| name = name.replace("model.visual.", "visual.", 1) | ||
|
|
||
| # blueprint_exaone4_5: ViT uses GQA (HF fused qkv = q_dim + kv_dim + kv_dim), not MQA / equal thirds. | ||
| if name.startswith("visual.") and ".qkv." in name: | ||
| assert self.hparams_vision is not None | ||
| hv = self.hparams_vision | ||
| n_heads = hv["num_heads"] | ||
| n_kv = int(hv.get("num_key_value_heads", n_heads)) | ||
| hidden = hv["hidden_size"] | ||
| head_dim = hidden // n_heads | ||
| q_dim = n_heads * head_dim | ||
| kv_dim = n_kv * head_dim | ||
| total_out = q_dim + 2 * kv_dim | ||
| out_dim = data_torch.shape[0] | ||
| if out_dim != total_out: | ||
| raise ValueError(f"EXAONE 4.5 vision qkv out dim mismatch: got {out_dim}, expected {total_out} ({name})") | ||
| wq = data_torch[:q_dim] | ||
| wk = data_torch[q_dim : q_dim + kv_dim] | ||
| wv = data_torch[q_dim + kv_dim :] | ||
| nq = name.replace("qkv", "q", 1) | ||
| nk = name.replace("qkv", "k", 1) | ||
| nv = name.replace("qkv", "v", 1) | ||
| yield from ModelBase.modify_tensors(self, wq, nq, bid) | ||
| yield from ModelBase.modify_tensors(self, wk, nk, bid) | ||
| yield from ModelBase.modify_tensors(self, wv, nv, bid) | ||
|
Comment on lines
+10530
to
+10551
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should not split qkv, instead, use see |
||
| return | ||
|
|
||
| # EXAONE4.5 PatchMerger includes ln_q (RMSNorm), but generic Qwen2-VL mapping can miss it. | ||
| # Keep explicit mapping to mm.input_norm for activation-level parity with HF. | ||
| if name == "visual.merger.ln_q.weight": | ||
| yield ("mm.input_norm.weight", data_torch) | ||
| return | ||
|
|
||
| yield from Qwen2VLVisionModel.modify_tensors(self, data_torch, name, bid) | ||
|
|
||
|
|
||
| @ModelBase.register("GraniteForCausalLM") | ||
| class GraniteModel(LlamaModel): | ||
| """Conversion for IBM's GraniteForCausalLM""" | ||
|
|
@@ -13118,7 +13243,9 @@ def get_model_architecture(hparams: dict[str, Any], model_type: ModelType) -> st | |
|
|
||
| # if "architectures" is found in the sub-config, use that instead | ||
| if model_type == ModelType.TEXT and text_config.get("architectures") is not None: | ||
| arch = text_config["architectures"][0] | ||
| # Multimodal EXAONE 4.5 stores the inner causal LM class in text_config; HF→GGUF must use the VL root arch. | ||
| if hparams.get("model_type") != "exaone4_5": | ||
| arch = text_config["architectures"][0] | ||
| elif model_type == ModelType.MMPROJ and vision_config.get("architectures") is not None: | ||
| arch = vision_config["architectures"][0] | ||
| if arch is None: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use proper
tensor_mapping