Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
ffc0325
[feat] dynamicemb planner: mode-aware storage formula (HYBRID vs CACH…
tiankongdeguiji May 12, 2026
ead7df2
[feat] dynamicemb planner: mode-aware perf model (CACHING bw >= HYBRID)
tiankongdeguiji May 12, 2026
c88c4f4
[feat] dynamicemb planner: enumerate { HYBRID, CACHING } × factors
tiankongdeguiji May 12, 2026
6ce464b
[feat] dynamicemb planner: 2D DP over HBM × DDR
tiankongdeguiji May 12, 2026
7187c2a
[test] dynamicemb planner: end-to-end enumerate + DP integration test
tiankongdeguiji May 12, 2026
c3bff19
[fix] dynamicemb planner: refine storage docstring + simplify perf wr…
tiankongdeguiji May 13, 2026
38bd7ad
[fix] dynamicemb planner: per-option DDR prune against per-machine cap
tiankongdeguiji May 13, 2026
70b0fcb
[refactor] fold dynamicemb e2e tests + rename DP proposer test + bump…
tiankongdeguiji May 13, 2026
c28e2b2
[fix] dynamicemb planner: empirically fitted CACHING vs HYBRID perf c…
tiankongdeguiji May 13, 2026
5280156
[fix] drop duplicate empirical-fit comment above the constants
tiankongdeguiji May 13, 2026
625e1b7
[fix] dynamicemb planner: drop redundant deepcopy in variant emission
tiankongdeguiji May 13, 2026
3ff8be6
[fix] dynamicemb planner: cap 2D DP per-axis bins to avoid multi-host…
tiankongdeguiji May 13, 2026
984d0df
[doc] dynamicemb planner: tighten three docstrings to match implement…
tiankongdeguiji May 13, 2026
08563bb
[test] dynamicemb planner: multi-table DP + legacy alias + empty sear…
tiankongdeguiji May 13, 2026
0c0d3d7
[fix] dynamicemb planner: restore cache_params in finally + direct wr…
tiankongdeguiji May 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 149 additions & 12 deletions tzrec/utils/dynamicemb_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
import math
import os
from typing import List, Optional, Tuple, Type, cast
from typing import Any, List, Optional, Tuple, Type, cast

import torch
from torch import nn
Expand All @@ -21,7 +22,10 @@
planners,
shard_estimators,
)
from torchrec.distributed.planner.estimator.types import HardwarePerfConfig
from torchrec.distributed.planner.estimator.types import (
HardwarePerfConfig,
ShardPerfContext,
)
from torchrec.distributed.planner.types import (
ParameterConstraints,
ShardingOption,
Expand All @@ -44,6 +48,54 @@

from tzrec.protos import feature_pb2

_DYNAMICEMB_CACHING_X_EFF_BASE = 0.28
_DYNAMICEMB_HYBRID_X_EFF_BASE = 0.11
_DYNAMICEMB_X_EFF_TIEBREAK = 0.01


def _dynamicemb_effective_cache_ratio(
cache_load_factor: Optional[float],
caching: bool,
stats: Optional[Any] = None,
) -> float:
"""Effective HBM-hit ratio for the dynamicemb perf model.

Returns the value passed to torchrec's perf bandwidth formula
``bw = x_eff*hbm + (1-x_eff)*hbm_to_ddr_bw``. Larger value = faster path.

The ratio is derived from an on-device perf sweep, not a heuristic.
Empirical pattern (alpha=1.05 pow-law access on A10):

* ``x == 1.0``: the runtime *switches kernels* — when
``total_value_memory <= local_hbm_for_values`` the dual-tier
``HybridStorage`` / ``DynamicEmbCache`` paths are dropped in favor
of the HBM-only ``DynamicEmbStorage`` kernel
(``batched_dynamicemb_tables.py:502-604``). The ~8x jump in ``x_eff``
between ``x=0.9`` and ``x=1.0`` is intentional and matches measured
latency, not a smoothing artifact. (A future refactor could lift
this to a discrete ``mode={HBM_ONLY, HYBRID, CACHING}`` axis on the
enumerator side rather than packing the discontinuity into ``x``.)
* ``caching=True``, ``x < 1.0``: 3.3x slower than HBM-only -> base 0.28.
* ``caching=False``, ``x < 1.0``: 6.8x slower than HBM-only -> base 0.11.

Within each ``x < 1.0`` block the perf is roughly flat in ratio, but we
add a tiny monotonic perturbation so the DP can break ties.

If ``stats`` is provided, ``1 - stats.expected_miss_rate(x)`` overrides
the heuristic verbatim (clamped to ``[0, 1]``); the caller opts in to
their own measurement.
"""
x = float(cache_load_factor) if cache_load_factor is not None else 0.0
x = max(0.0, min(1.0, x))
if stats is not None:
miss_rate = float(stats.expected_miss_rate(x))
return max(0.0, min(1.0, 1.0 - miss_rate))
if x >= 1.0:
return 1.0
base = _DYNAMICEMB_CACHING_X_EFF_BASE if caching else _DYNAMICEMB_HYBRID_X_EFF_BASE
return base + _DYNAMICEMB_X_EFF_TIEBREAK * x
Comment on lines +93 to +96
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sharp discontinuity at x=1.0. HYBRID@x=0.99 → 0.1199; HYBRID@x=1.00 → 1.0 — an ~8× jump for a 1% ratio change. The DP will reliably prefer x=1.0 over x=0.99 by a huge perf margin, then prefer CACHING@x=0.5 (0.285) over HYBRID@x=0.9 (0.119) even on workloads where HYBRID@0.9 is plainly faster in reality. If the empirical sweep really shows a step at x=1.0 because the runtime drops the host tier, please call that out as "x=1.0 = HBM-only kernel, not the same algorithm as x=0.99" — and consider whether the enumerator should emit a discrete mode={HBM_ONLY, HYBRID, CACHING} axis rather than packing the discontinuity into the same x knob.



has_dynamicemb = False
try:
import dynamicemb
Expand Down Expand Up @@ -258,26 +310,44 @@ def _calculate_dynamicemb_table_storage_specific_size(
is_hbm: bool = True,
only_values: bool = False,
bucket_capacity: int = 128,
caching: bool = False,
) -> int:
"""Calculate dynamic embedding table storage.

total_value_memory = max_capacity x aligned16(embedding+optimizer states)
num_buckets = max_capacity/bucket_capacity
hbm_budget = min(global_hbm_for_values//world_size, total_value_memory) +
max_capacity x (key<8byte> + score<8byte> + digest<1byte>) +
num_buckets x (bucket_size<4byte>)
ddr_budget = max(total_value_memory - global_hbm_for_values//world_size, 0)
"""Per-shard storage size for a dynamicemb table -- HBM or DDR (bytes).

Byte budget (single shard, rows x dim):

value_bytes_per_row = round_up16(dim * (1 + opt_mult) * element)
total_value_memory = align(rows) * value_bytes_per_row
num_buckets = align(rows) / bucket_capacity

hbm_budget = cache_ratio * total_value_memory # values
+ align(rows) * (key<8B> + score<8B> + digest<1B>) # per-row
+ num_buckets * bucket_header<4B> # per-bucket

ddr_budget = HYBRID (caching=False): (1 - cache_ratio) * total_value_memory
CACHING (caching=True): total_value_memory # full backing

HYBRID hash-partitions values across HBM and host; ``cache_ratio`` is
HBM's value share. CACHING keeps the full backing store on host and
uses HBM as a hot-row cache of size
``cache_ratio * total_value_memory``. Hash-table metadata
(key + score + digest + bucket header) is accounted on HBM only --
matches the existing tzrec convention.
"""
if cache_ratio is None:
cache_ratio = 1.0
if is_hbm:
value_ratio = cache_ratio
else:
value_ratio = 1.0 if caching else (1.0 - cache_ratio)
return math.ceil(
align_to_table_size(size[0])
* (
_round_up(
math.ceil(size[1] * (1 + optimizer_multipler) * element_size),
16,
)
* (cache_ratio if is_hbm else 1 - cache_ratio)
* value_ratio
+ (8 + 8 + 1 + 4 / bucket_capacity) * (is_hbm and not only_values)
)
)
Expand Down Expand Up @@ -413,13 +483,74 @@ def _customized_kernel_aware_get_device_bw(
# pyre-ignore [9]
HardwarePerfConfig.get_device_bw = _customized_kernel_aware_get_device_bw

_orig_build_shard_perf_contexts = (
ShardPerfContext.build_shard_perf_contexts.__func__
)

def _dynamicemb_aware_build_shard_perf_contexts(
cls, # pyre-ignore [2]
config, # pyre-ignore [2]
shard_sizes, # pyre-ignore [2]
sharding_option, # pyre-ignore [2]
topology, # pyre-ignore [2]
constraints, # pyre-ignore [2]
sharder, # pyre-ignore [2]
*args, # pyre-ignore [2]
**kwargs, # pyre-ignore [2]
):
"""Inject the empirical x_eff into the perf estimator for both modes.

Temporarily replace ``sharding_option.cache_params`` with a clone
whose ``load_factor`` is the empirically-fitted x_eff for the
(mode, cache_load_factor) combination. Restored before returning so
the (separately invoked) storage estimator still sees the un-boosted
ratio.
"""
dynamicemb_options = getattr(sharding_option, "dynamicemb_options", None)
original_cache_params = sharding_option.cache_params
if dynamicemb_options is not None:
caching = bool(getattr(dynamicemb_options, "caching", False))
stats = original_cache_params.stats if original_cache_params else None
x_eff = _dynamicemb_effective_cache_ratio(
sharding_option.cache_load_factor, caching=caching, stats=stats
)
sharding_option.cache_params = (
dataclasses.replace(original_cache_params, load_factor=x_eff)
if original_cache_params is not None
else CacheParams(load_factor=x_eff)
)
# try/finally so an estimator exception cannot leak the boosted
# cache_params clone into the storage estimator's view of the
# same ShardingOption.
try:
result = _orig_build_shard_perf_contexts(
cls,
config,
shard_sizes,
sharding_option,
topology,
constraints,
sharder,
*args,
**kwargs,
)
finally:
sharding_option.cache_params = original_cache_params
return result

# pyre-ignore [9]
ShardPerfContext.build_shard_perf_contexts = classmethod(
_dynamicemb_aware_build_shard_perf_contexts
)

def _calculate_dynamicemb_storage_specific_sizes(
tensor: torch.Tensor,
shard_sizes: List[List[int]],
optimizer_class: Optional[Type[torch.optim.Optimizer]] = None,
cache_ratio: float = 1.0,
is_inference: bool = False,
bucket_capacity: int = 128,
caching: bool = False,
) -> Tuple[List[int], List[int]]:
"""Calculate storage for dynamicemb."""
optimizer_multipler = 0.0
Expand All @@ -437,6 +568,7 @@ def _calculate_dynamicemb_storage_specific_sizes(
cache_ratio,
is_hbm=True,
bucket_capacity=bucket_capacity,
caching=caching,
)
for size in shard_sizes
]
Expand All @@ -449,6 +581,7 @@ def _calculate_dynamicemb_storage_specific_sizes(
cache_ratio,
is_hbm=False,
bucket_capacity=bucket_capacity,
caching=caching,
)
for size in shard_sizes
]
Expand Down Expand Up @@ -496,7 +629,10 @@ def dynamicemb_calculate_shard_storages(
factors.
num_poolings (List[float]): average number of poolings per sample
(typically 1.0).
caching_ratio (float): ratio of HBM to DDR memory for UVM caching.
caching_ratio (float): cache_load_factor for the dynamicemb table.
In HYBRID mode HBM holds this fraction of values and host
holds the remainder; in CACHING mode HBM is a hot-row cache
of this fraction and host holds the full backing store.
is_pooled (bool): True if embedding output is pooled (ie. `EmbeddingBag`),
False if unpooled/sequential (ie. `Embedding`).
input_data_type_size (int): number of bytes of input data type.
Expand Down Expand Up @@ -535,6 +671,7 @@ def dynamicemb_calculate_shard_storages(
cache_ratio=caching_ratio if caching_ratio else 1.0,
is_inference=is_inference,
bucket_capacity=dynamicemb_options.bucket_capacity,
caching=bool(getattr(dynamicemb_options, "caching", False)),
)
)
counter_hbm_specific_size = 0
Expand Down
Loading
Loading