fix hunyuan vidoe 1.5 weight-load bug#1052
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the checkpoint loading logic in base_model.py to support specific directory structures for hunyuan_video_1.5 and initializes sequence parallel groups in pre_infer.py. Review feedback highlights potential KeyError and AttributeError risks when accessing configuration values directly and recommends using safer access patterns with .get() and null checks to ensure stability across different environments.
| if self.config["seq_parallel"]: | ||
| self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") | ||
| else: | ||
| self.seq_p_group = None |
There was a problem hiding this comment.
This block has two potential issues:
self.config["seq_parallel"]will raise aKeyErrorif the key is missing from the configuration.self.config.get("device_mesh").get_group(...)will raise anAttributeErrorifdevice_meshis not present in the config, asget()returnsNoneby default, andNonehas no attributeget_group.
Using a safer access pattern ensures the code doesn't crash in distributed environments with incomplete configurations.
| if self.config["seq_parallel"]: | |
| self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") | |
| else: | |
| self.seq_p_group = None | |
| if self.config.get("seq_parallel"): | |
| device_mesh = self.config.get("device_mesh") | |
| self.seq_p_group = device_mesh.get_group(mesh_dim="seq_p") if device_mesh else None | |
| else: | |
| self.seq_p_group = None |
| if self.config["model_cls"] == "hunyuan_video_1.5": | ||
| safetensors_files = glob.glob(os.path.join(safetensors_path, "transformer", self.config["transformer_model_name"], "*.safetensors")) | ||
| else: | ||
| safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) |
There was a problem hiding this comment.
Using self.config["model_cls"] can raise a KeyError if the key is missing from the configuration. Since BaseTransformerModel is a base class used by multiple model types (such as z_image or qwen_image), this change could cause regressions for other models that do not define model_cls in their config. It is safer to use .get() to check the model class.
| if self.config["model_cls"] == "hunyuan_video_1.5": | |
| safetensors_files = glob.glob(os.path.join(safetensors_path, "transformer", self.config["transformer_model_name"], "*.safetensors")) | |
| else: | |
| safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) | |
| if self.config.get("model_cls") == "hunyuan_video_1.5": | |
| safetensors_files = glob.glob(os.path.join(safetensors_path, "transformer", self.config.get("transformer_model_name", ""), "*.safetensors")) | |
| else: | |
| safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors")) |
No description provided.