Skip to content

Commit 9ecc3c4

Browse files
authored
Add Gemma3 12B Support (#402)
* Add Gemma3 12B Support * Update loader.py * Support tokenizer recreation from metadata * Update loader.py * update loader.py
2 parents c2e3b0a + 2e7f529 commit 9ecc3c4

1 file changed

Lines changed: 80 additions & 4 deletions

File tree

loader.py

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .dequant import is_quantized, dequantize_tensor
1111

1212
IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan", "lumina2", "qwen_image"}
13-
TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl", "qwen3", "qwen3vl"}
13+
TXT_ARCH_LIST = {"t5", "t5encoder", "llama", "qwen2vl", "qwen3", "qwen3vl", "gemma3"}
1414
VIS_TYPE_LIST = {"clip-vision", "mmproj"}
1515

1616
def get_orig_shape(reader, tensor_name):
@@ -199,6 +199,13 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", is_text_model=F
199199
"output.weight": "lm_head.weight",
200200
}
201201

202+
GEMMA3_SD_MAP = LLAMA_SD_MAP.copy()
203+
GEMMA3_SD_MAP.update({
204+
"ffn_norm": "pre_feedforward_layernorm",
205+
"post_ffw_norm": "post_feedforward_layernorm",
206+
"post_attention_norm": "post_attention_layernorm",
207+
})
208+
202209
CLIP_VISION_SD_MAP = {
203210
"mm.": "visual.merger.mlp.",
204211
"v.post_ln.": "visual.merger.ln_q.",
@@ -232,6 +239,28 @@ def llama_permute(raw_sd, n_head, n_head_kv):
232239
sd[k] = v
233240
return sd
234241

242+
def gemma3_norm_corrections(sd):
243+
# Reverse change from Gemma3Model modify_tensors in llama.cpp convert script
244+
norm_patterns = [
245+
"input_layernorm.weight",
246+
"post_attention_layernorm.weight",
247+
"pre_feedforward_layernorm.weight",
248+
"post_feedforward_layernorm.weight",
249+
"self_attn.q_norm.weight",
250+
"self_attn.k_norm.weight",
251+
"model.norm.weight"
252+
]
253+
corrected = 0
254+
for key in list(sd.keys()):
255+
if any(p in key for p in norm_patterns):
256+
if is_quantized(sd[key]):
257+
sd[key] = dequantize_tensor(sd[key], dtype=torch.float32) - 1.0
258+
else:
259+
sd[key] = sd[key].float() - 1.0
260+
corrected += 1
261+
#logging.info(f"Gemma3: Applied -1 norm correction to {corrected} tensors")
262+
return sd
263+
235264
def strip_quant_suffix(name):
236265
pattern = r"[-_]?(?:ud-)?i?q[0-9]_[a-z0-9_\-]{1,8}$"
237266
match = re.search(pattern, name, re.IGNORECASE)
@@ -396,6 +425,48 @@ def gguf_tekken_tokenizer_loader(path, temb_shape):
396425
del reader
397426
return torch.ByteTensor(list(json.dumps(data).encode('utf-8')))
398427

428+
def gguf_gemma3_tokenizer_loader(path):
429+
#TODO: merge into gguf_tokenizer_loader
430+
logging.info("Attempting to recreate sentencepiece tokenizer from GGUF file metadata...")
431+
try:
432+
from sentencepiece import sentencepiece_model_pb2 as model
433+
except ImportError:
434+
raise ImportError("Please install sentencepiece and protobuf.\npip install sentencepiece protobuf")
435+
spm = model.ModelProto()
436+
reader = gguf.GGUFReader(path)
437+
438+
spm.normalizer_spec.name = "identity"
439+
spm.normalizer_spec.add_dummy_prefix = False
440+
spm.trainer_spec.model_type = 2
441+
spm.trainer_spec.input_format = "tsv"
442+
spm.trainer_spec.byte_fallback = True
443+
spm.trainer_spec.max_sentence_length = 4192
444+
spm.trainer_spec.bos_piece = "<bos>"
445+
446+
tokens = get_list_field(reader, "tokenizer.ggml.tokens", str)
447+
scores = get_list_field(reader, "tokenizer.ggml.scores", float)
448+
toktype = get_list_field(reader, "tokenizer.ggml.token_type", int)
449+
450+
if not tokens or not scores or not toktype:
451+
raise ValueError("Missing tokenizer metadata")
452+
453+
for idx in range(len(tokens)):
454+
piece = spm.SentencePiece()
455+
piece.piece = tokens[idx]
456+
if idx == 3: # UNK position
457+
piece.type = 2 # UNK Token
458+
piece.score = 0.0 # UNK Score
459+
else:
460+
piece.type = toktype[idx]
461+
piece.score = scores[idx]
462+
spm.pieces.append(piece)
463+
464+
spm.trainer_spec.vocab_size = len(spm.pieces)
465+
logging.info(f"Created tokenizer with vocab size of {len(spm.pieces)}")
466+
467+
del reader
468+
return torch.ByteTensor(list(spm.SerializeToString()))
469+
399470
def gguf_clip_loader(path):
400471
sd, extra = gguf_sd_loader(path, is_text_model=True)
401472
arch = extra.get("arch_str", None)
@@ -408,17 +479,23 @@ def gguf_clip_loader(path):
408479
logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
409480
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
410481
sd = sd_map_replace(sd, T5_SD_MAP)
411-
elif arch in {"llama", "qwen2vl", "qwen3", "qwen3vl"}:
482+
elif arch in {"llama", "qwen2vl", "qwen3", "qwen3vl", "gemma3"}:
412483
# TODO: pass model_options["vocab_size"] to loader somehow
413484
temb_key = "token_embd.weight"
414485
if temb_key in sd and sd[temb_key].shape[0] >= (64 * 1024):
415486
if arch == "llama" and sd[temb_key].shape == (131072, 5120):
416487
# non-standard Comfy-Org tokenizer
417488
sd["tekken_model"] = gguf_tekken_tokenizer_loader(path, sd[temb_key].shape)
489+
elif arch == "gemma3":
490+
sd["spiece_model"] = gguf_gemma3_tokenizer_loader(path)
418491
# See note above for T5.
419492
logging.warning(f"Dequantizing {temb_key} to prevent runtime OOM.")
420493
sd[temb_key] = dequantize_tensor(sd[temb_key], dtype=torch.float16)
421-
sd = sd_map_replace(sd, LLAMA_SD_MAP)
494+
if arch == "gemma3":
495+
sd = sd_map_replace(sd, GEMMA3_SD_MAP)
496+
sd = gemma3_norm_corrections(sd)
497+
else:
498+
sd = sd_map_replace(sd, LLAMA_SD_MAP)
422499
if arch == "llama":
423500
sd = llama_permute(sd, 32, 8) # L3 / Mistral
424501
if arch == "qwen2vl":
@@ -427,4 +504,3 @@ def gguf_clip_loader(path):
427504
else:
428505
pass
429506
return sd
430-

0 commit comments

Comments
 (0)