diff --git a/tzrec/utils/checkpoint_util.py b/tzrec/utils/checkpoint_util.py index e8bc3016..73d4b515 100644 --- a/tzrec/utils/checkpoint_util.py +++ b/tzrec/utils/checkpoint_util.py @@ -445,6 +445,45 @@ def _a2a( m._buffers[name].copy_(new_meta) +# Module-name segments that get a `_user` twin when INPUT_TILE=3 export +# duplicates the embedding group into item/user halves. See +# tzrec/modules/embedding.py:EmbeddingGroupImpl / +# SequenceEmbeddingGroupImpl for the construction. +_INPUT_TILE_USER_SEGMENTS = frozenset({"ebc", "mc_ebc", "ec_dict", "mc_ec_dict"}) + + +def _link_dynamicemb_input_tile_user_paths(dynamicemb_path: str) -> None: + """Symlink user-side dynamicemb dirs added by INPUT_TILE=3 export. + + Each entry under ``dynamicemb_path`` is a dotted module path. For + every entry containing a segment in ``_INPUT_TILE_USER_SEGMENTS``, + create a sibling symlink whose matching segment carries a ``_user`` + suffix so DynamicEmbLoad can resolve the twin path used by the + INPUT_TILE=3 model. + """ + for entry in os.listdir(dynamicemb_path): + full_path = os.path.join(dynamicemb_path, entry) + if os.path.islink(full_path) or not os.path.isdir(full_path): + continue + segs = entry.split(".") + for i, seg in enumerate(segs): + if seg not in _INPUT_TILE_USER_SEGMENTS: + continue + user_segs = list(segs) + user_segs[i] = f"{seg}_user" + user_entry = ".".join(user_segs) + user_path = os.path.join(dynamicemb_path, user_entry) + if os.path.lexists(user_path): + continue + try: + os.symlink(entry, user_path) + logger.info( + f"created INPUT_TILE=3 dynamicemb symlink {user_entry} -> {entry}" + ) + except OSError as e: + logger.warning(f"failed to create dynamicemb symlink {user_entry}: {e}") + + def restore_model( checkpoint_dir: str, model: nn.Module, @@ -523,6 +562,19 @@ def restore_model( dynamicemb_path = os.path.join(checkpoint_dir, "dynamicemb") if os.path.exists(dynamicemb_path): + # Training never sets INPUT_TILE, but exporting with + # INPUT_TILE=3 adds twin user-side modules (`ebc_user`, + # `mc_ebc_user`, `ec_dict_user`, `mc_ec_dict_user`) that + # share dynamic-embedding tables with their non-user + # counterparts. The checkpoint only has the non-user paths, + # so DynamicEmbLoad raises "can't find path to load" for + # the user-side ones. Symlink the missing twin paths. + input_tile = os.environ.get("INPUT_TILE", "") + if input_tile.startswith("3"): + if int(os.environ.get("RANK", 0)) == 0: + _link_dynamicemb_input_tile_user_paths(dynamicemb_path) + if dist.is_initialized() and dist.get_world_size() > 1: + dist.barrier() logger.info( f"RANK[{os.environ.get('RANK', 0)}] restoring dynamic embedding..." )