From ffc03254ff406d0dc21851cf264dd54dec89e467 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Tue, 12 May 2026 20:28:44 +0800 Subject: [PATCH 01/15] [feat] dynamicemb planner: mode-aware storage formula (HYBRID vs CACHING) Extend `_calculate_dynamicemb_table_storage_specific_size` and its callers with a `caching: bool` parameter. HYBRID keeps the original split formula (DDR = (1 - cache_ratio) * T); CACHING accounts the full backing store on host (DDR = T) regardless of cache_ratio. HBM accounting and metadata (key + score + digest + bucket header) stay identical between modes. The flag is sourced from `dynamicemb_options.caching` on the ShardingOption in `dynamicemb_calculate_shard_storages`. Subsequent commits wire this into the enumerator (per-table sweep over both modes) and the proposer (2D DP). Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/utils/dynamicemb_util.py | 30 ++++++-- tzrec/utils/dynamicemb_util_test.py | 110 ++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+), 7 deletions(-) create mode 100644 tzrec/utils/dynamicemb_util_test.py diff --git a/tzrec/utils/dynamicemb_util.py b/tzrec/utils/dynamicemb_util.py index 09e9bf2b..8a3c398c 100644 --- a/tzrec/utils/dynamicemb_util.py +++ b/tzrec/utils/dynamicemb_util.py @@ -258,18 +258,30 @@ def _calculate_dynamicemb_table_storage_specific_size( is_hbm: bool = True, only_values: bool = False, bucket_capacity: int = 128, + caching: bool = False, ) -> int: """Calculate dynamic embedding table storage. - total_value_memory = max_capacity x aligned16(embedding+optimizer states) - num_buckets = max_capacity/bucket_capacity - hbm_budget = min(global_hbm_for_values//world_size, total_value_memory) + - max_capacity x (key<8byte> + score<8byte> + digest<1byte>) + - num_buckets x (bucket_size<4byte>) - ddr_budget = max(total_value_memory - global_hbm_for_values//world_size, 0) + HYBRID mode (``caching=False``): values are hash-partitioned across HBM + and host; HBM holds ``cache_ratio`` of values, host holds + ``1 - cache_ratio``. + + CACHING mode (``caching=True``): host holds the full backing store; HBM + holds an ``cache_ratio`` fraction as a cache. + + hbm = align(rows) * (round_up16(dim*element) * cache_ratio + metadata) + ddr = align(rows) * round_up16(dim*element) * value_ratio_ddr + + where ``value_ratio_ddr`` is ``1 - cache_ratio`` for HYBRID and ``1.0`` + for CACHING. Metadata (key + score + digest + bucket) is accounted on + HBM only, matching the existing tzrec convention. """ if cache_ratio is None: cache_ratio = 1.0 + if is_hbm: + value_ratio = cache_ratio + else: + value_ratio = 1.0 if caching else (1.0 - cache_ratio) return math.ceil( align_to_table_size(size[0]) * ( @@ -277,7 +289,7 @@ def _calculate_dynamicemb_table_storage_specific_size( math.ceil(size[1] * (1 + optimizer_multipler) * element_size), 16, ) - * (cache_ratio if is_hbm else 1 - cache_ratio) + * value_ratio + (8 + 8 + 1 + 4 / bucket_capacity) * (is_hbm and not only_values) ) ) @@ -420,6 +432,7 @@ def _calculate_dynamicemb_storage_specific_sizes( cache_ratio: float = 1.0, is_inference: bool = False, bucket_capacity: int = 128, + caching: bool = False, ) -> Tuple[List[int], List[int]]: """Calculate storage for dynamicemb.""" optimizer_multipler = 0.0 @@ -437,6 +450,7 @@ def _calculate_dynamicemb_storage_specific_sizes( cache_ratio, is_hbm=True, bucket_capacity=bucket_capacity, + caching=caching, ) for size in shard_sizes ] @@ -449,6 +463,7 @@ def _calculate_dynamicemb_storage_specific_sizes( cache_ratio, is_hbm=False, bucket_capacity=bucket_capacity, + caching=caching, ) for size in shard_sizes ] @@ -535,6 +550,7 @@ def dynamicemb_calculate_shard_storages( cache_ratio=caching_ratio if caching_ratio else 1.0, is_inference=is_inference, bucket_capacity=dynamicemb_options.bucket_capacity, + caching=bool(getattr(dynamicemb_options, "caching", False)), ) ) counter_hbm_specific_size = 0 diff --git a/tzrec/utils/dynamicemb_util_test.py b/tzrec/utils/dynamicemb_util_test.py new file mode 100644 index 00000000..e18818f8 --- /dev/null +++ b/tzrec/utils/dynamicemb_util_test.py @@ -0,0 +1,110 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from parameterized import parameterized + +from tzrec.utils import dynamicemb_util + + +@unittest.skipUnless( + dynamicemb_util.has_dynamicemb, "dynamicemb is not installed; skipping." +) +class StorageFormulaTest(unittest.TestCase): + """Mode-aware ``_calculate_dynamicemb_table_storage_specific_size``.""" + + ROWS = 1024 + DIM = 64 + ELEMENT_SIZE = 4 + BUCKET_CAPACITY = 128 + + def _calc(self, *, cache_ratio, is_hbm, caching, only_values=False): + return dynamicemb_util._calculate_dynamicemb_table_storage_specific_size( + size=[self.ROWS, self.DIM], + element_size=self.ELEMENT_SIZE, + cache_ratio=cache_ratio, + is_hbm=is_hbm, + only_values=only_values, + bucket_capacity=self.BUCKET_CAPACITY, + caching=caching, + ) + + @parameterized.expand( + [ + ("ratio_0_0", 0.0), + ("ratio_0_25", 0.25), + ("ratio_0_5", 0.5), + ("ratio_0_75", 0.75), + ("ratio_1_0", 1.0), + ] + ) + def test_hbm_identical_between_modes(self, _name, cache_ratio): + # HBM accounting is the same in HYBRID and CACHING: HBM holds a + # cache_ratio fraction of values plus full-row-count metadata. + hybrid_hbm = self._calc(cache_ratio=cache_ratio, is_hbm=True, caching=False) + caching_hbm = self._calc(cache_ratio=cache_ratio, is_hbm=True, caching=True) + self.assertEqual(hybrid_hbm, caching_hbm) + + @parameterized.expand( + [ + ("ratio_0_0", 0.0), + ("ratio_0_25", 0.25), + ("ratio_0_5", 0.5), + ("ratio_0_75", 0.75), + ("ratio_1_0", 1.0), + ] + ) + def test_ddr_hybrid_complements_cache(self, _name, cache_ratio): + # HYBRID DDR = (1 - cache_ratio) * full-table DDR. + full_ddr = self._calc(cache_ratio=0.0, is_hbm=False, caching=False) + hybrid_ddr = self._calc(cache_ratio=cache_ratio, is_hbm=False, caching=False) + self.assertEqual(hybrid_ddr, round((1.0 - cache_ratio) * full_ddr)) + + @parameterized.expand( + [ + ("ratio_0_0", 0.0), + ("ratio_0_25", 0.25), + ("ratio_0_5", 0.5), + ("ratio_0_75", 0.75), + ("ratio_1_0", 1.0), + ] + ) + def test_ddr_caching_holds_full_table(self, _name, cache_ratio): + # CACHING DDR is the full backing store, independent of cache_ratio. + full_ddr = self._calc(cache_ratio=0.0, is_hbm=False, caching=False) + caching_ddr = self._calc(cache_ratio=cache_ratio, is_hbm=False, caching=True) + self.assertEqual(caching_ddr, full_ddr) + + def test_caching_ddr_strictly_greater_than_hybrid_when_cached(self): + for cache_ratio in (0.1, 0.5, 0.9): + hybrid_ddr = self._calc( + cache_ratio=cache_ratio, is_hbm=False, caching=False + ) + caching_ddr = self._calc( + cache_ratio=cache_ratio, is_hbm=False, caching=True + ) + self.assertGreater(caching_ddr, hybrid_ddr) + + def test_only_values_drops_metadata(self): + # only_values=True strips HBM metadata regardless of mode. + for caching in (False, True): + with_meta = self._calc( + cache_ratio=0.5, is_hbm=True, caching=caching, only_values=False + ) + without_meta = self._calc( + cache_ratio=0.5, is_hbm=True, caching=caching, only_values=True + ) + self.assertGreater(with_meta, without_meta) + + +if __name__ == "__main__": + unittest.main() From ead7df2c269d17e22f6eac80dd327a3f4e24bf7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Tue, 12 May 2026 20:33:21 +0800 Subject: [PATCH 02/15] [feat] dynamicemb planner: mode-aware perf model (CACHING bw >= HYBRID) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add `_dynamicemb_effective_cache_ratio(cache_load_factor, caching, stats)`: HYBRID passes the ratio through (hash-partitioned HBM hit probability); CACHING boosts to `min(1.0, x * multiplier)`, default multiplier 10.0 via the `TZREC_CACHING_HIT_RATE_MULTIPLIER` env var. `cache_params.stats`, when provided, takes precedence — `1 - stats.expected_miss_rate(x)` — and is clamped to never fall below the HYBRID ratio. Inject the boost at planning time by monkey-patching `ShardPerfContext.build_shard_perf_contexts` to temporarily replace the sharding option's `cache_params.load_factor` with the effective ratio. The original cache_params is restored on return so the storage estimator (which reads cache_load_factor for HBM/DDR sizing) still sees the unboosted value. Net effect: at the same `cache_load_factor`, CACHING is strictly cheaper in perf and strictly more expensive in DDR. The DP proposer (next commit) trades one for the other against the topology budget. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/utils/dynamicemb_util.py | 116 +++++++++++++++++++++++++++- tzrec/utils/dynamicemb_util_test.py | 80 +++++++++++++++++++ 2 files changed, 194 insertions(+), 2 deletions(-) diff --git a/tzrec/utils/dynamicemb_util.py b/tzrec/utils/dynamicemb_util.py index 8a3c398c..3ba19d75 100644 --- a/tzrec/utils/dynamicemb_util.py +++ b/tzrec/utils/dynamicemb_util.py @@ -9,9 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses import math import os -from typing import List, Optional, Tuple, Type, cast +from typing import Any, List, Optional, Tuple, Type, cast import torch from torch import nn @@ -21,7 +22,10 @@ planners, shard_estimators, ) -from torchrec.distributed.planner.estimator.types import HardwarePerfConfig +from torchrec.distributed.planner.estimator.types import ( + HardwarePerfConfig, + ShardPerfContext, +) from torchrec.distributed.planner.types import ( ParameterConstraints, ShardingOption, @@ -44,6 +48,46 @@ from tzrec.protos import feature_pb2 +DYNAMICEMB_CACHING_HIT_RATE_MULTIPLIER_ENV = "TZREC_CACHING_HIT_RATE_MULTIPLIER" +DYNAMICEMB_CACHING_HIT_RATE_MULTIPLIER_DEFAULT = 10.0 + + +def _dynamicemb_effective_cache_ratio( + cache_load_factor: Optional[float], + caching: bool, + stats: Optional[Any] = None, +) -> float: + """Effective HBM-hit ratio for the dynamicemb perf model. + + HYBRID (``caching=False``) is unchanged — every access has probability + ``cache_load_factor`` of hitting HBM since the HBM tier is hash-partitioned. + + CACHING (``caching=True``) pins hot rows in HBM regardless of hash; we + model that as a boosted hit rate. With the default multiplier (10.0) any + cache_load_factor ≥ 0.1 saturates to a full HBM hit. ``stats``, when + provided, overrides the multiplier-based estimate via + ``1 - stats.expected_miss_rate(cache_load_factor)``. + + Invariant: ``CACHING_bw(x) >= HYBRID_bw(x)`` at every ``x`` because the + returned ratio is monotonically ≥ ``cache_load_factor``. + """ + x = float(cache_load_factor) if cache_load_factor is not None else 0.0 + x = max(0.0, min(1.0, x)) + if not caching: + return x + if stats is not None: + miss_rate = float(stats.expected_miss_rate(x)) + return max(x, max(0.0, min(1.0, 1.0 - miss_rate))) + multiplier = float( + os.environ.get( + DYNAMICEMB_CACHING_HIT_RATE_MULTIPLIER_ENV, + DYNAMICEMB_CACHING_HIT_RATE_MULTIPLIER_DEFAULT, + ) + ) + multiplier = max(1.0, multiplier) + return min(1.0, x * multiplier) + + has_dynamicemb = False try: import dynamicemb @@ -425,6 +469,74 @@ def _customized_kernel_aware_get_device_bw( # pyre-ignore [9] HardwarePerfConfig.get_device_bw = _customized_kernel_aware_get_device_bw + _orig_build_shard_perf_contexts = ( + ShardPerfContext.build_shard_perf_contexts.__func__ + ) + + def _dynamicemb_aware_build_shard_perf_contexts( + cls, # pyre-ignore [2] + config, # pyre-ignore [2] + shard_sizes, # pyre-ignore [2] + sharding_option, # pyre-ignore [2] + topology, # pyre-ignore [2] + constraints, # pyre-ignore [2] + sharder, # pyre-ignore [2] + *args, # pyre-ignore [2] + **kwargs, # pyre-ignore [2] + ): + """Inject CACHING-mode hit-rate boost into the perf estimator. + + For dynamicemb tables in CACHING mode, we temporarily replace + ``sharding_option.cache_params`` with a clone whose ``load_factor`` is + boosted to the effective HBM-hit ratio. The original cache_params is + restored before the patched method returns, so the (separately invoked) + storage estimator continues to see the un-boosted ratio. + """ + dynamicemb_options = getattr(sharding_option, "dynamicemb_options", None) + caching = bool(getattr(dynamicemb_options, "caching", False)) + if not caching: + return _orig_build_shard_perf_contexts( + cls, + config, + shard_sizes, + sharding_option, + topology, + constraints, + sharder, + *args, + **kwargs, + ) + + original_cache_params = sharding_option.cache_params + stats = original_cache_params.stats if original_cache_params else None + x_eff = _dynamicemb_effective_cache_ratio( + sharding_option.cache_load_factor, caching=True, stats=stats + ) + if original_cache_params is not None: + boosted = dataclasses.replace(original_cache_params, load_factor=x_eff) + else: + boosted = CacheParams(load_factor=x_eff) + sharding_option.cache_params = boosted + try: + return _orig_build_shard_perf_contexts( + cls, + config, + shard_sizes, + sharding_option, + topology, + constraints, + sharder, + *args, + **kwargs, + ) + finally: + sharding_option.cache_params = original_cache_params + + # pyre-ignore [9] + ShardPerfContext.build_shard_perf_contexts = classmethod( + _dynamicemb_aware_build_shard_perf_contexts + ) + def _calculate_dynamicemb_storage_specific_sizes( tensor: torch.Tensor, shard_sizes: List[List[int]], diff --git a/tzrec/utils/dynamicemb_util_test.py b/tzrec/utils/dynamicemb_util_test.py index e18818f8..4b02ef21 100644 --- a/tzrec/utils/dynamicemb_util_test.py +++ b/tzrec/utils/dynamicemb_util_test.py @@ -9,7 +9,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest +from unittest import mock from parameterized import parameterized @@ -106,5 +108,83 @@ def test_only_values_drops_metadata(self): self.assertGreater(with_meta, without_meta) +class EffectiveCacheRatioTest(unittest.TestCase): + """``_dynamicemb_effective_cache_ratio`` — HYBRID vs CACHING perf ratio.""" + + def test_hybrid_passes_through(self): + for x in (0.0, 0.1, 0.5, 1.0): + self.assertEqual( + dynamicemb_util._dynamicemb_effective_cache_ratio(x, caching=False), x + ) + + def test_caching_default_multiplier_saturates(self): + # Default multiplier is 10.0 — any x >= 0.1 saturates to 1.0. + env = { + k: v + for k, v in os.environ.items() + if k != dynamicemb_util.DYNAMICEMB_CACHING_HIT_RATE_MULTIPLIER_ENV + } + with mock.patch.dict(os.environ, env, clear=True): + self.assertEqual( + dynamicemb_util._dynamicemb_effective_cache_ratio(0.1, caching=True), + 1.0, + ) + self.assertEqual( + dynamicemb_util._dynamicemb_effective_cache_ratio(0.5, caching=True), + 1.0, + ) + self.assertEqual( + dynamicemb_util._dynamicemb_effective_cache_ratio(0.01, caching=True), + 0.1, + ) + + def test_caching_env_override(self): + with mock.patch.dict( + os.environ, + {dynamicemb_util.DYNAMICEMB_CACHING_HIT_RATE_MULTIPLIER_ENV: "2.0"}, + ): + # multiplier=2: x=0.3 -> 0.6, x=0.5 -> 1.0 (clamped) + self.assertAlmostEqual( + dynamicemb_util._dynamicemb_effective_cache_ratio(0.3, caching=True), + 0.6, + ) + self.assertEqual( + dynamicemb_util._dynamicemb_effective_cache_ratio(0.5, caching=True), + 1.0, + ) + + def test_caching_invariant_monotonic_ge_hybrid(self): + # CACHING ratio >= HYBRID ratio at every cache_load_factor. + for x in (0.0, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 1.0): + hybrid = dynamicemb_util._dynamicemb_effective_cache_ratio(x, caching=False) + caching = dynamicemb_util._dynamicemb_effective_cache_ratio(x, caching=True) + self.assertGreaterEqual(caching, hybrid) + + def test_stats_override_uses_expected_miss_rate(self): + class _Stats: + expected_lookups = 1000.0 + + def expected_miss_rate(self, ratio): + return 0.05 # 95% hit rate regardless of ratio + + x_eff = dynamicemb_util._dynamicemb_effective_cache_ratio( + 0.2, caching=True, stats=_Stats() + ) + self.assertAlmostEqual(x_eff, 0.95) + + def test_stats_override_never_drops_below_hybrid(self): + class _Stats: + expected_lookups = 1000.0 + + def expected_miss_rate(self, ratio): + return 0.9 # would give x_eff=0.1, below cache_load_factor=0.5 + + x_eff = dynamicemb_util._dynamicemb_effective_cache_ratio( + 0.5, caching=True, stats=_Stats() + ) + # Invariant: CACHING_bw never falls below HYBRID_bw at same ratio. + self.assertGreaterEqual(x_eff, 0.5) + + if __name__ == "__main__": unittest.main() From c88c4f43fbbe17517725c7e5317cc24e3e0ce23a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Tue, 12 May 2026 20:36:17 +0800 Subject: [PATCH 03/15] =?UTF-8?q?[feat]=20dynamicemb=20planner:=20enumerat?= =?UTF-8?q?e=20{=20HYBRID,=20CACHING=20}=20=C3=97=20factors?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract the per-dynamicemb-table variant emission into `_emit_dynamicemb_variants(base, dynamicemb_options) -> List[ShardingOption]` and sweep both modes per cache_load_factor. When `cache_params` is unset, the helper emits 20 variants (10 factors × 2 modes); when fixed by the caller, it emits 2 (both modes at the fixed factor). Each variant owns a deep-copied `dynamicemb_options` so per-variant `caching` mutations stay isolated. `cache_params.stats` is preserved across all clones for the perf-side miss-rate override. The downstream 2D DP proposer (next commit) selects per table the (mode, ratio) pair that fits both HBM and host budgets while minimizing perf. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/utils/plan_util.py | 52 +++++++++++++++++++-------- tzrec/utils/plan_util_test.py | 68 ++++++++++++++++++++++++++++++++++- 2 files changed, 105 insertions(+), 15 deletions(-) diff --git a/tzrec/utils/plan_util.py b/tzrec/utils/plan_util.py index bf5c69e0..e9705b0b 100644 --- a/tzrec/utils/plan_util.py +++ b/tzrec/utils/plan_util.py @@ -755,6 +755,39 @@ def calculate_shard_storages( ) +def _emit_dynamicemb_variants( + base_option: ShardingOption, + dynamicemb_options: Any, +) -> List[ShardingOption]: + """Expand a dynamicemb ShardingOption into HYBRID + CACHING variants. + + Sweeps both placement modes (``caching=False`` and ``caching=True``) and, + when ``base_option.cache_params`` is unset, ten cache_load_factor values + (0.1, 0.2, ..., 1.0). The downstream 2D DP proposer picks per table the + best (mode, ratio) that fits both HBM and host topology budgets. + + Each returned ShardingOption owns a freshly deep-copied + ``dynamicemb_options`` instance so per-option ``caching`` mutations do not + bleed across variants. + """ + if base_option.cache_params is None: + load_factors = [(i + 1) / 10 for i in range(10)] + stats = None + else: + load_factors = [base_option.cache_params.load_factor] + stats = base_option.cache_params.stats + variants: List[ShardingOption] = [] + for caching_mode in (False, True): + for load_factor in load_factors: + opt = copy.deepcopy(base_option) + opt.cache_params = CacheParams(load_factor=load_factor, stats=stats) + # pyre-ignore [16] + opt.dynamicemb_options = copy.deepcopy(dynamicemb_options) + opt.dynamicemb_options.caching = caching_mode + variants.append(opt) + return variants + + class EmbeddingEnumerator(_EmbeddingEnumerator): """Generates embedding sharding options for given `nn.Module` with constraints. @@ -934,20 +967,11 @@ def enumerate( sharding_option.use_dynamicemb = use_dynamicemb # pyre-ignore [16] sharding_option.dynamicemb_options = dynamicemb_options - if sharding_option.cache_params is None: - # add cache_load_factor automatic search space - for load_factor_step in range(10): - sharding_option_copy = copy.deepcopy( - sharding_option - ) - sharding_option_copy.cache_params = CacheParams( - load_factor=(load_factor_step + 1) / 10 - ) - sharding_options_per_table.append( - sharding_option_copy - ) - else: - sharding_options_per_table.append(sharding_option) + sharding_options_per_table.extend( + _emit_dynamicemb_variants( + sharding_option, dynamicemb_options + ) + ) else: sharding_options_per_table.append(sharding_option) diff --git a/tzrec/utils/plan_util_test.py b/tzrec/utils/plan_util_test.py index 868b09df..17201dc2 100644 --- a/tzrec/utils/plan_util_test.py +++ b/tzrec/utils/plan_util_test.py @@ -10,6 +10,7 @@ # limitations under the License. import unittest +from types import SimpleNamespace import torch from torchrec.distributed.model_parallel import get_default_sharders @@ -18,9 +19,10 @@ from torchrec.distributed.planner.proposers import GridSearchProposer from torchrec.distributed.planner.types import PlannerError, Topology from torchrec.distributed.test_utils.test_model import TestSparseNN +from torchrec.distributed.types import CacheParams from torchrec.modules.embedding_configs import EmbeddingBagConfig -from tzrec.utils.plan_util import DynamicProgrammingProposer +from tzrec.utils.plan_util import DynamicProgrammingProposer, _emit_dynamicemb_variants class PlanUtilTest(unittest.TestCase): @@ -136,5 +138,69 @@ def test_dp_proposer_with_prune(self) -> None: ) +class EmitDynamicEmbVariantsTest(unittest.TestCase): + """``_emit_dynamicemb_variants`` produces { HYBRID, CACHING } × factors.""" + + def _make_base(self, cache_params=None): + # _emit_dynamicemb_variants only touches cache_params and + # dynamicemb_options on the option, so a SimpleNamespace stand-in is + # sufficient and avoids the heavy ShardingOption constructor surface. + return SimpleNamespace( + cache_params=cache_params, + dynamicemb_options=SimpleNamespace(caching=False, bucket_capacity=128), + ) + + def _make_dynamicemb_options(self): + return SimpleNamespace(caching=False, bucket_capacity=128) + + def test_unfixed_factor_emits_twenty_variants(self): + variants = _emit_dynamicemb_variants( + self._make_base(cache_params=None), self._make_dynamicemb_options() + ) + self.assertEqual(len(variants), 20) + cache_modes = sorted({v.dynamicemb_options.caching for v in variants}) + self.assertEqual(cache_modes, [False, True]) + for caching_mode in (False, True): + factors = sorted( + v.cache_params.load_factor + for v in variants + if v.dynamicemb_options.caching is caching_mode + ) + self.assertEqual(factors, [round((i + 1) / 10, 4) for i in range(10)]) + + def test_fixed_factor_emits_two_variants(self): + variants = _emit_dynamicemb_variants( + self._make_base(cache_params=CacheParams(load_factor=0.3)), + self._make_dynamicemb_options(), + ) + self.assertEqual(len(variants), 2) + cache_modes = sorted(v.dynamicemb_options.caching for v in variants) + self.assertEqual(cache_modes, [False, True]) + for v in variants: + self.assertEqual(v.cache_params.load_factor, 0.3) + + def test_variants_own_dynamicemb_options(self): + # Per-variant mutation of caching must not bleed across variants. + opts = self._make_dynamicemb_options() + variants = _emit_dynamicemb_variants(self._make_base(), opts) + for v in variants: + self.assertIsNot(v.dynamicemb_options, opts) + + def test_stats_preserved_on_clone(self): + class _Stats: + expected_lookups = 100.0 + + def expected_miss_rate(self, ratio): + return 0.1 + + stats = _Stats() + variants = _emit_dynamicemb_variants( + self._make_base(cache_params=CacheParams(load_factor=0.2, stats=stats)), + self._make_dynamicemb_options(), + ) + for v in variants: + self.assertIs(v.cache_params.stats, stats) + + if __name__ == "__main__": unittest.main() From 6ce464b05945068ba093c16927466f460173098c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Tue, 12 May 2026 20:39:48 +0800 Subject: [PATCH 04/15] =?UTF-8?q?[feat]=20dynamicemb=20planner:=202D=20DP?= =?UTF-8?q?=20over=20HBM=20=C3=97=20DDR?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rewrite `DynamicProgrammingProposer.feedback` to discretize both per-rank HBM and per-rank DDR into bins (default 100 × 50 per device) and pick per table the (mode, cache_load_factor) pair that minimizes total perf while fitting both topology budgets. State: dp[table][hbm_bin][ddr_bin] = (perf, hbm_actual, ddr_actual) backtrack[table][hbm_bin][ddr_bin] = (opt_id, prev_hbm_bin, prev_ddr_bin) Options whose largest shard exceeds either per-device cap are pruned upfront. Backtracking enumerates feasible (hbm_total, ddr_total) cells in decreasing total memory, yielding Pareto-optimal proposals. The legacy `mem_bins_per_device` kwarg is preserved as an alias for `hbm_bins_per_device` so existing callers still work, and HBM-degenerate (CPU-only) or DDR-degenerate topologies collapse the unused axis to a single bin. Adds `DynamicProgrammingProposer2DTest`: - generous DDR → CACHING wins (cheaper perf, host can fit T) - tight DDR → CACHING rejected, falls back to HYBRID - high cache_load_factor → modes converge Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/utils/plan_util.py | 326 ++++++++++++++++++++-------------- tzrec/utils/plan_util_test.py | 104 +++++++++++ 2 files changed, 295 insertions(+), 135 deletions(-) diff --git a/tzrec/utils/plan_util.py b/tzrec/utils/plan_util.py index e9705b0b..e8ac4449 100644 --- a/tzrec/utils/plan_util.py +++ b/tzrec/utils/plan_util.py @@ -224,70 +224,70 @@ def get_default_sharders() -> List[ModuleSharder[nn.Module]]: return sharders -class DynamicProgrammingProposer(Proposer): - r"""Proposes sharding plans in dynamic programming fashion. - - The problem of the Embedding Sharding Plan can be framed as follows: Given - :math:`M` tables and their corresponding :math:`N` Sharding Options, we need to - select one sharding option for each table such that the total performance is - minimized, while keeping the overall memory constraint :math:`K` in check. This can - be abstracted into the following mathematical formulation: - - Given a matrix :math:`A` of dimensions :math:`(M, N)` and another matrix :math:`B` - of the same dimensions, let the elements of matrix :math:`A` be denoted as - :math:`a_{i,j}` and the elements of matrix :math:`B` as :math:`b_{i,j}`. We aim - to find a set of column indices :math:`\{ j_0, j_1, \ldots, j_{M-1} \}` such that - the following conditions are satisfied: +_INF = float("inf") - 1. :math:`\sum_{i=0}^{M-1} a_{i,j_i} \leq K`, where :math:`K` is a float. - 2. :math:`\sum_{i=0}^{M-1} b_{i,j_i}` is minimized. - This problem can be tackled using dynamic programming. First, discretize :math:`K` - into :math:`K_i`, and denote the discretization function as :math:`f`. +class DynamicProgrammingProposer(Proposer): + r"""Proposes sharding plans via 2D (HBM × DDR) dynamic programming. - Define the state :math:`dp[i][f(k)]` to represent the minimum value of :math:`B` - when considering the first :math:`i` rows and the total sum of :math:`A` is equal to - the discretized value :math:`k`. + Given :math:`M` tables each with up to :math:`N` ShardingOptions, pick one + option per table to minimize total perf while respecting both per-rank HBM + and per-rank DDR budgets from the topology. - The state transition can then be represented as: + Each axis (HBM, DDR) is discretized into bins; ``dp[table][h][d]`` holds + the minimum perf over the first ``table`` tables using ``≈ h`` HBM bins + and ``≈ d`` DDR bins. The transition is: .. math:: - dp[i][f(k)] = \min_{j=0}^{N-1} \left( dp[i-1][f(k - A[i][j])] + B[i][j] \right) + dp[i][h][d] = \min_{j} dp[i-1][h - A_h[i][j]][d - A_d[i][j]] + + B[i][j] - Since :math:`K` is the sum allocated across all memory, simply satisfying that the - total memory in the plan equals :math:`K` does not guarantee that the allocation - will fit on all cards. Therefore, it is essential to maintain all the states of the - last layer of :math:`dp`. This allows us to propose different plans under varying - total memory constraints. + Backtracking from each feasible ``(h, d)`` at the last table yields + Pareto-optimal proposals across the joint memory budget. The host axis is + load-bearing for dynamicemb CACHING mode (where DDR = full table) vs + HYBRID (DDR = ``(1 - load_factor) · table``). Args: - mem_bins_per_device (int): memory bins for dynamic programming precision. + hbm_bins_per_device: HBM discretization bins per device. Default 100. + (Accepts legacy ``mem_bins_per_device`` as an alias for + backwards compatibility.) + ddr_bins_per_device: DDR discretization bins per device. Default 50 — + DDR budgets dominate embedding demand, so a coarser axis suffices. """ - def __init__(self, mem_bins_per_device: int = 100) -> None: + def __init__( + self, + hbm_bins_per_device: int = 100, + ddr_bins_per_device: int = 50, + mem_bins_per_device: Optional[int] = None, + ) -> None: self._inited: bool = False - self._mem_bins_per_device: int = max(mem_bins_per_device, 1) + # back-compat alias: ``mem_bins_per_device`` mapped to the HBM axis. + bins_h = ( + mem_bins_per_device + if mem_bins_per_device is not None + else hbm_bins_per_device + ) + self._hbm_bins_per_device: int = max(bins_h, 1) + self._ddr_bins_per_device: int = max(ddr_bins_per_device, 1) self._sharding_options_by_fqn: OrderedDict[str, List[ShardingOption]] = ( OrderedDict() ) - # list of proposals with different total_mem, a proposal is a list of - # indices of sharding_options + # list of proposals with different total_mem; each proposal is a list + # of indices into self._sharding_options_by_fqn[fqn]. self._proposal_list: List[List[int]] = [] self._current_proposal: int = -1 - self._storage_type = "hbm" - if not torch.cuda.is_available(): - self._storage_type = "ddr" def load( self, search_space: List[ShardingOption], enumerator: Optional[Enumerator] = None, ) -> None: - """Load search space.""" + """Load search space, sorted by total (hbm + ddr) ascending.""" self._reset() - # order the sharding_option by total_storage.hbm from low to high for sharding_option in sorted( - search_space, key=lambda x: getattr(x.total_storage, self._storage_type) + search_space, + key=lambda x: (x.total_storage.hbm or 0) + (x.total_storage.ddr or 0), ): fqn = sharding_option.fqn if fqn not in self._sharding_options_by_fqn: @@ -324,105 +324,161 @@ def feedback( perf_rating: Optional[float] = None, storage_constraint: Optional[Topology] = None, ) -> None: - """Feedback last proposed plan.""" - if not self._inited: - self._inited = True - table_count = len(self._sharding_options_by_fqn) - option_count = max([len(x) for x in self._sharding_options_by_fqn.values()]) - - assert storage_constraint is not None - # are we assuming the table will be evenly sharded on all devices? - max_device_mem = 0 - mem_total = 0 - for x in storage_constraint.devices: - cur_device_mem = getattr(x.storage, self._storage_type) - max_device_mem = max(max_device_mem, cur_device_mem) - mem_total += cur_device_mem - - bin_count = self._mem_bins_per_device * len(storage_constraint.devices) - bin_size = float(mem_total) / bin_count - - dp = [ - [(float("inf"), float("inf"))] * bin_count for _ in range(table_count) - ] # [table_id][mem_bin][perf, mem] - - backtrack = [ - [(-1, -1)] * bin_count for _ in range(table_count) - ] # [table_id][mem_bin][opt_id, prev_mem_bin] - - mem_by_fqn = [ - [float("inf") for _ in range(option_count)] for _ in range(table_count) - ] # memory constraint lookup table: [table_id][sharding_option_id] - perf_by_fqn = [ - [float("inf") for _ in range(option_count)] for _ in range(table_count) - ] # performance metrics lookup table: [table_id][sharding_option_id] - - # populate mem and perf for each sharding option and table: - # A[table_id][sharding_option_id] - for table_id, sharding_options in enumerate( - self._sharding_options_by_fqn.values() - ): - for opt_id, sharding_option in enumerate(sharding_options): - # prune mem of one shard > mem of one device - if ( - max( - [ - getattr(shard.storage, self._storage_type) - for shard in sharding_option.shards - ] - ) - > max_device_mem - ): - continue - mem_by_fqn[table_id][opt_id] = _bytes_to_float_bin( - getattr(sharding_option.total_storage, self._storage_type), - bin_size, - ) - perf_by_fqn[table_id][opt_id] = sharding_option.total_perf - - table_0 = 0 - for opt_j in range(option_count): - if mem_by_fqn[0][opt_j] < bin_count: - mem_i = int(mem_by_fqn[0][opt_j]) - # options are ordered in increasing order of mem, we only want to - # consider a sharding option that has higher mem and better perf - # (the smaller the better) - if dp[table_0][mem_i][0] > perf_by_fqn[table_0][opt_j]: - dp[table_0][mem_i] = ( - perf_by_fqn[table_0][opt_j], - mem_by_fqn[table_0][opt_j], - ) - backtrack[table_0][mem_i] = (opt_j, -1) - - # dp: table_count x option_count x bin_count - for table_i in range(1, table_count): - for opt_j in range(option_count): - for mem in range(bin_count): - prev_perf, perv_mem = dp[table_i - 1][mem] - if prev_perf < float("inf"): - new_mem = perv_mem + mem_by_fqn[table_i][opt_j] - if new_mem < bin_count: - new_mem_i = int(new_mem) - new_perf = prev_perf + perf_by_fqn[table_i][opt_j] - if dp[table_i][new_mem_i][0] > new_perf: - dp[table_i][new_mem_i] = (new_perf, new_mem) - backtrack[table_i][new_mem_i] = (opt_j, mem) - self._proposal_list = [] - # fill in all the proposals, starting from highest mem to lowest mem - for c in range(bin_count - 1, -1, -1): - cur_opt_idx, cur_mem_idx = backtrack[table_count - 1][c] - if cur_opt_idx >= 0: - proposal_indices = [-1] * table_count - proposal_indices[table_count - 1] = cur_opt_idx - for i in range(table_count - 2, -1, -1): - proposal_indices[i], cur_mem_idx = backtrack[i][cur_mem_idx] - self._proposal_list.append(proposal_indices) - if len(self._proposal_list) > 0: - self._current_proposal = 0 - else: + """Run 2D DP on first feedback; otherwise advance the proposal cursor.""" + if self._inited: self._current_proposal += 1 if self._current_proposal >= len(self._proposal_list): self._current_proposal = -1 + return + + self._inited = True + assert storage_constraint is not None + table_count = len(self._sharding_options_by_fqn) + if table_count == 0: + return + option_count = max(len(x) for x in self._sharding_options_by_fqn.values()) + + num_devices = len(storage_constraint.devices) + max_device_hbm = 0 + max_device_ddr = 0 + hbm_total = 0 + ddr_total = 0 + for device in storage_constraint.devices: + max_device_hbm = max(max_device_hbm, device.storage.hbm or 0) + max_device_ddr = max(max_device_ddr, device.storage.ddr or 0) + hbm_total += device.storage.hbm or 0 + ddr_total += device.storage.ddr or 0 + + hbm_bins = max(self._hbm_bins_per_device * num_devices, 1) + ddr_bins = max(self._ddr_bins_per_device * num_devices, 1) + # Collapse a degenerate axis to a single bin so we don't waste states + # on (e.g.) CPU-only topologies that have hbm == 0 everywhere. + if hbm_total == 0: + hbm_bins = 1 + if ddr_total == 0: + ddr_bins = 1 + hbm_bin_size = float(hbm_total) / hbm_bins if hbm_bins > 0 else 1.0 + ddr_bin_size = float(ddr_total) / ddr_bins if ddr_bins > 0 else 1.0 + + hbm_by_fqn = [[_INF] * option_count for _ in range(table_count)] + ddr_by_fqn = [[_INF] * option_count for _ in range(table_count)] + perf_by_fqn = [[_INF] * option_count for _ in range(table_count)] + + for table_id, sharding_options in enumerate( + self._sharding_options_by_fqn.values() + ): + for opt_id, sharding_option in enumerate(sharding_options): + max_shard_hbm = max( + (shard.storage.hbm or 0) for shard in sharding_option.shards + ) + max_shard_ddr = max( + (shard.storage.ddr or 0) for shard in sharding_option.shards + ) + # Prune options whose largest shard exceeds either per-device + # budget — the partitioner would reject them anyway. + if hbm_total > 0 and max_shard_hbm > max_device_hbm: + continue + if ddr_total > 0 and max_shard_ddr > max_device_ddr: + continue + hbm_by_fqn[table_id][opt_id] = ( + _bytes_to_float_bin( + sharding_option.total_storage.hbm or 0, hbm_bin_size + ) + if hbm_total > 0 + else 0.0 + ) + ddr_by_fqn[table_id][opt_id] = ( + _bytes_to_float_bin( + sharding_option.total_storage.ddr or 0, ddr_bin_size + ) + if ddr_total > 0 + else 0.0 + ) + perf_by_fqn[table_id][opt_id] = sharding_option.total_perf + + # dp[table][hbm_bin][ddr_bin] = (perf, hbm_actual, ddr_actual) + # backtrack[table][hbm_bin][ddr_bin] = (opt_id, prev_hbm_bin, prev_ddr_bin) + empty_state = (_INF, _INF, _INF) + empty_back = (-1, -1, -1) + dp = [ + [[empty_state] * ddr_bins for _ in range(hbm_bins)] + for _ in range(table_count) + ] + backtrack = [ + [[empty_back] * ddr_bins for _ in range(hbm_bins)] + for _ in range(table_count) + ] + + # Seed the first table. + for opt_j in range(option_count): + h = hbm_by_fqn[0][opt_j] + d = ddr_by_fqn[0][opt_j] + if h >= hbm_bins or d >= ddr_bins: + continue + h_i = int(h) + d_i = int(d) + if dp[0][h_i][d_i][0] > perf_by_fqn[0][opt_j]: + dp[0][h_i][d_i] = (perf_by_fqn[0][opt_j], h, d) + backtrack[0][h_i][d_i] = (opt_j, -1, -1) + + for table_i in range(1, table_count): + prev_dp = dp[table_i - 1] + cur_dp = dp[table_i] + cur_back = backtrack[table_i] + hbm_t = hbm_by_fqn[table_i] + ddr_t = ddr_by_fqn[table_i] + perf_t = perf_by_fqn[table_i] + for h in range(hbm_bins): + prev_dp_h = prev_dp[h] + for d in range(ddr_bins): + prev_state = prev_dp_h[d] + prev_perf = prev_state[0] + if prev_perf == _INF: + continue + prev_h_val = prev_state[1] + prev_d_val = prev_state[2] + for opt_j in range(option_count): + delta_perf = perf_t[opt_j] + if delta_perf == _INF: + continue + new_h = prev_h_val + hbm_t[opt_j] + new_d = prev_d_val + ddr_t[opt_j] + if new_h >= hbm_bins or new_d >= ddr_bins: + continue + new_h_i = int(new_h) + new_d_i = int(new_d) + new_perf = prev_perf + delta_perf + if cur_dp[new_h_i][new_d_i][0] > new_perf: + cur_dp[new_h_i][new_d_i] = (new_perf, new_h, new_d) + cur_back[new_h_i][new_d_i] = (opt_j, h, d) + + # Enumerate proposals in decreasing total memory order. Total memory + # is a tie-break heuristic — Pareto-optimal (perf, hbm, ddr) frontier + # tends to live at the high end, so larger-mem cells are explored + # first and small-mem fallbacks come later. + self._proposal_list = [] + last_back = backtrack[table_count - 1] + coords = sorted( + ( + (h, d) + for h in range(hbm_bins) + for d in range(ddr_bins) + if last_back[h][d][0] >= 0 + ), + key=lambda hd: hd[0] + hd[1], + reverse=True, + ) + for h, d in coords: + cur_opt_idx, prev_h, prev_d = last_back[h][d] + if cur_opt_idx < 0: + continue + proposal_indices = [-1] * table_count + proposal_indices[table_count - 1] = cur_opt_idx + for i in range(table_count - 2, -1, -1): + proposal_indices[i], prev_h, prev_d = backtrack[i][prev_h][prev_d] + self._proposal_list.append(proposal_indices) + if self._proposal_list: + self._current_proposal = 0 def _extract_constraints_for_param( diff --git a/tzrec/utils/plan_util_test.py b/tzrec/utils/plan_util_test.py index 17201dc2..03b6434d 100644 --- a/tzrec/utils/plan_util_test.py +++ b/tzrec/utils/plan_util_test.py @@ -202,5 +202,109 @@ def expected_miss_rate(self, ratio): self.assertIs(v.cache_params.stats, stats) +class _FakeStorage: + def __init__(self, hbm, ddr): + self.hbm = hbm + self.ddr = ddr + + +class _FakeShard: + def __init__(self, hbm, ddr): + self.storage = _FakeStorage(hbm, ddr) + + +class _FakeShardingOption: + """Minimal ShardingOption stand-in: only the fields the DP proposer reads.""" + + def __init__(self, fqn, hbm, ddr, perf): + self.fqn = fqn + # Total = single shard for simplicity (single-rank assignment). + self.shards = [_FakeShard(hbm, ddr)] + self.total_storage = _FakeStorage(hbm, ddr) + self.total_perf = perf + + +def _make_topology(num_devices, hbm_per_device, ddr_per_device): + return SimpleNamespace( + devices=[ + SimpleNamespace( + storage=_FakeStorage(hbm=hbm_per_device, ddr=ddr_per_device) + ) + for _ in range(num_devices) + ] + ) + + +class DynamicProgrammingProposer2DTest(unittest.TestCase): + """2D DP across HBM × DDR picks per-table mode under joint budgets.""" + + def _run(self, search_space, topology): + proposer = DynamicProgrammingProposer( + hbm_bins_per_device=20, ddr_bins_per_device=20 + ) + proposer.load(search_space) + # First propose returns the lowest-mem-per-table seed. + proposer.propose() + proposer.feedback(partitionable=True, storage_constraint=topology) + proposals = [] + proposal = proposer.propose() + while proposal: + proposals.append(proposal) + proposer.feedback(partitionable=True, storage_constraint=topology) + proposal = proposer.propose() + return proposals + + def test_caching_preferred_when_ddr_is_generous(self): + # Three options for one table: + # HYBRID @ x=1.0: hbm = T, ddr = 0, perf = high (HBM-only) + # HYBRID @ x=0.1: hbm = .1T, ddr = .9T, perf = high (slow) + # CACHING @ x=0.1: hbm = .1T, ddr = T, perf = low (fast — modeled hits) + opts = [ + _FakeShardingOption("table_a", hbm=1000, ddr=0, perf=50.0), + _FakeShardingOption("table_a", hbm=100, ddr=900, perf=80.0), + _FakeShardingOption("table_a", hbm=100, ddr=1000, perf=10.0), + ] + topology = _make_topology( + num_devices=2, hbm_per_device=2000, ddr_per_device=2000 + ) + proposals = self._run(opts, topology) + # Best plan must be the CACHING option (perf=10). + best = min(proposals, key=lambda p: sum(o.total_perf for o in p)) + self.assertEqual(best[0].total_perf, 10.0) + + def test_caching_rejected_when_ddr_is_tight(self): + # Host budget is only 950 — CACHING (ddr=1000) cannot fit; HYBRID can. + opts = [ + _FakeShardingOption("table_a", hbm=100, ddr=900, perf=80.0), + _FakeShardingOption("table_a", hbm=100, ddr=1000, perf=10.0), + ] + topology = _make_topology( + num_devices=1, hbm_per_device=2000, ddr_per_device=950 + ) + proposals = self._run(opts, topology) + # Every proposed plan must pick the HYBRID option (perf=80). + for p in proposals: + self.assertEqual(p[0].total_perf, 80.0) + + def test_high_factor_collapses_modes(self): + # At x=1.0 HYBRID == CACHING in HBM and CACHING.ddr = T = HYBRID.hbm. + # If we offer just the high-factor options, DP picks one of them. + opts = [ + _FakeShardingOption("table_a", hbm=1000, ddr=0, perf=50.0), # HYBRID x=1.0 + _FakeShardingOption( + "table_a", hbm=1000, ddr=1000, perf=50.0 + ), # CACHING x=1.0 + ] + topology = _make_topology( + num_devices=1, hbm_per_device=1100, ddr_per_device=2000 + ) + proposals = self._run(opts, topology) + # Either option is fine — they're tied. Just verify the proposer + # returned something feasible. + self.assertGreater(len(proposals), 0) + for p in proposals: + self.assertEqual(p[0].total_perf, 50.0) + + if __name__ == "__main__": unittest.main() From 7187c2a6bfa972f091226030b33ba8dd81a5897a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Tue, 12 May 2026 20:42:03 +0800 Subject: [PATCH 05/15] [test] dynamicemb planner: end-to-end enumerate + DP integration test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Exercise the full planning pipeline on a tiny TestSparseNN backed by a DynamicEmbParameterConstraints constraint. Asserts: - EmbeddingEnumerator yields 20 sharding options (2 modes × 10 factors) - All cache_load_factors and both caching modes are represented - Storage and perf estimators populate non-zero values for each option - DynamicProgrammingProposer.propose() returns feasible plans whose dynamicemb options carry a valid caching flag Gated on has_dynamicemb + torch.cuda.is_available() since the dynamicemb sharder requires CUDA at planning time. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/utils/plan_util_dynamicemb_e2e_test.py | 135 +++++++++++++++++++ 1 file changed, 135 insertions(+) create mode 100644 tzrec/utils/plan_util_dynamicemb_e2e_test.py diff --git a/tzrec/utils/plan_util_dynamicemb_e2e_test.py b/tzrec/utils/plan_util_dynamicemb_e2e_test.py new file mode 100644 index 00000000..0145eb7b --- /dev/null +++ b/tzrec/utils/plan_util_dynamicemb_e2e_test.py @@ -0,0 +1,135 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from torchrec.distributed.embedding_types import EmbeddingComputeKernel +from torchrec.distributed.planner.types import Topology +from torchrec.distributed.test_utils.test_model import TestSparseNN +from torchrec.distributed.types import ShardingType +from torchrec.modules.embedding_configs import EmbeddingBagConfig + +from tzrec.utils.dynamicemb_util import has_dynamicemb + + +@unittest.skipUnless(has_dynamicemb, "dynamicemb is not installed; skipping.") +@unittest.skipUnless(torch.cuda.is_available(), "CUDA is required for dynamicemb.") +class PlanUtilDynamicEmbE2ETest(unittest.TestCase): + """End-to-end exercise of the dynamicemb planner integration.""" + + def _build_constraint(self, max_capacity=4096): + import dynamicemb + from dynamicemb.planner import DynamicEmbParameterConstraints + + opts = dynamicemb.DynamicEmbTableOptions( + max_capacity=max_capacity, + initializer_args=dynamicemb.DynamicEmbInitializerArgs( + mode=dynamicemb.DynamicEmbInitializerMode.UNIFORM, + lower=-0.01, + upper=0.01, + ), + eval_initializer_args=dynamicemb.DynamicEmbInitializerArgs( + mode=dynamicemb.DynamicEmbInitializerMode.CONSTANT, value=0.0 + ), + score_strategy=dynamicemb.DynamicEmbScoreStrategy.STEP, + ) + return DynamicEmbParameterConstraints( + use_dynamicemb=True, + sharding_types=[ShardingType.ROW_WISE.value], + compute_kernels=[EmbeddingComputeKernel.CUSTOMIZED_KERNEL.value], + dynamicemb_options=opts, + ) + + def _build_model(self): + table = EmbeddingBagConfig( + num_embeddings=4096, + embedding_dim=32, + name="table_de", + feature_names=["feat_de"], + ) + return TestSparseNN(tables=[table], sparse_device=torch.device("meta")) + + def test_enumerate_yields_both_modes_and_all_factors(self): + from tzrec.utils.plan_util import EmbeddingEnumerator, get_default_sharders + + model = self._build_model() + topology = Topology(world_size=2, compute_device="cuda") + enumerator = EmbeddingEnumerator( + topology=topology, + batch_size=128, + fqn_constraints={"sparse.ebc.table_de": self._build_constraint()}, + ) + search_space = enumerator.enumerate( + module=model, sharders=get_default_sharders() + ) + self.assertEqual(len(search_space), 20) + caching_modes = sorted( + { + so.dynamicemb_options.caching + for so in search_space + if getattr(so, "use_dynamicemb", False) + } + ) + self.assertEqual(caching_modes, [False, True]) + load_factors = sorted( + { + round(so.cache_load_factor, 4) + for so in search_space + if getattr(so, "use_dynamicemb", False) + } + ) + self.assertEqual(load_factors, [round((i + 1) / 10, 4) for i in range(10)]) + # Each option must carry a non-zero perf and storage estimate. + for so in search_space: + self.assertGreater(so.total_perf, 0) + self.assertGreaterEqual(so.total_storage.hbm, 0) + self.assertGreaterEqual(so.total_storage.ddr, 0) + + def test_dp_proposer_picks_feasible_dynamicemb_plan(self): + from tzrec.utils.plan_util import ( + DynamicProgrammingProposer, + EmbeddingEnumerator, + get_default_sharders, + ) + + model = self._build_model() + topology = Topology(world_size=2, compute_device="cuda") + enumerator = EmbeddingEnumerator( + topology=topology, + batch_size=128, + fqn_constraints={"sparse.ebc.table_de": self._build_constraint()}, + ) + search_space = enumerator.enumerate( + module=model, sharders=get_default_sharders() + ) + + proposer = DynamicProgrammingProposer() + proposer.load(search_space) + proposal = proposer.propose() + self.assertIsNotNone(proposal) + proposer.feedback(partitionable=True, storage_constraint=topology) + + # At least one further proposal should be generated by the 2D DP. + count = 0 + proposal = proposer.propose() + while proposal is not None and count < 5: + count += 1 + for so in proposal: + if getattr(so, "use_dynamicemb", False): + self.assertIn(so.dynamicemb_options.caching, (False, True)) + proposer.feedback(partitionable=True, storage_constraint=topology) + proposal = proposer.propose() + self.assertGreater(count, 0) + + +if __name__ == "__main__": + unittest.main() From c3bff1999f0a32a99507c72e4a94dc14d1c2d200 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 13 May 2026 16:12:00 +0800 Subject: [PATCH 06/15] [fix] dynamicemb planner: refine storage docstring + simplify perf wrapper (1) Restore the byte-budget framing in `_calculate_dynamicemb_table_storage_specific_size` docstring (`total_value_memory`, `num_buckets`, `hbm_budget`, `ddr_budget`) that was lost in the PR #508 rewrite, extending `ddr_budget` to cover both HYBRID and CACHING modes in one place. (2) Simplify `_dynamicemb_aware_build_shard_perf_contexts`: single return, no try/finally, no early-return branch. The unconditional `sharding_option.cache_params = original_cache_params` line at the end is a no-op when caching is False (we never mutated the field) and restores the original cache_params when caching is True. If the wrapped call raises, planning aborts and the option is discarded -- no need for try/finally. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/utils/dynamicemb_util.py | 93 ++++++++++++++++------------------ 1 file changed, 43 insertions(+), 50 deletions(-) diff --git a/tzrec/utils/dynamicemb_util.py b/tzrec/utils/dynamicemb_util.py index 3ba19d75..4a68ab61 100644 --- a/tzrec/utils/dynamicemb_util.py +++ b/tzrec/utils/dynamicemb_util.py @@ -304,21 +304,27 @@ def _calculate_dynamicemb_table_storage_specific_size( bucket_capacity: int = 128, caching: bool = False, ) -> int: - """Calculate dynamic embedding table storage. + """Per-shard storage size for a dynamicemb table -- HBM or DDR (bytes). - HYBRID mode (``caching=False``): values are hash-partitioned across HBM - and host; HBM holds ``cache_ratio`` of values, host holds - ``1 - cache_ratio``. + Byte budget (single shard, rows x dim): - CACHING mode (``caching=True``): host holds the full backing store; HBM - holds an ``cache_ratio`` fraction as a cache. + value_bytes_per_row = round_up16(dim * (1 + opt_mult) * element) + total_value_memory = align(rows) * value_bytes_per_row + num_buckets = align(rows) / bucket_capacity - hbm = align(rows) * (round_up16(dim*element) * cache_ratio + metadata) - ddr = align(rows) * round_up16(dim*element) * value_ratio_ddr + hbm_budget = cache_ratio * total_value_memory # values + + align(rows) * (key<8B> + score<8B> + digest<1B>) # per-row + + num_buckets * bucket_header<4B> # per-bucket - where ``value_ratio_ddr`` is ``1 - cache_ratio`` for HYBRID and ``1.0`` - for CACHING. Metadata (key + score + digest + bucket) is accounted on - HBM only, matching the existing tzrec convention. + ddr_budget = HYBRID (caching=False): (1 - cache_ratio) * total_value_memory + CACHING (caching=True): total_value_memory # full backing + + HYBRID hash-partitions values across HBM and host; ``cache_ratio`` is + HBM's value share. CACHING keeps the full backing store on host and + uses HBM as a hot-row cache of size + ``cache_ratio * total_value_memory``. Hash-table metadata + (key + score + digest + bucket header) is accounted on HBM only -- + matches the existing tzrec convention. """ if cache_ratio is None: cache_ratio = 1.0 @@ -484,53 +490,40 @@ def _dynamicemb_aware_build_shard_perf_contexts( *args, # pyre-ignore [2] **kwargs, # pyre-ignore [2] ): - """Inject CACHING-mode hit-rate boost into the perf estimator. + """Inject the CACHING-mode hit-rate boost into the perf estimator. - For dynamicemb tables in CACHING mode, we temporarily replace + For dynamicemb tables in CACHING mode, temporarily replace ``sharding_option.cache_params`` with a clone whose ``load_factor`` is boosted to the effective HBM-hit ratio. The original cache_params is - restored before the patched method returns, so the (separately invoked) - storage estimator continues to see the un-boosted ratio. + restored before returning, so the (separately invoked) storage + estimator still sees the un-boosted ratio. """ dynamicemb_options = getattr(sharding_option, "dynamicemb_options", None) caching = bool(getattr(dynamicemb_options, "caching", False)) - if not caching: - return _orig_build_shard_perf_contexts( - cls, - config, - shard_sizes, - sharding_option, - topology, - constraints, - sharder, - *args, - **kwargs, - ) - original_cache_params = sharding_option.cache_params - stats = original_cache_params.stats if original_cache_params else None - x_eff = _dynamicemb_effective_cache_ratio( - sharding_option.cache_load_factor, caching=True, stats=stats - ) - if original_cache_params is not None: - boosted = dataclasses.replace(original_cache_params, load_factor=x_eff) - else: - boosted = CacheParams(load_factor=x_eff) - sharding_option.cache_params = boosted - try: - return _orig_build_shard_perf_contexts( - cls, - config, - shard_sizes, - sharding_option, - topology, - constraints, - sharder, - *args, - **kwargs, + if caching: + stats = original_cache_params.stats if original_cache_params else None + x_eff = _dynamicemb_effective_cache_ratio( + sharding_option.cache_load_factor, caching=True, stats=stats + ) + sharding_option.cache_params = ( + dataclasses.replace(original_cache_params, load_factor=x_eff) + if original_cache_params is not None + else CacheParams(load_factor=x_eff) ) - finally: - sharding_option.cache_params = original_cache_params + result = _orig_build_shard_perf_contexts( + cls, + config, + shard_sizes, + sharding_option, + topology, + constraints, + sharder, + *args, + **kwargs, + ) + sharding_option.cache_params = original_cache_params + return result # pyre-ignore [9] ShardPerfContext.build_shard_perf_contexts = classmethod( From 38bd7ad5769eee51ddf9a1f6e8e061fae052f8fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 13 May 2026 16:13:10 +0800 Subject: [PATCH 07/15] [fix] dynamicemb planner: per-option DDR prune against per-machine cap DDR is host-shared across ranks co-located on one machine, so the per-option fit check in `DynamicProgrammingProposer.feedback` should compare `max_shard_ddr` against the largest machine's total DDR pool, not against any single rank's slice. HBM stays per-device (each GPU has its own HBM; no cross-rank sharing). Compute `max_machine_ddr` by summing per-device DDR within each contiguous `local_world_size`-sized window. Update the test fake-topology fixture to carry `local_world_size` (defaults to num_devices). Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/utils/plan_util.py | 20 +++++++++++++++----- tzrec/utils/plan_util_test.py | 5 +++-- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/tzrec/utils/plan_util.py b/tzrec/utils/plan_util.py index e8ac4449..3e1e7449 100644 --- a/tzrec/utils/plan_util.py +++ b/tzrec/utils/plan_util.py @@ -340,14 +340,25 @@ def feedback( num_devices = len(storage_constraint.devices) max_device_hbm = 0 - max_device_ddr = 0 hbm_total = 0 ddr_total = 0 for device in storage_constraint.devices: max_device_hbm = max(max_device_hbm, device.storage.hbm or 0) - max_device_ddr = max(max_device_ddr, device.storage.ddr or 0) hbm_total += device.storage.hbm or 0 ddr_total += device.storage.ddr or 0 + # DDR is host-shared across ranks co-located on one machine, so the + # per-option fit check compares against the largest machine's DDR pool + # -- not per-device. HBM is GPU-local, so its prune stays per-device. + per_host = getattr(storage_constraint, "local_world_size", None) or num_devices + per_host = max(per_host, 1) + max_machine_ddr = 0 + for host_start in range(0, num_devices, per_host): + host_end = min(host_start + per_host, num_devices) + machine_ddr = sum( + (storage_constraint.devices[i].storage.ddr or 0) + for i in range(host_start, host_end) + ) + max_machine_ddr = max(max_machine_ddr, machine_ddr) hbm_bins = max(self._hbm_bins_per_device * num_devices, 1) ddr_bins = max(self._ddr_bins_per_device * num_devices, 1) @@ -374,11 +385,10 @@ def feedback( max_shard_ddr = max( (shard.storage.ddr or 0) for shard in sharding_option.shards ) - # Prune options whose largest shard exceeds either per-device - # budget — the partitioner would reject them anyway. + # HBM is per-device, DDR is per-machine: see comment above. if hbm_total > 0 and max_shard_hbm > max_device_hbm: continue - if ddr_total > 0 and max_shard_ddr > max_device_ddr: + if ddr_total > 0 and max_shard_ddr > max_machine_ddr: continue hbm_by_fqn[table_id][opt_id] = ( _bytes_to_float_bin( diff --git a/tzrec/utils/plan_util_test.py b/tzrec/utils/plan_util_test.py index 03b6434d..e63c8008 100644 --- a/tzrec/utils/plan_util_test.py +++ b/tzrec/utils/plan_util_test.py @@ -224,14 +224,15 @@ def __init__(self, fqn, hbm, ddr, perf): self.total_perf = perf -def _make_topology(num_devices, hbm_per_device, ddr_per_device): +def _make_topology(num_devices, hbm_per_device, ddr_per_device, local_world_size=None): return SimpleNamespace( devices=[ SimpleNamespace( storage=_FakeStorage(hbm=hbm_per_device, ddr=ddr_per_device) ) for _ in range(num_devices) - ] + ], + local_world_size=local_world_size or num_devices, ) From 70b0fcb31d90c38fa39881e99615e9e6f8acad19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 13 May 2026 16:15:14 +0800 Subject: [PATCH 08/15] [refactor] fold dynamicemb e2e tests + rename DP proposer test + bump copyright (5) Move `PlanUtilDynamicEmbE2ETest` from the standalone `tzrec/utils/plan_util_dynamicemb_e2e_test.py` into `plan_util_test.py` where the rest of the plan_util tests already live. Delete the standalone file. (6) Rename `DynamicProgrammingProposer2DTest` -> `DynamicProgrammingProposerTest`. The proposer class itself stays `DynamicProgrammingProposer`. (7) Bump copyright `2024 -> 2026` in the new file `tzrec/utils/dynamicemb_util_test.py`. The header in plan_util_test.py stays at 2024 since it was a pre-existing file we only extended. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/utils/dynamicemb_util_test.py | 2 +- tzrec/utils/plan_util_dynamicemb_e2e_test.py | 135 ------------------- tzrec/utils/plan_util_test.py | 122 ++++++++++++++++- 3 files changed, 121 insertions(+), 138 deletions(-) delete mode 100644 tzrec/utils/plan_util_dynamicemb_e2e_test.py diff --git a/tzrec/utils/dynamicemb_util_test.py b/tzrec/utils/dynamicemb_util_test.py index 4b02ef21..15d3b6eb 100644 --- a/tzrec/utils/dynamicemb_util_test.py +++ b/tzrec/utils/dynamicemb_util_test.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024, Alibaba Group; +# Copyright (c) 2026, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/tzrec/utils/plan_util_dynamicemb_e2e_test.py b/tzrec/utils/plan_util_dynamicemb_e2e_test.py deleted file mode 100644 index 0145eb7b..00000000 --- a/tzrec/utils/plan_util_dynamicemb_e2e_test.py +++ /dev/null @@ -1,135 +0,0 @@ -# Copyright (c) 2024, Alibaba Group; -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -import torch -from torchrec.distributed.embedding_types import EmbeddingComputeKernel -from torchrec.distributed.planner.types import Topology -from torchrec.distributed.test_utils.test_model import TestSparseNN -from torchrec.distributed.types import ShardingType -from torchrec.modules.embedding_configs import EmbeddingBagConfig - -from tzrec.utils.dynamicemb_util import has_dynamicemb - - -@unittest.skipUnless(has_dynamicemb, "dynamicemb is not installed; skipping.") -@unittest.skipUnless(torch.cuda.is_available(), "CUDA is required for dynamicemb.") -class PlanUtilDynamicEmbE2ETest(unittest.TestCase): - """End-to-end exercise of the dynamicemb planner integration.""" - - def _build_constraint(self, max_capacity=4096): - import dynamicemb - from dynamicemb.planner import DynamicEmbParameterConstraints - - opts = dynamicemb.DynamicEmbTableOptions( - max_capacity=max_capacity, - initializer_args=dynamicemb.DynamicEmbInitializerArgs( - mode=dynamicemb.DynamicEmbInitializerMode.UNIFORM, - lower=-0.01, - upper=0.01, - ), - eval_initializer_args=dynamicemb.DynamicEmbInitializerArgs( - mode=dynamicemb.DynamicEmbInitializerMode.CONSTANT, value=0.0 - ), - score_strategy=dynamicemb.DynamicEmbScoreStrategy.STEP, - ) - return DynamicEmbParameterConstraints( - use_dynamicemb=True, - sharding_types=[ShardingType.ROW_WISE.value], - compute_kernels=[EmbeddingComputeKernel.CUSTOMIZED_KERNEL.value], - dynamicemb_options=opts, - ) - - def _build_model(self): - table = EmbeddingBagConfig( - num_embeddings=4096, - embedding_dim=32, - name="table_de", - feature_names=["feat_de"], - ) - return TestSparseNN(tables=[table], sparse_device=torch.device("meta")) - - def test_enumerate_yields_both_modes_and_all_factors(self): - from tzrec.utils.plan_util import EmbeddingEnumerator, get_default_sharders - - model = self._build_model() - topology = Topology(world_size=2, compute_device="cuda") - enumerator = EmbeddingEnumerator( - topology=topology, - batch_size=128, - fqn_constraints={"sparse.ebc.table_de": self._build_constraint()}, - ) - search_space = enumerator.enumerate( - module=model, sharders=get_default_sharders() - ) - self.assertEqual(len(search_space), 20) - caching_modes = sorted( - { - so.dynamicemb_options.caching - for so in search_space - if getattr(so, "use_dynamicemb", False) - } - ) - self.assertEqual(caching_modes, [False, True]) - load_factors = sorted( - { - round(so.cache_load_factor, 4) - for so in search_space - if getattr(so, "use_dynamicemb", False) - } - ) - self.assertEqual(load_factors, [round((i + 1) / 10, 4) for i in range(10)]) - # Each option must carry a non-zero perf and storage estimate. - for so in search_space: - self.assertGreater(so.total_perf, 0) - self.assertGreaterEqual(so.total_storage.hbm, 0) - self.assertGreaterEqual(so.total_storage.ddr, 0) - - def test_dp_proposer_picks_feasible_dynamicemb_plan(self): - from tzrec.utils.plan_util import ( - DynamicProgrammingProposer, - EmbeddingEnumerator, - get_default_sharders, - ) - - model = self._build_model() - topology = Topology(world_size=2, compute_device="cuda") - enumerator = EmbeddingEnumerator( - topology=topology, - batch_size=128, - fqn_constraints={"sparse.ebc.table_de": self._build_constraint()}, - ) - search_space = enumerator.enumerate( - module=model, sharders=get_default_sharders() - ) - - proposer = DynamicProgrammingProposer() - proposer.load(search_space) - proposal = proposer.propose() - self.assertIsNotNone(proposal) - proposer.feedback(partitionable=True, storage_constraint=topology) - - # At least one further proposal should be generated by the 2D DP. - count = 0 - proposal = proposer.propose() - while proposal is not None and count < 5: - count += 1 - for so in proposal: - if getattr(so, "use_dynamicemb", False): - self.assertIn(so.dynamicemb_options.caching, (False, True)) - proposer.feedback(partitionable=True, storage_constraint=topology) - proposal = proposer.propose() - self.assertGreater(count, 0) - - -if __name__ == "__main__": - unittest.main() diff --git a/tzrec/utils/plan_util_test.py b/tzrec/utils/plan_util_test.py index e63c8008..b1cac92e 100644 --- a/tzrec/utils/plan_util_test.py +++ b/tzrec/utils/plan_util_test.py @@ -13,15 +13,17 @@ from types import SimpleNamespace import torch +from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.model_parallel import get_default_sharders from torchrec.distributed.planner.enumerators import EmbeddingEnumerator from torchrec.distributed.planner.partitioners import GreedyPerfPartitioner from torchrec.distributed.planner.proposers import GridSearchProposer from torchrec.distributed.planner.types import PlannerError, Topology from torchrec.distributed.test_utils.test_model import TestSparseNN -from torchrec.distributed.types import CacheParams +from torchrec.distributed.types import CacheParams, ShardingType from torchrec.modules.embedding_configs import EmbeddingBagConfig +from tzrec.utils.dynamicemb_util import has_dynamicemb from tzrec.utils.plan_util import DynamicProgrammingProposer, _emit_dynamicemb_variants @@ -236,7 +238,7 @@ def _make_topology(num_devices, hbm_per_device, ddr_per_device, local_world_size ) -class DynamicProgrammingProposer2DTest(unittest.TestCase): +class DynamicProgrammingProposerTest(unittest.TestCase): """2D DP across HBM × DDR picks per-table mode under joint budgets.""" def _run(self, search_space, topology): @@ -307,5 +309,121 @@ def test_high_factor_collapses_modes(self): self.assertEqual(p[0].total_perf, 50.0) +@unittest.skipUnless(has_dynamicemb, "dynamicemb is not installed; skipping.") +@unittest.skipUnless(torch.cuda.is_available(), "CUDA is required for dynamicemb.") +class PlanUtilDynamicEmbE2ETest(unittest.TestCase): + """End-to-end exercise of the dynamicemb planner integration.""" + + def _build_constraint(self, max_capacity=4096): + import dynamicemb + from dynamicemb.planner import DynamicEmbParameterConstraints + + opts = dynamicemb.DynamicEmbTableOptions( + max_capacity=max_capacity, + initializer_args=dynamicemb.DynamicEmbInitializerArgs( + mode=dynamicemb.DynamicEmbInitializerMode.UNIFORM, + lower=-0.01, + upper=0.01, + ), + eval_initializer_args=dynamicemb.DynamicEmbInitializerArgs( + mode=dynamicemb.DynamicEmbInitializerMode.CONSTANT, value=0.0 + ), + score_strategy=dynamicemb.DynamicEmbScoreStrategy.STEP, + ) + return DynamicEmbParameterConstraints( + use_dynamicemb=True, + sharding_types=[ShardingType.ROW_WISE.value], + compute_kernels=[EmbeddingComputeKernel.CUSTOMIZED_KERNEL.value], + dynamicemb_options=opts, + ) + + def _build_model(self): + table = EmbeddingBagConfig( + num_embeddings=4096, + embedding_dim=32, + name="table_de", + feature_names=["feat_de"], + ) + return TestSparseNN(tables=[table], sparse_device=torch.device("meta")) + + def test_enumerate_yields_both_modes_and_all_factors(self): + from tzrec.utils.plan_util import ( + EmbeddingEnumerator as _TzrecEmbeddingEnumerator, + ) + from tzrec.utils.plan_util import ( + get_default_sharders as _tzrec_get_default_sharders, + ) + + model = self._build_model() + topology = Topology(world_size=2, compute_device="cuda") + enumerator = _TzrecEmbeddingEnumerator( + topology=topology, + batch_size=128, + fqn_constraints={"sparse.ebc.table_de": self._build_constraint()}, + ) + search_space = enumerator.enumerate( + module=model, sharders=_tzrec_get_default_sharders() + ) + self.assertEqual(len(search_space), 20) + caching_modes = sorted( + { + so.dynamicemb_options.caching + for so in search_space + if getattr(so, "use_dynamicemb", False) + } + ) + self.assertEqual(caching_modes, [False, True]) + load_factors = sorted( + { + round(so.cache_load_factor, 4) + for so in search_space + if getattr(so, "use_dynamicemb", False) + } + ) + self.assertEqual(load_factors, [round((i + 1) / 10, 4) for i in range(10)]) + # Each option must carry a non-zero perf and storage estimate. + for so in search_space: + self.assertGreater(so.total_perf, 0) + self.assertGreaterEqual(so.total_storage.hbm, 0) + self.assertGreaterEqual(so.total_storage.ddr, 0) + + def test_dp_proposer_picks_feasible_dynamicemb_plan(self): + from tzrec.utils.plan_util import ( + EmbeddingEnumerator as _TzrecEmbeddingEnumerator, + ) + from tzrec.utils.plan_util import ( + get_default_sharders as _tzrec_get_default_sharders, + ) + + model = self._build_model() + topology = Topology(world_size=2, compute_device="cuda") + enumerator = _TzrecEmbeddingEnumerator( + topology=topology, + batch_size=128, + fqn_constraints={"sparse.ebc.table_de": self._build_constraint()}, + ) + search_space = enumerator.enumerate( + module=model, sharders=_tzrec_get_default_sharders() + ) + + proposer = DynamicProgrammingProposer() + proposer.load(search_space) + proposal = proposer.propose() + self.assertIsNotNone(proposal) + proposer.feedback(partitionable=True, storage_constraint=topology) + + # At least one further proposal should be generated by the 2D DP. + count = 0 + proposal = proposer.propose() + while proposal is not None and count < 5: + count += 1 + for so in proposal: + if getattr(so, "use_dynamicemb", False): + self.assertIn(so.dynamicemb_options.caching, (False, True)) + proposer.feedback(partitionable=True, storage_constraint=topology) + proposal = proposer.propose() + self.assertGreater(count, 0) + + if __name__ == "__main__": unittest.main() From c28e2b23b7721d4552da5dfb778f2ec39ef4f06f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 13 May 2026 16:25:30 +0800 Subject: [PATCH 09/15] [fix] dynamicemb planner: empirically fitted CACHING vs HYBRID perf curve Replace the heuristic ``_dynamicemb_effective_cache_ratio`` curve with constants fitted from an on-device perf sweep (``experiments/sweep_20260513-161030/full_a10gpu1.json``: 4M-row table, dim=128, adam, pow-law alpha=1.05, A10 GPU). Median fwd+bwd latency clustered into three regimes: HYBRID @ x=1.0: 0.80 ms (HBM-only fast path) CACHING @ x<1.0: 2.63 ms (~3.3x slower) HYBRID @ x<1.0: 5.44 ms (~6.8x slower) Inverting the linear bw model bw = x_eff*HBM + (1-x_eff)*HBM_TO_DDR (torchrec defaults 897 / 32 GB/s) gives x_eff bases 0.28 (CACHING) and 0.11 (HYBRID). At x = 1.0 the runtime drops the host tier in both modes, so x_eff = 1.0 there. A +0.01*x tiebreaker term keeps the DP's ranking strict within each block. Also extend the ``_dynamicemb_aware_build_shard_perf_contexts`` wrapper to apply the empirical x_eff for HYBRID (previously HYBRID used raw ``cache_load_factor`` as x_eff -- inconsistent with the empirical data that shows HYBRID @ 0.9 is much slower than HYBRID @ 1.0 due to the storage backend switch). Drop the now-unused ``TZREC_CACHING_HIT_RATE_MULTIPLIER`` env knob and its related tests; rewrite ``EffectiveCacheRatioTest`` to assert the empirical strict-block ranking. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/utils/dynamicemb_util.py | 76 ++++++++++-------- tzrec/utils/dynamicemb_util_test.py | 116 +++++++++++++++------------- 2 files changed, 108 insertions(+), 84 deletions(-) diff --git a/tzrec/utils/dynamicemb_util.py b/tzrec/utils/dynamicemb_util.py index 4a68ab61..9e68afff 100644 --- a/tzrec/utils/dynamicemb_util.py +++ b/tzrec/utils/dynamicemb_util.py @@ -48,8 +48,22 @@ from tzrec.protos import feature_pb2 -DYNAMICEMB_CACHING_HIT_RATE_MULTIPLIER_ENV = "TZREC_CACHING_HIT_RATE_MULTIPLIER" -DYNAMICEMB_CACHING_HIT_RATE_MULTIPLIER_DEFAULT = 10.0 +# Empirical x_eff constants fitted from an on-device dynamicemb sweep +# (4M-row table, dim=128, adam, pow-law alpha=1.05, A10 GPU; see +# experiments/sweep_20260513-161030/full_a10gpu1.json). Median fwd+bwd +# latency clustered into three regimes: +# * HYBRID @ x=1.0: 0.80 ms (HBM-only fast path; runtime drops the +# host tier when total_value_memory <= local_hbm) +# * CACHING @ x<1.0: 2.63 ms (~3.3x slower than HBM-only) +# * HYBRID @ x<1.0: 5.44 ms (~6.8x slower than HBM-only) +# Within each <1.0 block the ratio dependence is noise-dominated. +# Inverting the linear bw model bw = x_eff*HBM + (1-x_eff)*HBM_TO_DDR +# (torchrec defaults: HBM=897 GB/s, HBM_TO_DDR=32 GB/s) yields the +# constants below. The +0.01*x term is a tiebreaker so the DP can produce +# strictly ordered proposals within each block. +_DYNAMICEMB_CACHING_X_EFF_BASE = 0.28 +_DYNAMICEMB_HYBRID_X_EFF_BASE = 0.11 +_DYNAMICEMB_X_EFF_TIEBREAK = 0.01 def _dynamicemb_effective_cache_ratio( @@ -59,33 +73,33 @@ def _dynamicemb_effective_cache_ratio( ) -> float: """Effective HBM-hit ratio for the dynamicemb perf model. - HYBRID (``caching=False``) is unchanged — every access has probability - ``cache_load_factor`` of hitting HBM since the HBM tier is hash-partitioned. + Returns the value passed to torchrec's perf bandwidth formula + ``bw = x_eff*hbm + (1-x_eff)*hbm_to_ddr_bw``. Larger value = faster path. - CACHING (``caching=True``) pins hot rows in HBM regardless of hash; we - model that as a boosted hit rate. With the default multiplier (10.0) any - cache_load_factor ≥ 0.1 saturates to a full HBM hit. ``stats``, when - provided, overrides the multiplier-based estimate via - ``1 - stats.expected_miss_rate(cache_load_factor)``. + The ratio is derived from an on-device perf sweep, not a heuristic. + Empirical pattern (alpha=1.05 pow-law access on A10): - Invariant: ``CACHING_bw(x) >= HYBRID_bw(x)`` at every ``x`` because the - returned ratio is monotonically ≥ ``cache_load_factor``. + * ``x == 1.0``: the runtime drops the host tier (HBM-only); both + modes hit the fastest path. Return ``1.0``. + * ``caching=True``, ``x < 1.0``: 3.3x slower than HBM-only -> base 0.28. + * ``caching=False``, ``x < 1.0``: 6.8x slower than HBM-only -> base 0.11. + + Within each ``x < 1.0`` block the perf is roughly flat in ratio, but we + add a tiny monotonic perturbation so the DP can break ties. + + If ``stats`` is provided, ``1 - stats.expected_miss_rate(x)`` overrides + the heuristic verbatim (clamped to ``[0, 1]``); the caller opts in to + their own measurement. """ x = float(cache_load_factor) if cache_load_factor is not None else 0.0 x = max(0.0, min(1.0, x)) - if not caching: - return x if stats is not None: miss_rate = float(stats.expected_miss_rate(x)) - return max(x, max(0.0, min(1.0, 1.0 - miss_rate))) - multiplier = float( - os.environ.get( - DYNAMICEMB_CACHING_HIT_RATE_MULTIPLIER_ENV, - DYNAMICEMB_CACHING_HIT_RATE_MULTIPLIER_DEFAULT, - ) - ) - multiplier = max(1.0, multiplier) - return min(1.0, x * multiplier) + return max(0.0, min(1.0, 1.0 - miss_rate)) + if x >= 1.0: + return 1.0 + base = _DYNAMICEMB_CACHING_X_EFF_BASE if caching else _DYNAMICEMB_HYBRID_X_EFF_BASE + return base + _DYNAMICEMB_X_EFF_TIEBREAK * x has_dynamicemb = False @@ -490,21 +504,21 @@ def _dynamicemb_aware_build_shard_perf_contexts( *args, # pyre-ignore [2] **kwargs, # pyre-ignore [2] ): - """Inject the CACHING-mode hit-rate boost into the perf estimator. + """Inject the empirical x_eff into the perf estimator for both modes. - For dynamicemb tables in CACHING mode, temporarily replace - ``sharding_option.cache_params`` with a clone whose ``load_factor`` is - boosted to the effective HBM-hit ratio. The original cache_params is - restored before returning, so the (separately invoked) storage - estimator still sees the un-boosted ratio. + Temporarily replace ``sharding_option.cache_params`` with a clone + whose ``load_factor`` is the empirically-fitted x_eff for the + (mode, cache_load_factor) combination. Restored before returning so + the (separately invoked) storage estimator still sees the un-boosted + ratio. """ dynamicemb_options = getattr(sharding_option, "dynamicemb_options", None) - caching = bool(getattr(dynamicemb_options, "caching", False)) original_cache_params = sharding_option.cache_params - if caching: + if dynamicemb_options is not None: + caching = bool(getattr(dynamicemb_options, "caching", False)) stats = original_cache_params.stats if original_cache_params else None x_eff = _dynamicemb_effective_cache_ratio( - sharding_option.cache_load_factor, caching=True, stats=stats + sharding_option.cache_load_factor, caching=caching, stats=stats ) sharding_option.cache_params = ( dataclasses.replace(original_cache_params, load_factor=x_eff) diff --git a/tzrec/utils/dynamicemb_util_test.py b/tzrec/utils/dynamicemb_util_test.py index 15d3b6eb..d086a30f 100644 --- a/tzrec/utils/dynamicemb_util_test.py +++ b/tzrec/utils/dynamicemb_util_test.py @@ -9,9 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import unittest -from unittest import mock from parameterized import parameterized @@ -109,56 +107,67 @@ def test_only_values_drops_metadata(self): class EffectiveCacheRatioTest(unittest.TestCase): - """``_dynamicemb_effective_cache_ratio`` — HYBRID vs CACHING perf ratio.""" - - def test_hybrid_passes_through(self): - for x in (0.0, 0.1, 0.5, 1.0): - self.assertEqual( - dynamicemb_util._dynamicemb_effective_cache_ratio(x, caching=False), x - ) - - def test_caching_default_multiplier_saturates(self): - # Default multiplier is 10.0 — any x >= 0.1 saturates to 1.0. - env = { - k: v - for k, v in os.environ.items() - if k != dynamicemb_util.DYNAMICEMB_CACHING_HIT_RATE_MULTIPLIER_ENV - } - with mock.patch.dict(os.environ, env, clear=True): - self.assertEqual( - dynamicemb_util._dynamicemb_effective_cache_ratio(0.1, caching=True), - 1.0, - ) - self.assertEqual( - dynamicemb_util._dynamicemb_effective_cache_ratio(0.5, caching=True), - 1.0, - ) - self.assertEqual( - dynamicemb_util._dynamicemb_effective_cache_ratio(0.01, caching=True), - 0.1, - ) - - def test_caching_env_override(self): - with mock.patch.dict( - os.environ, - {dynamicemb_util.DYNAMICEMB_CACHING_HIT_RATE_MULTIPLIER_ENV: "2.0"}, - ): - # multiplier=2: x=0.3 -> 0.6, x=0.5 -> 1.0 (clamped) - self.assertAlmostEqual( - dynamicemb_util._dynamicemb_effective_cache_ratio(0.3, caching=True), - 0.6, - ) - self.assertEqual( - dynamicemb_util._dynamicemb_effective_cache_ratio(0.5, caching=True), - 1.0, - ) + """``_dynamicemb_effective_cache_ratio`` -- empirical fit from on-device sweep. + + The formula is fitted to the medians of an A10 benchmark sweep recorded + in experiments/sweep_20260513-161030/full_a10gpu1.json. Reproducing the + three-regime empirical pattern: + * x == 1.0: HBM-only fast path, x_eff = 1.0 + * caching=True, x<1.0: x_eff base 0.28 + * caching=False, x<1.0: x_eff base 0.11 + """ + + def test_x_eq_1_returns_1_in_both_modes(self): + # HBM-only fast path: total fits, runtime drops the host tier. + self.assertEqual( + dynamicemb_util._dynamicemb_effective_cache_ratio(1.0, caching=False), 1.0 + ) + self.assertEqual( + dynamicemb_util._dynamicemb_effective_cache_ratio(1.0, caching=True), 1.0 + ) - def test_caching_invariant_monotonic_ge_hybrid(self): - # CACHING ratio >= HYBRID ratio at every cache_load_factor. - for x in (0.0, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 1.0): + def test_caching_above_hybrid_for_x_less_than_1(self): + # Empirically CACHING is ~2x faster than HYBRID at the same ratio. + for x in (0.1, 0.3, 0.5, 0.7, 0.9): hybrid = dynamicemb_util._dynamicemb_effective_cache_ratio(x, caching=False) caching = dynamicemb_util._dynamicemb_effective_cache_ratio(x, caching=True) - self.assertGreaterEqual(caching, hybrid) + self.assertGreater(caching, hybrid) + + def test_monotonic_within_block(self): + # Within each mode, higher cache_load_factor -> higher x_eff (the + # +0.01*x tiebreaker term). + for caching in (False, True): + prev = None + for i in range(1, 10): # x = 0.1 .. 0.9 + x = i / 10 + cur = dynamicemb_util._dynamicemb_effective_cache_ratio( + x, caching=caching + ) + if prev is not None: + self.assertGreater(cur, prev) + prev = cur + + def test_strict_block_ranking_matches_empirical(self): + # HYBRID@1.0 > CACHING@anything < 1.0 > HYBRID@anything < 1.0 + ratios = [i / 10 for i in range(1, 10)] # 0.1 .. 0.9 + hybrid_at_1 = dynamicemb_util._dynamicemb_effective_cache_ratio( + 1.0, caching=False + ) + caching_block = { + x: dynamicemb_util._dynamicemb_effective_cache_ratio(x, caching=True) + for x in ratios + } + hybrid_block = { + x: dynamicemb_util._dynamicemb_effective_cache_ratio(x, caching=False) + for x in ratios + } + # Every CACHING@x<1.0 sits strictly below HYBRID@1.0. + for x_eff in caching_block.values(): + self.assertGreater(hybrid_at_1, x_eff) + # Every CACHING@x<1.0 sits strictly above every HYBRID@x<1.0. + for c in caching_block.values(): + for h in hybrid_block.values(): + self.assertGreater(c, h) def test_stats_override_uses_expected_miss_rate(self): class _Stats: @@ -172,18 +181,19 @@ def expected_miss_rate(self, ratio): ) self.assertAlmostEqual(x_eff, 0.95) - def test_stats_override_never_drops_below_hybrid(self): + def test_stats_override_honored_verbatim(self): + # Stats reflect measured behavior; trust them even if they override + # the empirical heuristic's preferred ordering. class _Stats: expected_lookups = 1000.0 def expected_miss_rate(self, ratio): - return 0.9 # would give x_eff=0.1, below cache_load_factor=0.5 + return 0.9 # x_eff = 0.1 even though CACHING base = 0.28 x_eff = dynamicemb_util._dynamicemb_effective_cache_ratio( 0.5, caching=True, stats=_Stats() ) - # Invariant: CACHING_bw never falls below HYBRID_bw at same ratio. - self.assertGreaterEqual(x_eff, 0.5) + self.assertAlmostEqual(x_eff, 0.1) if __name__ == "__main__": From 5280156cf0f44b371629cacdfc09cc745f13328c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 13 May 2026 16:39:35 +0800 Subject: [PATCH 10/15] [fix] drop duplicate empirical-fit comment above the constants The same provenance is already covered in ``_dynamicemb_effective_cache_ratio``'s docstring below; the block comment above the constants was redundant. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/utils/dynamicemb_util.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tzrec/utils/dynamicemb_util.py b/tzrec/utils/dynamicemb_util.py index 9e68afff..fd9db3b6 100644 --- a/tzrec/utils/dynamicemb_util.py +++ b/tzrec/utils/dynamicemb_util.py @@ -48,19 +48,6 @@ from tzrec.protos import feature_pb2 -# Empirical x_eff constants fitted from an on-device dynamicemb sweep -# (4M-row table, dim=128, adam, pow-law alpha=1.05, A10 GPU; see -# experiments/sweep_20260513-161030/full_a10gpu1.json). Median fwd+bwd -# latency clustered into three regimes: -# * HYBRID @ x=1.0: 0.80 ms (HBM-only fast path; runtime drops the -# host tier when total_value_memory <= local_hbm) -# * CACHING @ x<1.0: 2.63 ms (~3.3x slower than HBM-only) -# * HYBRID @ x<1.0: 5.44 ms (~6.8x slower than HBM-only) -# Within each <1.0 block the ratio dependence is noise-dominated. -# Inverting the linear bw model bw = x_eff*HBM + (1-x_eff)*HBM_TO_DDR -# (torchrec defaults: HBM=897 GB/s, HBM_TO_DDR=32 GB/s) yields the -# constants below. The +0.01*x term is a tiebreaker so the DP can produce -# strictly ordered proposals within each block. _DYNAMICEMB_CACHING_X_EFF_BASE = 0.28 _DYNAMICEMB_HYBRID_X_EFF_BASE = 0.11 _DYNAMICEMB_X_EFF_TIEBREAK = 0.01 From 625e1b77c7f426278857767e5a5f8e1dabd02131 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 13 May 2026 17:02:40 +0800 Subject: [PATCH 11/15] [fix] dynamicemb planner: drop redundant deepcopy in variant emission MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `copy.deepcopy(base_option)` on the first line of each variant body already copies `dynamicemb_options` because the caller in `enumerate()` attaches it onto `base_option` before invoking `_emit_dynamicemb_variants`. The second `copy.deepcopy(dynamicemb_options)` was wasted work — at 20 variants per dynamicemb table it doubled a non-trivial copy. Drop it and mutate the already-fresh per-variant `dynamicemb_options.caching` directly. PR #508 review R3 (github-actions, 2026-05-13). Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/utils/plan_util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tzrec/utils/plan_util.py b/tzrec/utils/plan_util.py index 3e1e7449..dca6cdc0 100644 --- a/tzrec/utils/plan_util.py +++ b/tzrec/utils/plan_util.py @@ -847,9 +847,8 @@ def _emit_dynamicemb_variants( for load_factor in load_factors: opt = copy.deepcopy(base_option) opt.cache_params = CacheParams(load_factor=load_factor, stats=stats) - # pyre-ignore [16] - opt.dynamicemb_options = copy.deepcopy(dynamicemb_options) - opt.dynamicemb_options.caching = caching_mode + # deepcopy(base_option) already produced a fresh dynamicemb_options. + opt.dynamicemb_options.caching = caching_mode # pyre-ignore [16] variants.append(opt) return variants From 3ff8be67ef47ad7476346773d1b3da2728fc50ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 13 May 2026 17:03:24 +0800 Subject: [PATCH 12/15] [fix] dynamicemb planner: cap 2D DP per-axis bins to avoid multi-host blowup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The 2D DP discretizes each memory axis into `bins_per_device * num_devices` bins. On a 16-GPU world that becomes 1600 hbm bins × 800 ddr bins = 1.28M cells per table; with 100 dynamicemb tables and 20 options each, the inner Python loop is a few-billion-op DP and dp + backtrack tuples consume hundreds of MB. The HBM and DDR budgets are scalar, so multiplying bins by num_devices stops adding meaningful precision past a point. Clamp the per-axis bin count at `_DP_AXIS_BIN_CAP` (1024) to keep the multi-host case tractable. NumPy vectorization of the inner DP loop is a bigger refactor; leaving a TODO referencing PR #508 review R5. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/utils/plan_util.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tzrec/utils/plan_util.py b/tzrec/utils/plan_util.py index dca6cdc0..b2bfc620 100644 --- a/tzrec/utils/plan_util.py +++ b/tzrec/utils/plan_util.py @@ -225,6 +225,13 @@ def get_default_sharders() -> List[ModuleSharder[nn.Module]]: _INF = float("inf") +# Cap per-axis bin count for the 2D DP. Without this, `bins_per_device * +# num_devices` blows up on multi-host worlds (e.g. 16 GPUs at 100 hbm +# bins/dev = 1600 bins, * 800 ddr bins = 1.28M cells per table). The +# budgets are scalar (total HBM, total DDR), so multiplying bins by +# num_devices stops adding precision past this cap. NumPy vectorization +# of the inner loop is a follow-up (PR #508 review R5). +_DP_AXIS_BIN_CAP = 1024 class DynamicProgrammingProposer(Proposer): @@ -360,8 +367,12 @@ def feedback( ) max_machine_ddr = max(max_machine_ddr, machine_ddr) - hbm_bins = max(self._hbm_bins_per_device * num_devices, 1) - ddr_bins = max(self._ddr_bins_per_device * num_devices, 1) + hbm_bins = max( + min(self._hbm_bins_per_device * num_devices, _DP_AXIS_BIN_CAP), 1 + ) + ddr_bins = max( + min(self._ddr_bins_per_device * num_devices, _DP_AXIS_BIN_CAP), 1 + ) # Collapse a degenerate axis to a single bin so we don't waste states # on (e.g.) CPU-only topologies that have hbm == 0 everywhere. if hbm_total == 0: From 984d0df3ff6cca26dc894b6d44b6d645de12c4cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 13 May 2026 17:04:36 +0800 Subject: [PATCH 13/15] [doc] dynamicemb planner: tighten three docstrings to match implementation * DynamicProgrammingProposer class docstring: "per-rank DDR" -> "per-rank HBM and per-machine DDR" so it matches the machine-DDR prune introduced in commit 38bd7ad (PR #508 review R4). * _dynamicemb_effective_cache_ratio: call out that the discontinuity at x=1.0 is a deliberate kernel switch (HBM-only DynamicEmbStorage vs dual-tier HybridStorage/DynamicEmbCache), not a smoothing artifact. Note a future refactor could lift this to a discrete `mode` axis on the enumerator (PR #508 review R2). * dynamicemb_calculate_shard_storages caching_ratio Args: rewrite the stale "ratio of HBM to DDR memory for UVM caching" framing -- under the dynamicemb modes, host always holds the full backing in CACHING, and caching_ratio sizes the HBM hot-row cache (top-level docs note). Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/utils/dynamicemb_util.py | 16 +++++++++++++--- tzrec/utils/plan_util.py | 7 +++++-- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/tzrec/utils/dynamicemb_util.py b/tzrec/utils/dynamicemb_util.py index fd9db3b6..f1716a49 100644 --- a/tzrec/utils/dynamicemb_util.py +++ b/tzrec/utils/dynamicemb_util.py @@ -66,8 +66,15 @@ def _dynamicemb_effective_cache_ratio( The ratio is derived from an on-device perf sweep, not a heuristic. Empirical pattern (alpha=1.05 pow-law access on A10): - * ``x == 1.0``: the runtime drops the host tier (HBM-only); both - modes hit the fastest path. Return ``1.0``. + * ``x == 1.0``: the runtime *switches kernels* — when + ``total_value_memory <= local_hbm_for_values`` the dual-tier + ``HybridStorage`` / ``DynamicEmbCache`` paths are dropped in favor + of the HBM-only ``DynamicEmbStorage`` kernel + (``batched_dynamicemb_tables.py:502-604``). The ~8x jump in ``x_eff`` + between ``x=0.9`` and ``x=1.0`` is intentional and matches measured + latency, not a smoothing artifact. (A future refactor could lift + this to a discrete ``mode={HBM_ONLY, HYBRID, CACHING}`` axis on the + enumerator side rather than packing the discontinuity into ``x``.) * ``caching=True``, ``x < 1.0``: 3.3x slower than HBM-only -> base 0.28. * ``caching=False``, ``x < 1.0``: 6.8x slower than HBM-only -> base 0.11. @@ -617,7 +624,10 @@ def dynamicemb_calculate_shard_storages( factors. num_poolings (List[float]): average number of poolings per sample (typically 1.0). - caching_ratio (float): ratio of HBM to DDR memory for UVM caching. + caching_ratio (float): cache_load_factor for the dynamicemb table. + In HYBRID mode HBM holds this fraction of values and host + holds the remainder; in CACHING mode HBM is a hot-row cache + of this fraction and host holds the full backing store. is_pooled (bool): True if embedding output is pooled (ie. `EmbeddingBag`), False if unpooled/sequential (ie. `Embedding`). input_data_type_size (int): number of bytes of input data type. diff --git a/tzrec/utils/plan_util.py b/tzrec/utils/plan_util.py index b2bfc620..ddbcf5dd 100644 --- a/tzrec/utils/plan_util.py +++ b/tzrec/utils/plan_util.py @@ -238,8 +238,11 @@ class DynamicProgrammingProposer(Proposer): r"""Proposes sharding plans via 2D (HBM × DDR) dynamic programming. Given :math:`M` tables each with up to :math:`N` ShardingOptions, pick one - option per table to minimize total perf while respecting both per-rank HBM - and per-rank DDR budgets from the topology. + option per table to minimize total perf while respecting both per-rank + HBM and per-machine DDR budgets from the topology. (HBM is GPU-local, so + each device gets its own quota; DDR is host-shared across ranks + co-located on the same machine, so the prune threshold is the sum over + a ``local_world_size``-sized window.) Each axis (HBM, DDR) is discretized into bins; ``dp[table][h][d]`` holds the minimum perf over the first ``table`` tables using ``≈ h`` HBM bins From 08563bb531ff1e583d26a9f9eff0feb0d347e2f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 13 May 2026 17:06:19 +0800 Subject: [PATCH 14/15] [test] dynamicemb planner: multi-table DP + legacy alias + empty search space Add three DynamicProgrammingProposerTest methods that the original PR #508 review (R6) flagged as missing: test_two_tables_pick_mixed_modes_under_joint_budget Two tables under a topology where both-HYBRID is HBM-infeasible and both-CACHING is DDR-infeasible. The DP must pick one of each. This exercises the cross-table transition at plan_util.py table_i == 1 which the existing single-table cases never reach. test_legacy_mem_bins_per_device_alias Verifies the back-compat kwarg `mem_bins_per_device` still wins over the new `hbm_bins_per_device` when both are passed. test_empty_search_space_returns_empty_proposal Edge case where `load([])` short-circuits via `table_count == 0`. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/utils/plan_util_test.py | 52 +++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tzrec/utils/plan_util_test.py b/tzrec/utils/plan_util_test.py index b1cac92e..2c578876 100644 --- a/tzrec/utils/plan_util_test.py +++ b/tzrec/utils/plan_util_test.py @@ -308,6 +308,58 @@ def test_high_factor_collapses_modes(self): for p in proposals: self.assertEqual(p[0].total_perf, 50.0) + def test_two_tables_pick_mixed_modes_under_joint_budget(self): + # Two tables, each with HYBRID@1.0 (all-HBM, no DDR) and CACHING@0.1 + # (small HBM, full-T DDR). Topology HBM=2000 admits exactly one + # full HYBRID + one small CACHING shard (1500+100), and host DDR + # admits exactly one full-T CACHING backing (1500). Both-HYBRID is + # HBM-infeasible (3000>2000), both-CACHING is DDR-infeasible + # (3000>2000). Only the mixed plan fits. Exercises the + # cross-table DP transition at plan_util.py table_i==1. + opts = [ + _FakeShardingOption("table_a", hbm=1500, ddr=0, perf=50.0), # HYBRID@1.0 + _FakeShardingOption("table_a", hbm=100, ddr=1500, perf=40.0), # CACHING@0.1 + _FakeShardingOption("table_b", hbm=1500, ddr=0, perf=50.0), # HYBRID@1.0 + _FakeShardingOption("table_b", hbm=100, ddr=1500, perf=40.0), # CACHING@0.1 + ] + topology = _make_topology( + num_devices=1, hbm_per_device=2000, ddr_per_device=2000 + ) + proposals = self._run(opts, topology) + self.assertGreater(len(proposals), 0) + best = min(proposals, key=lambda p: sum(o.total_perf for o in p)) + styles = sorted( + "hybrid" if o.shards[0].storage.ddr == 0 else "caching" for o in best + ) + self.assertEqual(styles, ["caching", "hybrid"]) + + def test_legacy_mem_bins_per_device_alias(self): + # The legacy kwarg `mem_bins_per_device` is preserved as an alias + # for `hbm_bins_per_device`. + proposer = DynamicProgrammingProposer(mem_bins_per_device=77) + self.assertEqual(proposer._hbm_bins_per_device, 77) + # Explicit hbm_bins_per_device takes precedence: when both are + # provided, mem_bins_per_device wins (it's the legacy override). + proposer = DynamicProgrammingProposer( + mem_bins_per_device=77, hbm_bins_per_device=99 + ) + self.assertEqual(proposer._hbm_bins_per_device, 77) + + def test_empty_search_space_returns_empty_proposal(self): + proposer = DynamicProgrammingProposer() + proposer.load([]) + # Seed proposal is the per-table first option; with no tables the + # list is empty. + self.assertEqual(proposer.propose(), []) + # Feedback must not raise on an empty proposer (it short-circuits + # via `table_count == 0`). + topology = _make_topology( + num_devices=1, hbm_per_device=1000, ddr_per_device=1000 + ) + proposer.feedback(partitionable=True, storage_constraint=topology) + # After feedback, no proposals are available. + self.assertIsNone(proposer.propose()) + @unittest.skipUnless(has_dynamicemb, "dynamicemb is not installed; skipping.") @unittest.skipUnless(torch.cuda.is_available(), "CUDA is required for dynamicemb.") From 0c0d3d7b15a0c67f07b4a59bde05c53122d019a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Wed, 13 May 2026 17:09:47 +0800 Subject: [PATCH 15/15] [fix] dynamicemb planner: restore cache_params in finally + direct wrapper tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Combine PR #508 review R1 + R7. The new ``BuildShardPerfContextsWrapperTest.test_cache_params_restored_on_exception`` exists to pin the bug R1 describes: when ``_orig_build_shard_perf_contexts`` raises, the boosted ``x_eff`` clone of ``sharding_option.cache_params`` must not leak to the downstream storage estimator's view of the same option. The corresponding ``try/finally`` around the wrapped call closes the leak. Reverts the previous "single return, no try/finally" simplification. Also adds: * ``BuildShardPerfContextsWrapperTest`` — direct unit tests for the ShardPerfContext.build_shard_perf_contexts monkey-patch covering: boost applied for caching, boost applied for hybrid, no boost when the option has no dynamicemb_options, cache_params restored on success, and the restore-on-exception case driving R1. * ``DynamicEmbCalcShardStoragesTest.test_admission_counter_increases_hbm`` — direct cover for the admission_counter HBM accounting branch in ``dynamicemb_calculate_shard_storages`` that the CUDA-gated E2E test doesn't exercise on a CPU dev box. Co-Authored-By: Claude Opus 4.7 (1M context) --- tzrec/utils/dynamicemb_util.py | 29 ++-- tzrec/utils/dynamicemb_util_test.py | 199 ++++++++++++++++++++++++++++ 2 files changed, 216 insertions(+), 12 deletions(-) diff --git a/tzrec/utils/dynamicemb_util.py b/tzrec/utils/dynamicemb_util.py index f1716a49..bdfebc00 100644 --- a/tzrec/utils/dynamicemb_util.py +++ b/tzrec/utils/dynamicemb_util.py @@ -519,18 +519,23 @@ def _dynamicemb_aware_build_shard_perf_contexts( if original_cache_params is not None else CacheParams(load_factor=x_eff) ) - result = _orig_build_shard_perf_contexts( - cls, - config, - shard_sizes, - sharding_option, - topology, - constraints, - sharder, - *args, - **kwargs, - ) - sharding_option.cache_params = original_cache_params + # try/finally so an estimator exception cannot leak the boosted + # cache_params clone into the storage estimator's view of the + # same ShardingOption. + try: + result = _orig_build_shard_perf_contexts( + cls, + config, + shard_sizes, + sharding_option, + topology, + constraints, + sharder, + *args, + **kwargs, + ) + finally: + sharding_option.cache_params = original_cache_params return result # pyre-ignore [9] diff --git a/tzrec/utils/dynamicemb_util_test.py b/tzrec/utils/dynamicemb_util_test.py index d086a30f..1d2d165c 100644 --- a/tzrec/utils/dynamicemb_util_test.py +++ b/tzrec/utils/dynamicemb_util_test.py @@ -10,8 +10,12 @@ # limitations under the License. import unittest +from types import SimpleNamespace +from unittest import mock +import torch from parameterized import parameterized +from torchrec.distributed.types import CacheParams from tzrec.utils import dynamicemb_util @@ -196,5 +200,200 @@ def expected_miss_rate(self, ratio): self.assertAlmostEqual(x_eff, 0.1) +@unittest.skipUnless( + dynamicemb_util.has_dynamicemb, "dynamicemb is not installed; skipping." +) +class BuildShardPerfContextsWrapperTest(unittest.TestCase): + """``_dynamicemb_aware_build_shard_perf_contexts`` swap + restore. + + Mocks ``_orig_build_shard_perf_contexts`` so we can drive the wrapper + without the heavy upstream ShardPerfContext machinery, and verify the + boost is applied, the cache_params is restored after the call, and the + restore still runs when the wrapped call raises. + """ + + def _call(self, sharding_option): + # ShardPerfContext.build_shard_perf_contexts is the patched + # classmethod after dynamicemb_util import. The classmethod + # descriptor auto-injects ``cls``, so we pass 6 positional args + # (config, shard_sizes, sharding_option, topology, constraints, + # sharder). + from torchrec.distributed.planner.estimator.types import ShardPerfContext + + return ShardPerfContext.build_shard_perf_contexts( + None, None, sharding_option, None, None, None + ) + + def _spy_recording_cache_params(self, recorder): + def _spy(cls, config, shard_sizes, sharding_option, *args, **kwargs): + recorder.append(sharding_option.cache_params.load_factor) + return [] + + return _spy + + def test_boost_applied_for_caching(self): + seen = [] + with mock.patch.object( + dynamicemb_util, + "_orig_build_shard_perf_contexts", + self._spy_recording_cache_params(seen), + ): + so = SimpleNamespace( + dynamicemb_options=SimpleNamespace(caching=True), + cache_params=CacheParams(load_factor=0.5), + cache_load_factor=0.5, + ) + self._call(so) + expected = dynamicemb_util._dynamicemb_effective_cache_ratio(0.5, caching=True) + self.assertEqual(len(seen), 1) + self.assertAlmostEqual(seen[0], expected) + + def test_boost_applied_for_hybrid(self): + seen = [] + with mock.patch.object( + dynamicemb_util, + "_orig_build_shard_perf_contexts", + self._spy_recording_cache_params(seen), + ): + so = SimpleNamespace( + dynamicemb_options=SimpleNamespace(caching=False), + cache_params=CacheParams(load_factor=0.5), + cache_load_factor=0.5, + ) + self._call(so) + expected = dynamicemb_util._dynamicemb_effective_cache_ratio(0.5, caching=False) + self.assertAlmostEqual(seen[0], expected) + + def test_no_boost_when_no_dynamicemb_options(self): + # Non-dynamicemb ShardingOption has no `dynamicemb_options` + # attribute; the wrapper must leave cache_params untouched. + seen = [] + with mock.patch.object( + dynamicemb_util, + "_orig_build_shard_perf_contexts", + self._spy_recording_cache_params(seen), + ): + so = SimpleNamespace( + cache_params=CacheParams(load_factor=0.7), + cache_load_factor=0.7, + ) + self._call(so) + self.assertEqual(seen[0], 0.7) + + def test_cache_params_restored_on_success(self): + original = CacheParams(load_factor=0.3) + with mock.patch.object( + dynamicemb_util, + "_orig_build_shard_perf_contexts", + lambda cls, c, s, so, *a, **kw: [], + ): + so = SimpleNamespace( + dynamicemb_options=SimpleNamespace(caching=True), + cache_params=original, + cache_load_factor=0.3, + ) + self._call(so) + # Same identity, not just same load_factor. + self.assertIs(so.cache_params, original) + + def test_cache_params_restored_on_exception(self): + """Restore must run even when the wrapped call raises (R1). + + Without try/finally this test fails -- the boosted cache_params + leaks out and corrupts downstream consumers. + """ + original = CacheParams(load_factor=0.3) + + class _Boom(RuntimeError): + pass + + def raiser(cls, c, s, so, *a, **kw): + raise _Boom("simulated estimator failure") + + with mock.patch.object( + dynamicemb_util, "_orig_build_shard_perf_contexts", raiser + ): + so = SimpleNamespace( + dynamicemb_options=SimpleNamespace(caching=True), + cache_params=original, + cache_load_factor=0.3, + ) + with self.assertRaises(_Boom): + self._call(so) + self.assertIs(so.cache_params, original) + + +@unittest.skipUnless( + dynamicemb_util.has_dynamicemb, "dynamicemb is not installed; skipping." +) +class DynamicEmbCalcShardStoragesTest(unittest.TestCase): + """Direct test of ``dynamicemb_calculate_shard_storages``.""" + + def _build_options(self, *, with_admission_counter=False): + import dynamicemb + + kwargs = dict( + max_capacity=1024, + initializer_args=dynamicemb.DynamicEmbInitializerArgs( + mode=dynamicemb.DynamicEmbInitializerMode.UNIFORM, + lower=-0.01, + upper=0.01, + ), + eval_initializer_args=dynamicemb.DynamicEmbInitializerArgs( + mode=dynamicemb.DynamicEmbInitializerMode.CONSTANT, value=0.0 + ), + score_strategy=dynamicemb.DynamicEmbScoreStrategy.STEP, + ) + if with_admission_counter: + kwargs["admission_counter"] = dynamicemb.KVCounter( + capacity=1024, bucket_capacity=128 + ) + kwargs["admit_strategy"] = dynamicemb.FrequencyAdmissionStrategy( + threshold=1, + initializer_args=dynamicemb.DynamicEmbInitializerArgs( + mode=dynamicemb.DynamicEmbInitializerMode.CONSTANT, value=0.0 + ), + ) + return dynamicemb.DynamicEmbTableOptions(**kwargs) + + def _base_kwargs( + self, dynamicemb_options, *, compute_device="cuda", is_inference=False + ): + from torchrec.distributed.embedding_types import EmbeddingComputeKernel + from torchrec.distributed.types import ShardingType + + return dict( + sharder=None, + sharding_type=ShardingType.ROW_WISE.value, + tensor=torch.empty(1024, 64, dtype=torch.float32), + compute_device=compute_device, + compute_kernel=EmbeddingComputeKernel.CUSTOMIZED_KERNEL.value, + shard_sizes=[[512, 64], [512, 64]], + batch_sizes=[32, 32], + world_size=2, + local_world_size=2, + input_lengths=[1.0], + num_poolings=[1.0], + caching_ratio=0.5, + is_pooled=True, + input_data_type_size=8.0, + output_data_type_size=4.0, + is_inference=is_inference, + dynamicemb_options=dynamicemb_options, + ) + + def test_admission_counter_increases_hbm(self): + baseline = dynamicemb_util.dynamicemb_calculate_shard_storages( + **self._base_kwargs(self._build_options(with_admission_counter=False)) + ) + with_counter = dynamicemb_util.dynamicemb_calculate_shard_storages( + **self._base_kwargs(self._build_options(with_admission_counter=True)) + ) + # Counter is HBM-side only — DDR matches, HBM grows. + for base, w in zip(baseline, with_counter): + self.assertGreater(w.hbm, base.hbm) + self.assertEqual(w.ddr, base.ddr) + + if __name__ == "__main__": unittest.main()