Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions tzrec/utils/checkpoint_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,45 @@ def _a2a(
m._buffers[name].copy_(new_meta)


# Module-name segments that get a `_user` twin when INPUT_TILE=3 export
# duplicates the embedding group into item/user halves. See
# tzrec/modules/embedding.py:EmbeddingGroupImpl /
# SequenceEmbeddingGroupImpl for the construction.
_INPUT_TILE_USER_SEGMENTS = frozenset({"ebc", "mc_ebc", "ec_dict", "mc_ec_dict"})


def _link_dynamicemb_input_tile_user_paths(dynamicemb_path: str) -> None:
"""Symlink user-side dynamicemb dirs added by INPUT_TILE=3 export.

Each entry under ``dynamicemb_path`` is a dotted module path. For
every entry containing a segment in ``_INPUT_TILE_USER_SEGMENTS``,
create a sibling symlink whose matching segment carries a ``_user``
suffix so DynamicEmbLoad can resolve the twin path used by the
INPUT_TILE=3 model.
"""
for entry in os.listdir(dynamicemb_path):
full_path = os.path.join(dynamicemb_path, entry)
if os.path.islink(full_path) or not os.path.isdir(full_path):
continue
segs = entry.split(".")
for i, seg in enumerate(segs):
if seg not in _INPUT_TILE_USER_SEGMENTS:
continue
user_segs = list(segs)
user_segs[i] = f"{seg}_user"
user_entry = ".".join(user_segs)
user_path = os.path.join(dynamicemb_path, user_entry)
if os.path.lexists(user_path):
continue
try:
os.symlink(entry, user_path)
logger.info(
f"created INPUT_TILE=3 dynamicemb symlink {user_entry} -> {entry}"
)
except OSError as e:
logger.warning(f"failed to create dynamicemb symlink {user_entry}: {e}")
Comment on lines +469 to +484
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor defensive nit: if a single dotted entry ever contained more than one segment in _INPUT_TILE_USER_SEGMENTS, this loop would create N separate symlinks (each renaming a different segment). In practice the construction in EmbeddingGroupImpl / SequenceEmbeddingGroupImpl produces only one such segment per path, so this is theoretical — but adding a break after a successful (or skipped-because-exists) os.symlink would make the intent explicit and match the construction-side rename.



def restore_model(
checkpoint_dir: str,
model: nn.Module,
Expand Down Expand Up @@ -523,6 +562,19 @@ def restore_model(

dynamicemb_path = os.path.join(checkpoint_dir, "dynamicemb")
if os.path.exists(dynamicemb_path):
# Training never sets INPUT_TILE, but exporting with
# INPUT_TILE=3 adds twin user-side modules (`ebc_user`,
# `mc_ebc_user`, `ec_dict_user`, `mc_ec_dict_user`) that
# share dynamic-embedding tables with their non-user
# counterparts. The checkpoint only has the non-user paths,
# so DynamicEmbLoad raises "can't find path to load" for
# the user-side ones. Symlink the missing twin paths.
input_tile = os.environ.get("INPUT_TILE", "")
if input_tile.startswith("3"):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider reusing is_input_tile_emb() from tzrec/acc/utils.py (already used in tzrec/utils/export_util.py) instead of re-deriving the INPUT_TILE=3 check inline. Keeps the env-var convention (input_tile[0] == "3") in one place and avoids drift if it changes.

if int(os.environ.get("RANK", 0)) == 0:
_link_dynamicemb_input_tile_user_paths(dynamicemb_path)
if dist.is_initialized() and dist.get_world_size() > 1:
dist.barrier()
logger.info(
f"RANK[{os.environ.get('RANK', 0)}] restoring dynamic embedding..."
)
Expand Down
Loading