1010from .dequant import is_quantized , dequantize_tensor
1111
1212IMG_ARCH_LIST = {"flux" , "sd1" , "sdxl" , "sd3" , "aura" , "hidream" , "cosmos" , "ltxv" , "hyvid" , "wan" , "lumina2" , "qwen_image" }
13- TXT_ARCH_LIST = {"t5" , "t5encoder" , "llama" , "qwen2vl" , "qwen3" , "qwen3vl" }
13+ TXT_ARCH_LIST = {"t5" , "t5encoder" , "llama" , "qwen2vl" , "qwen3" , "qwen3vl" , "gemma3" }
1414VIS_TYPE_LIST = {"clip-vision" , "mmproj" }
1515
1616def get_orig_shape (reader , tensor_name ):
@@ -199,6 +199,13 @@ def gguf_sd_loader(path, handle_prefix="model.diffusion_model.", is_text_model=F
199199 "output.weight" : "lm_head.weight" ,
200200}
201201
202+ GEMMA3_SD_MAP = LLAMA_SD_MAP .copy ()
203+ GEMMA3_SD_MAP .update ({
204+ "ffn_norm" : "pre_feedforward_layernorm" ,
205+ "post_ffw_norm" : "post_feedforward_layernorm" ,
206+ "post_attention_norm" : "post_attention_layernorm" ,
207+ })
208+
202209CLIP_VISION_SD_MAP = {
203210 "mm." : "visual.merger.mlp." ,
204211 "v.post_ln." : "visual.merger.ln_q." ,
@@ -232,6 +239,28 @@ def llama_permute(raw_sd, n_head, n_head_kv):
232239 sd [k ] = v
233240 return sd
234241
242+ def gemma3_norm_corrections (sd ):
243+ # Reverse change from Gemma3Model modify_tensors in llama.cpp convert script
244+ norm_patterns = [
245+ "input_layernorm.weight" ,
246+ "post_attention_layernorm.weight" ,
247+ "pre_feedforward_layernorm.weight" ,
248+ "post_feedforward_layernorm.weight" ,
249+ "self_attn.q_norm.weight" ,
250+ "self_attn.k_norm.weight" ,
251+ "model.norm.weight"
252+ ]
253+ corrected = 0
254+ for key in list (sd .keys ()):
255+ if any (p in key for p in norm_patterns ):
256+ if is_quantized (sd [key ]):
257+ sd [key ] = dequantize_tensor (sd [key ], dtype = torch .float32 ) - 1.0
258+ else :
259+ sd [key ] = sd [key ].float () - 1.0
260+ corrected += 1
261+ #logging.info(f"Gemma3: Applied -1 norm correction to {corrected} tensors")
262+ return sd
263+
235264def strip_quant_suffix (name ):
236265 pattern = r"[-_]?(?:ud-)?i?q[0-9]_[a-z0-9_\-]{1,8}$"
237266 match = re .search (pattern , name , re .IGNORECASE )
@@ -396,6 +425,48 @@ def gguf_tekken_tokenizer_loader(path, temb_shape):
396425 del reader
397426 return torch .ByteTensor (list (json .dumps (data ).encode ('utf-8' )))
398427
428+ def gguf_gemma3_tokenizer_loader (path ):
429+ #TODO: merge into gguf_tokenizer_loader
430+ logging .info ("Attempting to recreate sentencepiece tokenizer from GGUF file metadata..." )
431+ try :
432+ from sentencepiece import sentencepiece_model_pb2 as model
433+ except ImportError :
434+ raise ImportError ("Please install sentencepiece and protobuf.\n pip install sentencepiece protobuf" )
435+ spm = model .ModelProto ()
436+ reader = gguf .GGUFReader (path )
437+
438+ spm .normalizer_spec .name = "identity"
439+ spm .normalizer_spec .add_dummy_prefix = False
440+ spm .trainer_spec .model_type = 2
441+ spm .trainer_spec .input_format = "tsv"
442+ spm .trainer_spec .byte_fallback = True
443+ spm .trainer_spec .max_sentence_length = 4192
444+ spm .trainer_spec .bos_piece = "<bos>"
445+
446+ tokens = get_list_field (reader , "tokenizer.ggml.tokens" , str )
447+ scores = get_list_field (reader , "tokenizer.ggml.scores" , float )
448+ toktype = get_list_field (reader , "tokenizer.ggml.token_type" , int )
449+
450+ if not tokens or not scores or not toktype :
451+ raise ValueError ("Missing tokenizer metadata" )
452+
453+ for idx in range (len (tokens )):
454+ piece = spm .SentencePiece ()
455+ piece .piece = tokens [idx ]
456+ if idx == 3 : # UNK position
457+ piece .type = 2 # UNK Token
458+ piece .score = 0.0 # UNK Score
459+ else :
460+ piece .type = toktype [idx ]
461+ piece .score = scores [idx ]
462+ spm .pieces .append (piece )
463+
464+ spm .trainer_spec .vocab_size = len (spm .pieces )
465+ logging .info (f"Created tokenizer with vocab size of { len (spm .pieces )} " )
466+
467+ del reader
468+ return torch .ByteTensor (list (spm .SerializeToString ()))
469+
399470def gguf_clip_loader (path ):
400471 sd , extra = gguf_sd_loader (path , is_text_model = True )
401472 arch = extra .get ("arch_str" , None )
@@ -408,17 +479,23 @@ def gguf_clip_loader(path):
408479 logging .warning (f"Dequantizing { temb_key } to prevent runtime OOM." )
409480 sd [temb_key ] = dequantize_tensor (sd [temb_key ], dtype = torch .float16 )
410481 sd = sd_map_replace (sd , T5_SD_MAP )
411- elif arch in {"llama" , "qwen2vl" , "qwen3" , "qwen3vl" }:
482+ elif arch in {"llama" , "qwen2vl" , "qwen3" , "qwen3vl" , "gemma3" }:
412483 # TODO: pass model_options["vocab_size"] to loader somehow
413484 temb_key = "token_embd.weight"
414485 if temb_key in sd and sd [temb_key ].shape [0 ] >= (64 * 1024 ):
415486 if arch == "llama" and sd [temb_key ].shape == (131072 , 5120 ):
416487 # non-standard Comfy-Org tokenizer
417488 sd ["tekken_model" ] = gguf_tekken_tokenizer_loader (path , sd [temb_key ].shape )
489+ elif arch == "gemma3" :
490+ sd ["spiece_model" ] = gguf_gemma3_tokenizer_loader (path )
418491 # See note above for T5.
419492 logging .warning (f"Dequantizing { temb_key } to prevent runtime OOM." )
420493 sd [temb_key ] = dequantize_tensor (sd [temb_key ], dtype = torch .float16 )
421- sd = sd_map_replace (sd , LLAMA_SD_MAP )
494+ if arch == "gemma3" :
495+ sd = sd_map_replace (sd , GEMMA3_SD_MAP )
496+ sd = gemma3_norm_corrections (sd )
497+ else :
498+ sd = sd_map_replace (sd , LLAMA_SD_MAP )
422499 if arch == "llama" :
423500 sd = llama_permute (sd , 32 , 8 ) # L3 / Mistral
424501 if arch == "qwen2vl" :
@@ -427,4 +504,3 @@ def gguf_clip_loader(path):
427504 else :
428505 pass
429506 return sd
430-
0 commit comments