diff --git a/tzrec/utils/dynamicemb_util.py b/tzrec/utils/dynamicemb_util.py index 09e9bf2b..bdfebc00 100644 --- a/tzrec/utils/dynamicemb_util.py +++ b/tzrec/utils/dynamicemb_util.py @@ -9,9 +9,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses import math import os -from typing import List, Optional, Tuple, Type, cast +from typing import Any, List, Optional, Tuple, Type, cast import torch from torch import nn @@ -21,7 +22,10 @@ planners, shard_estimators, ) -from torchrec.distributed.planner.estimator.types import HardwarePerfConfig +from torchrec.distributed.planner.estimator.types import ( + HardwarePerfConfig, + ShardPerfContext, +) from torchrec.distributed.planner.types import ( ParameterConstraints, ShardingOption, @@ -44,6 +48,54 @@ from tzrec.protos import feature_pb2 +_DYNAMICEMB_CACHING_X_EFF_BASE = 0.28 +_DYNAMICEMB_HYBRID_X_EFF_BASE = 0.11 +_DYNAMICEMB_X_EFF_TIEBREAK = 0.01 + + +def _dynamicemb_effective_cache_ratio( + cache_load_factor: Optional[float], + caching: bool, + stats: Optional[Any] = None, +) -> float: + """Effective HBM-hit ratio for the dynamicemb perf model. + + Returns the value passed to torchrec's perf bandwidth formula + ``bw = x_eff*hbm + (1-x_eff)*hbm_to_ddr_bw``. Larger value = faster path. + + The ratio is derived from an on-device perf sweep, not a heuristic. + Empirical pattern (alpha=1.05 pow-law access on A10): + + * ``x == 1.0``: the runtime *switches kernels* — when + ``total_value_memory <= local_hbm_for_values`` the dual-tier + ``HybridStorage`` / ``DynamicEmbCache`` paths are dropped in favor + of the HBM-only ``DynamicEmbStorage`` kernel + (``batched_dynamicemb_tables.py:502-604``). The ~8x jump in ``x_eff`` + between ``x=0.9`` and ``x=1.0`` is intentional and matches measured + latency, not a smoothing artifact. (A future refactor could lift + this to a discrete ``mode={HBM_ONLY, HYBRID, CACHING}`` axis on the + enumerator side rather than packing the discontinuity into ``x``.) + * ``caching=True``, ``x < 1.0``: 3.3x slower than HBM-only -> base 0.28. + * ``caching=False``, ``x < 1.0``: 6.8x slower than HBM-only -> base 0.11. + + Within each ``x < 1.0`` block the perf is roughly flat in ratio, but we + add a tiny monotonic perturbation so the DP can break ties. + + If ``stats`` is provided, ``1 - stats.expected_miss_rate(x)`` overrides + the heuristic verbatim (clamped to ``[0, 1]``); the caller opts in to + their own measurement. + """ + x = float(cache_load_factor) if cache_load_factor is not None else 0.0 + x = max(0.0, min(1.0, x)) + if stats is not None: + miss_rate = float(stats.expected_miss_rate(x)) + return max(0.0, min(1.0, 1.0 - miss_rate)) + if x >= 1.0: + return 1.0 + base = _DYNAMICEMB_CACHING_X_EFF_BASE if caching else _DYNAMICEMB_HYBRID_X_EFF_BASE + return base + _DYNAMICEMB_X_EFF_TIEBREAK * x + + has_dynamicemb = False try: import dynamicemb @@ -258,18 +310,36 @@ def _calculate_dynamicemb_table_storage_specific_size( is_hbm: bool = True, only_values: bool = False, bucket_capacity: int = 128, + caching: bool = False, ) -> int: - """Calculate dynamic embedding table storage. - - total_value_memory = max_capacity x aligned16(embedding+optimizer states) - num_buckets = max_capacity/bucket_capacity - hbm_budget = min(global_hbm_for_values//world_size, total_value_memory) + - max_capacity x (key<8byte> + score<8byte> + digest<1byte>) + - num_buckets x (bucket_size<4byte>) - ddr_budget = max(total_value_memory - global_hbm_for_values//world_size, 0) + """Per-shard storage size for a dynamicemb table -- HBM or DDR (bytes). + + Byte budget (single shard, rows x dim): + + value_bytes_per_row = round_up16(dim * (1 + opt_mult) * element) + total_value_memory = align(rows) * value_bytes_per_row + num_buckets = align(rows) / bucket_capacity + + hbm_budget = cache_ratio * total_value_memory # values + + align(rows) * (key<8B> + score<8B> + digest<1B>) # per-row + + num_buckets * bucket_header<4B> # per-bucket + + ddr_budget = HYBRID (caching=False): (1 - cache_ratio) * total_value_memory + CACHING (caching=True): total_value_memory # full backing + + HYBRID hash-partitions values across HBM and host; ``cache_ratio`` is + HBM's value share. CACHING keeps the full backing store on host and + uses HBM as a hot-row cache of size + ``cache_ratio * total_value_memory``. Hash-table metadata + (key + score + digest + bucket header) is accounted on HBM only -- + matches the existing tzrec convention. """ if cache_ratio is None: cache_ratio = 1.0 + if is_hbm: + value_ratio = cache_ratio + else: + value_ratio = 1.0 if caching else (1.0 - cache_ratio) return math.ceil( align_to_table_size(size[0]) * ( @@ -277,7 +347,7 @@ def _calculate_dynamicemb_table_storage_specific_size( math.ceil(size[1] * (1 + optimizer_multipler) * element_size), 16, ) - * (cache_ratio if is_hbm else 1 - cache_ratio) + * value_ratio + (8 + 8 + 1 + 4 / bucket_capacity) * (is_hbm and not only_values) ) ) @@ -413,6 +483,66 @@ def _customized_kernel_aware_get_device_bw( # pyre-ignore [9] HardwarePerfConfig.get_device_bw = _customized_kernel_aware_get_device_bw + _orig_build_shard_perf_contexts = ( + ShardPerfContext.build_shard_perf_contexts.__func__ + ) + + def _dynamicemb_aware_build_shard_perf_contexts( + cls, # pyre-ignore [2] + config, # pyre-ignore [2] + shard_sizes, # pyre-ignore [2] + sharding_option, # pyre-ignore [2] + topology, # pyre-ignore [2] + constraints, # pyre-ignore [2] + sharder, # pyre-ignore [2] + *args, # pyre-ignore [2] + **kwargs, # pyre-ignore [2] + ): + """Inject the empirical x_eff into the perf estimator for both modes. + + Temporarily replace ``sharding_option.cache_params`` with a clone + whose ``load_factor`` is the empirically-fitted x_eff for the + (mode, cache_load_factor) combination. Restored before returning so + the (separately invoked) storage estimator still sees the un-boosted + ratio. + """ + dynamicemb_options = getattr(sharding_option, "dynamicemb_options", None) + original_cache_params = sharding_option.cache_params + if dynamicemb_options is not None: + caching = bool(getattr(dynamicemb_options, "caching", False)) + stats = original_cache_params.stats if original_cache_params else None + x_eff = _dynamicemb_effective_cache_ratio( + sharding_option.cache_load_factor, caching=caching, stats=stats + ) + sharding_option.cache_params = ( + dataclasses.replace(original_cache_params, load_factor=x_eff) + if original_cache_params is not None + else CacheParams(load_factor=x_eff) + ) + # try/finally so an estimator exception cannot leak the boosted + # cache_params clone into the storage estimator's view of the + # same ShardingOption. + try: + result = _orig_build_shard_perf_contexts( + cls, + config, + shard_sizes, + sharding_option, + topology, + constraints, + sharder, + *args, + **kwargs, + ) + finally: + sharding_option.cache_params = original_cache_params + return result + + # pyre-ignore [9] + ShardPerfContext.build_shard_perf_contexts = classmethod( + _dynamicemb_aware_build_shard_perf_contexts + ) + def _calculate_dynamicemb_storage_specific_sizes( tensor: torch.Tensor, shard_sizes: List[List[int]], @@ -420,6 +550,7 @@ def _calculate_dynamicemb_storage_specific_sizes( cache_ratio: float = 1.0, is_inference: bool = False, bucket_capacity: int = 128, + caching: bool = False, ) -> Tuple[List[int], List[int]]: """Calculate storage for dynamicemb.""" optimizer_multipler = 0.0 @@ -437,6 +568,7 @@ def _calculate_dynamicemb_storage_specific_sizes( cache_ratio, is_hbm=True, bucket_capacity=bucket_capacity, + caching=caching, ) for size in shard_sizes ] @@ -449,6 +581,7 @@ def _calculate_dynamicemb_storage_specific_sizes( cache_ratio, is_hbm=False, bucket_capacity=bucket_capacity, + caching=caching, ) for size in shard_sizes ] @@ -496,7 +629,10 @@ def dynamicemb_calculate_shard_storages( factors. num_poolings (List[float]): average number of poolings per sample (typically 1.0). - caching_ratio (float): ratio of HBM to DDR memory for UVM caching. + caching_ratio (float): cache_load_factor for the dynamicemb table. + In HYBRID mode HBM holds this fraction of values and host + holds the remainder; in CACHING mode HBM is a hot-row cache + of this fraction and host holds the full backing store. is_pooled (bool): True if embedding output is pooled (ie. `EmbeddingBag`), False if unpooled/sequential (ie. `Embedding`). input_data_type_size (int): number of bytes of input data type. @@ -535,6 +671,7 @@ def dynamicemb_calculate_shard_storages( cache_ratio=caching_ratio if caching_ratio else 1.0, is_inference=is_inference, bucket_capacity=dynamicemb_options.bucket_capacity, + caching=bool(getattr(dynamicemb_options, "caching", False)), ) ) counter_hbm_specific_size = 0 diff --git a/tzrec/utils/dynamicemb_util_test.py b/tzrec/utils/dynamicemb_util_test.py new file mode 100644 index 00000000..1d2d165c --- /dev/null +++ b/tzrec/utils/dynamicemb_util_test.py @@ -0,0 +1,399 @@ +# Copyright (c) 2026, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from types import SimpleNamespace +from unittest import mock + +import torch +from parameterized import parameterized +from torchrec.distributed.types import CacheParams + +from tzrec.utils import dynamicemb_util + + +@unittest.skipUnless( + dynamicemb_util.has_dynamicemb, "dynamicemb is not installed; skipping." +) +class StorageFormulaTest(unittest.TestCase): + """Mode-aware ``_calculate_dynamicemb_table_storage_specific_size``.""" + + ROWS = 1024 + DIM = 64 + ELEMENT_SIZE = 4 + BUCKET_CAPACITY = 128 + + def _calc(self, *, cache_ratio, is_hbm, caching, only_values=False): + return dynamicemb_util._calculate_dynamicemb_table_storage_specific_size( + size=[self.ROWS, self.DIM], + element_size=self.ELEMENT_SIZE, + cache_ratio=cache_ratio, + is_hbm=is_hbm, + only_values=only_values, + bucket_capacity=self.BUCKET_CAPACITY, + caching=caching, + ) + + @parameterized.expand( + [ + ("ratio_0_0", 0.0), + ("ratio_0_25", 0.25), + ("ratio_0_5", 0.5), + ("ratio_0_75", 0.75), + ("ratio_1_0", 1.0), + ] + ) + def test_hbm_identical_between_modes(self, _name, cache_ratio): + # HBM accounting is the same in HYBRID and CACHING: HBM holds a + # cache_ratio fraction of values plus full-row-count metadata. + hybrid_hbm = self._calc(cache_ratio=cache_ratio, is_hbm=True, caching=False) + caching_hbm = self._calc(cache_ratio=cache_ratio, is_hbm=True, caching=True) + self.assertEqual(hybrid_hbm, caching_hbm) + + @parameterized.expand( + [ + ("ratio_0_0", 0.0), + ("ratio_0_25", 0.25), + ("ratio_0_5", 0.5), + ("ratio_0_75", 0.75), + ("ratio_1_0", 1.0), + ] + ) + def test_ddr_hybrid_complements_cache(self, _name, cache_ratio): + # HYBRID DDR = (1 - cache_ratio) * full-table DDR. + full_ddr = self._calc(cache_ratio=0.0, is_hbm=False, caching=False) + hybrid_ddr = self._calc(cache_ratio=cache_ratio, is_hbm=False, caching=False) + self.assertEqual(hybrid_ddr, round((1.0 - cache_ratio) * full_ddr)) + + @parameterized.expand( + [ + ("ratio_0_0", 0.0), + ("ratio_0_25", 0.25), + ("ratio_0_5", 0.5), + ("ratio_0_75", 0.75), + ("ratio_1_0", 1.0), + ] + ) + def test_ddr_caching_holds_full_table(self, _name, cache_ratio): + # CACHING DDR is the full backing store, independent of cache_ratio. + full_ddr = self._calc(cache_ratio=0.0, is_hbm=False, caching=False) + caching_ddr = self._calc(cache_ratio=cache_ratio, is_hbm=False, caching=True) + self.assertEqual(caching_ddr, full_ddr) + + def test_caching_ddr_strictly_greater_than_hybrid_when_cached(self): + for cache_ratio in (0.1, 0.5, 0.9): + hybrid_ddr = self._calc( + cache_ratio=cache_ratio, is_hbm=False, caching=False + ) + caching_ddr = self._calc( + cache_ratio=cache_ratio, is_hbm=False, caching=True + ) + self.assertGreater(caching_ddr, hybrid_ddr) + + def test_only_values_drops_metadata(self): + # only_values=True strips HBM metadata regardless of mode. + for caching in (False, True): + with_meta = self._calc( + cache_ratio=0.5, is_hbm=True, caching=caching, only_values=False + ) + without_meta = self._calc( + cache_ratio=0.5, is_hbm=True, caching=caching, only_values=True + ) + self.assertGreater(with_meta, without_meta) + + +class EffectiveCacheRatioTest(unittest.TestCase): + """``_dynamicemb_effective_cache_ratio`` -- empirical fit from on-device sweep. + + The formula is fitted to the medians of an A10 benchmark sweep recorded + in experiments/sweep_20260513-161030/full_a10gpu1.json. Reproducing the + three-regime empirical pattern: + * x == 1.0: HBM-only fast path, x_eff = 1.0 + * caching=True, x<1.0: x_eff base 0.28 + * caching=False, x<1.0: x_eff base 0.11 + """ + + def test_x_eq_1_returns_1_in_both_modes(self): + # HBM-only fast path: total fits, runtime drops the host tier. + self.assertEqual( + dynamicemb_util._dynamicemb_effective_cache_ratio(1.0, caching=False), 1.0 + ) + self.assertEqual( + dynamicemb_util._dynamicemb_effective_cache_ratio(1.0, caching=True), 1.0 + ) + + def test_caching_above_hybrid_for_x_less_than_1(self): + # Empirically CACHING is ~2x faster than HYBRID at the same ratio. + for x in (0.1, 0.3, 0.5, 0.7, 0.9): + hybrid = dynamicemb_util._dynamicemb_effective_cache_ratio(x, caching=False) + caching = dynamicemb_util._dynamicemb_effective_cache_ratio(x, caching=True) + self.assertGreater(caching, hybrid) + + def test_monotonic_within_block(self): + # Within each mode, higher cache_load_factor -> higher x_eff (the + # +0.01*x tiebreaker term). + for caching in (False, True): + prev = None + for i in range(1, 10): # x = 0.1 .. 0.9 + x = i / 10 + cur = dynamicemb_util._dynamicemb_effective_cache_ratio( + x, caching=caching + ) + if prev is not None: + self.assertGreater(cur, prev) + prev = cur + + def test_strict_block_ranking_matches_empirical(self): + # HYBRID@1.0 > CACHING@anything < 1.0 > HYBRID@anything < 1.0 + ratios = [i / 10 for i in range(1, 10)] # 0.1 .. 0.9 + hybrid_at_1 = dynamicemb_util._dynamicemb_effective_cache_ratio( + 1.0, caching=False + ) + caching_block = { + x: dynamicemb_util._dynamicemb_effective_cache_ratio(x, caching=True) + for x in ratios + } + hybrid_block = { + x: dynamicemb_util._dynamicemb_effective_cache_ratio(x, caching=False) + for x in ratios + } + # Every CACHING@x<1.0 sits strictly below HYBRID@1.0. + for x_eff in caching_block.values(): + self.assertGreater(hybrid_at_1, x_eff) + # Every CACHING@x<1.0 sits strictly above every HYBRID@x<1.0. + for c in caching_block.values(): + for h in hybrid_block.values(): + self.assertGreater(c, h) + + def test_stats_override_uses_expected_miss_rate(self): + class _Stats: + expected_lookups = 1000.0 + + def expected_miss_rate(self, ratio): + return 0.05 # 95% hit rate regardless of ratio + + x_eff = dynamicemb_util._dynamicemb_effective_cache_ratio( + 0.2, caching=True, stats=_Stats() + ) + self.assertAlmostEqual(x_eff, 0.95) + + def test_stats_override_honored_verbatim(self): + # Stats reflect measured behavior; trust them even if they override + # the empirical heuristic's preferred ordering. + class _Stats: + expected_lookups = 1000.0 + + def expected_miss_rate(self, ratio): + return 0.9 # x_eff = 0.1 even though CACHING base = 0.28 + + x_eff = dynamicemb_util._dynamicemb_effective_cache_ratio( + 0.5, caching=True, stats=_Stats() + ) + self.assertAlmostEqual(x_eff, 0.1) + + +@unittest.skipUnless( + dynamicemb_util.has_dynamicemb, "dynamicemb is not installed; skipping." +) +class BuildShardPerfContextsWrapperTest(unittest.TestCase): + """``_dynamicemb_aware_build_shard_perf_contexts`` swap + restore. + + Mocks ``_orig_build_shard_perf_contexts`` so we can drive the wrapper + without the heavy upstream ShardPerfContext machinery, and verify the + boost is applied, the cache_params is restored after the call, and the + restore still runs when the wrapped call raises. + """ + + def _call(self, sharding_option): + # ShardPerfContext.build_shard_perf_contexts is the patched + # classmethod after dynamicemb_util import. The classmethod + # descriptor auto-injects ``cls``, so we pass 6 positional args + # (config, shard_sizes, sharding_option, topology, constraints, + # sharder). + from torchrec.distributed.planner.estimator.types import ShardPerfContext + + return ShardPerfContext.build_shard_perf_contexts( + None, None, sharding_option, None, None, None + ) + + def _spy_recording_cache_params(self, recorder): + def _spy(cls, config, shard_sizes, sharding_option, *args, **kwargs): + recorder.append(sharding_option.cache_params.load_factor) + return [] + + return _spy + + def test_boost_applied_for_caching(self): + seen = [] + with mock.patch.object( + dynamicemb_util, + "_orig_build_shard_perf_contexts", + self._spy_recording_cache_params(seen), + ): + so = SimpleNamespace( + dynamicemb_options=SimpleNamespace(caching=True), + cache_params=CacheParams(load_factor=0.5), + cache_load_factor=0.5, + ) + self._call(so) + expected = dynamicemb_util._dynamicemb_effective_cache_ratio(0.5, caching=True) + self.assertEqual(len(seen), 1) + self.assertAlmostEqual(seen[0], expected) + + def test_boost_applied_for_hybrid(self): + seen = [] + with mock.patch.object( + dynamicemb_util, + "_orig_build_shard_perf_contexts", + self._spy_recording_cache_params(seen), + ): + so = SimpleNamespace( + dynamicemb_options=SimpleNamespace(caching=False), + cache_params=CacheParams(load_factor=0.5), + cache_load_factor=0.5, + ) + self._call(so) + expected = dynamicemb_util._dynamicemb_effective_cache_ratio(0.5, caching=False) + self.assertAlmostEqual(seen[0], expected) + + def test_no_boost_when_no_dynamicemb_options(self): + # Non-dynamicemb ShardingOption has no `dynamicemb_options` + # attribute; the wrapper must leave cache_params untouched. + seen = [] + with mock.patch.object( + dynamicemb_util, + "_orig_build_shard_perf_contexts", + self._spy_recording_cache_params(seen), + ): + so = SimpleNamespace( + cache_params=CacheParams(load_factor=0.7), + cache_load_factor=0.7, + ) + self._call(so) + self.assertEqual(seen[0], 0.7) + + def test_cache_params_restored_on_success(self): + original = CacheParams(load_factor=0.3) + with mock.patch.object( + dynamicemb_util, + "_orig_build_shard_perf_contexts", + lambda cls, c, s, so, *a, **kw: [], + ): + so = SimpleNamespace( + dynamicemb_options=SimpleNamespace(caching=True), + cache_params=original, + cache_load_factor=0.3, + ) + self._call(so) + # Same identity, not just same load_factor. + self.assertIs(so.cache_params, original) + + def test_cache_params_restored_on_exception(self): + """Restore must run even when the wrapped call raises (R1). + + Without try/finally this test fails -- the boosted cache_params + leaks out and corrupts downstream consumers. + """ + original = CacheParams(load_factor=0.3) + + class _Boom(RuntimeError): + pass + + def raiser(cls, c, s, so, *a, **kw): + raise _Boom("simulated estimator failure") + + with mock.patch.object( + dynamicemb_util, "_orig_build_shard_perf_contexts", raiser + ): + so = SimpleNamespace( + dynamicemb_options=SimpleNamespace(caching=True), + cache_params=original, + cache_load_factor=0.3, + ) + with self.assertRaises(_Boom): + self._call(so) + self.assertIs(so.cache_params, original) + + +@unittest.skipUnless( + dynamicemb_util.has_dynamicemb, "dynamicemb is not installed; skipping." +) +class DynamicEmbCalcShardStoragesTest(unittest.TestCase): + """Direct test of ``dynamicemb_calculate_shard_storages``.""" + + def _build_options(self, *, with_admission_counter=False): + import dynamicemb + + kwargs = dict( + max_capacity=1024, + initializer_args=dynamicemb.DynamicEmbInitializerArgs( + mode=dynamicemb.DynamicEmbInitializerMode.UNIFORM, + lower=-0.01, + upper=0.01, + ), + eval_initializer_args=dynamicemb.DynamicEmbInitializerArgs( + mode=dynamicemb.DynamicEmbInitializerMode.CONSTANT, value=0.0 + ), + score_strategy=dynamicemb.DynamicEmbScoreStrategy.STEP, + ) + if with_admission_counter: + kwargs["admission_counter"] = dynamicemb.KVCounter( + capacity=1024, bucket_capacity=128 + ) + kwargs["admit_strategy"] = dynamicemb.FrequencyAdmissionStrategy( + threshold=1, + initializer_args=dynamicemb.DynamicEmbInitializerArgs( + mode=dynamicemb.DynamicEmbInitializerMode.CONSTANT, value=0.0 + ), + ) + return dynamicemb.DynamicEmbTableOptions(**kwargs) + + def _base_kwargs( + self, dynamicemb_options, *, compute_device="cuda", is_inference=False + ): + from torchrec.distributed.embedding_types import EmbeddingComputeKernel + from torchrec.distributed.types import ShardingType + + return dict( + sharder=None, + sharding_type=ShardingType.ROW_WISE.value, + tensor=torch.empty(1024, 64, dtype=torch.float32), + compute_device=compute_device, + compute_kernel=EmbeddingComputeKernel.CUSTOMIZED_KERNEL.value, + shard_sizes=[[512, 64], [512, 64]], + batch_sizes=[32, 32], + world_size=2, + local_world_size=2, + input_lengths=[1.0], + num_poolings=[1.0], + caching_ratio=0.5, + is_pooled=True, + input_data_type_size=8.0, + output_data_type_size=4.0, + is_inference=is_inference, + dynamicemb_options=dynamicemb_options, + ) + + def test_admission_counter_increases_hbm(self): + baseline = dynamicemb_util.dynamicemb_calculate_shard_storages( + **self._base_kwargs(self._build_options(with_admission_counter=False)) + ) + with_counter = dynamicemb_util.dynamicemb_calculate_shard_storages( + **self._base_kwargs(self._build_options(with_admission_counter=True)) + ) + # Counter is HBM-side only — DDR matches, HBM grows. + for base, w in zip(baseline, with_counter): + self.assertGreater(w.hbm, base.hbm) + self.assertEqual(w.ddr, base.ddr) + + +if __name__ == "__main__": + unittest.main() diff --git a/tzrec/utils/plan_util.py b/tzrec/utils/plan_util.py index bf5c69e0..ddbcf5dd 100644 --- a/tzrec/utils/plan_util.py +++ b/tzrec/utils/plan_util.py @@ -224,70 +224,80 @@ def get_default_sharders() -> List[ModuleSharder[nn.Module]]: return sharders -class DynamicProgrammingProposer(Proposer): - r"""Proposes sharding plans in dynamic programming fashion. - - The problem of the Embedding Sharding Plan can be framed as follows: Given - :math:`M` tables and their corresponding :math:`N` Sharding Options, we need to - select one sharding option for each table such that the total performance is - minimized, while keeping the overall memory constraint :math:`K` in check. This can - be abstracted into the following mathematical formulation: - - Given a matrix :math:`A` of dimensions :math:`(M, N)` and another matrix :math:`B` - of the same dimensions, let the elements of matrix :math:`A` be denoted as - :math:`a_{i,j}` and the elements of matrix :math:`B` as :math:`b_{i,j}`. We aim - to find a set of column indices :math:`\{ j_0, j_1, \ldots, j_{M-1} \}` such that - the following conditions are satisfied: +_INF = float("inf") +# Cap per-axis bin count for the 2D DP. Without this, `bins_per_device * +# num_devices` blows up on multi-host worlds (e.g. 16 GPUs at 100 hbm +# bins/dev = 1600 bins, * 800 ddr bins = 1.28M cells per table). The +# budgets are scalar (total HBM, total DDR), so multiplying bins by +# num_devices stops adding precision past this cap. NumPy vectorization +# of the inner loop is a follow-up (PR #508 review R5). +_DP_AXIS_BIN_CAP = 1024 - 1. :math:`\sum_{i=0}^{M-1} a_{i,j_i} \leq K`, where :math:`K` is a float. - 2. :math:`\sum_{i=0}^{M-1} b_{i,j_i}` is minimized. - This problem can be tackled using dynamic programming. First, discretize :math:`K` - into :math:`K_i`, and denote the discretization function as :math:`f`. +class DynamicProgrammingProposer(Proposer): + r"""Proposes sharding plans via 2D (HBM × DDR) dynamic programming. - Define the state :math:`dp[i][f(k)]` to represent the minimum value of :math:`B` - when considering the first :math:`i` rows and the total sum of :math:`A` is equal to - the discretized value :math:`k`. + Given :math:`M` tables each with up to :math:`N` ShardingOptions, pick one + option per table to minimize total perf while respecting both per-rank + HBM and per-machine DDR budgets from the topology. (HBM is GPU-local, so + each device gets its own quota; DDR is host-shared across ranks + co-located on the same machine, so the prune threshold is the sum over + a ``local_world_size``-sized window.) - The state transition can then be represented as: + Each axis (HBM, DDR) is discretized into bins; ``dp[table][h][d]`` holds + the minimum perf over the first ``table`` tables using ``≈ h`` HBM bins + and ``≈ d`` DDR bins. The transition is: .. math:: - dp[i][f(k)] = \min_{j=0}^{N-1} \left( dp[i-1][f(k - A[i][j])] + B[i][j] \right) + dp[i][h][d] = \min_{j} dp[i-1][h - A_h[i][j]][d - A_d[i][j]] + + B[i][j] - Since :math:`K` is the sum allocated across all memory, simply satisfying that the - total memory in the plan equals :math:`K` does not guarantee that the allocation - will fit on all cards. Therefore, it is essential to maintain all the states of the - last layer of :math:`dp`. This allows us to propose different plans under varying - total memory constraints. + Backtracking from each feasible ``(h, d)`` at the last table yields + Pareto-optimal proposals across the joint memory budget. The host axis is + load-bearing for dynamicemb CACHING mode (where DDR = full table) vs + HYBRID (DDR = ``(1 - load_factor) · table``). Args: - mem_bins_per_device (int): memory bins for dynamic programming precision. + hbm_bins_per_device: HBM discretization bins per device. Default 100. + (Accepts legacy ``mem_bins_per_device`` as an alias for + backwards compatibility.) + ddr_bins_per_device: DDR discretization bins per device. Default 50 — + DDR budgets dominate embedding demand, so a coarser axis suffices. """ - def __init__(self, mem_bins_per_device: int = 100) -> None: + def __init__( + self, + hbm_bins_per_device: int = 100, + ddr_bins_per_device: int = 50, + mem_bins_per_device: Optional[int] = None, + ) -> None: self._inited: bool = False - self._mem_bins_per_device: int = max(mem_bins_per_device, 1) + # back-compat alias: ``mem_bins_per_device`` mapped to the HBM axis. + bins_h = ( + mem_bins_per_device + if mem_bins_per_device is not None + else hbm_bins_per_device + ) + self._hbm_bins_per_device: int = max(bins_h, 1) + self._ddr_bins_per_device: int = max(ddr_bins_per_device, 1) self._sharding_options_by_fqn: OrderedDict[str, List[ShardingOption]] = ( OrderedDict() ) - # list of proposals with different total_mem, a proposal is a list of - # indices of sharding_options + # list of proposals with different total_mem; each proposal is a list + # of indices into self._sharding_options_by_fqn[fqn]. self._proposal_list: List[List[int]] = [] self._current_proposal: int = -1 - self._storage_type = "hbm" - if not torch.cuda.is_available(): - self._storage_type = "ddr" def load( self, search_space: List[ShardingOption], enumerator: Optional[Enumerator] = None, ) -> None: - """Load search space.""" + """Load search space, sorted by total (hbm + ddr) ascending.""" self._reset() - # order the sharding_option by total_storage.hbm from low to high for sharding_option in sorted( - search_space, key=lambda x: getattr(x.total_storage, self._storage_type) + search_space, + key=lambda x: (x.total_storage.hbm or 0) + (x.total_storage.ddr or 0), ): fqn = sharding_option.fqn if fqn not in self._sharding_options_by_fqn: @@ -324,105 +334,175 @@ def feedback( perf_rating: Optional[float] = None, storage_constraint: Optional[Topology] = None, ) -> None: - """Feedback last proposed plan.""" - if not self._inited: - self._inited = True - table_count = len(self._sharding_options_by_fqn) - option_count = max([len(x) for x in self._sharding_options_by_fqn.values()]) - - assert storage_constraint is not None - # are we assuming the table will be evenly sharded on all devices? - max_device_mem = 0 - mem_total = 0 - for x in storage_constraint.devices: - cur_device_mem = getattr(x.storage, self._storage_type) - max_device_mem = max(max_device_mem, cur_device_mem) - mem_total += cur_device_mem - - bin_count = self._mem_bins_per_device * len(storage_constraint.devices) - bin_size = float(mem_total) / bin_count - - dp = [ - [(float("inf"), float("inf"))] * bin_count for _ in range(table_count) - ] # [table_id][mem_bin][perf, mem] - - backtrack = [ - [(-1, -1)] * bin_count for _ in range(table_count) - ] # [table_id][mem_bin][opt_id, prev_mem_bin] - - mem_by_fqn = [ - [float("inf") for _ in range(option_count)] for _ in range(table_count) - ] # memory constraint lookup table: [table_id][sharding_option_id] - perf_by_fqn = [ - [float("inf") for _ in range(option_count)] for _ in range(table_count) - ] # performance metrics lookup table: [table_id][sharding_option_id] - - # populate mem and perf for each sharding option and table: - # A[table_id][sharding_option_id] - for table_id, sharding_options in enumerate( - self._sharding_options_by_fqn.values() - ): - for opt_id, sharding_option in enumerate(sharding_options): - # prune mem of one shard > mem of one device - if ( - max( - [ - getattr(shard.storage, self._storage_type) - for shard in sharding_option.shards - ] - ) - > max_device_mem - ): - continue - mem_by_fqn[table_id][opt_id] = _bytes_to_float_bin( - getattr(sharding_option.total_storage, self._storage_type), - bin_size, - ) - perf_by_fqn[table_id][opt_id] = sharding_option.total_perf - - table_0 = 0 - for opt_j in range(option_count): - if mem_by_fqn[0][opt_j] < bin_count: - mem_i = int(mem_by_fqn[0][opt_j]) - # options are ordered in increasing order of mem, we only want to - # consider a sharding option that has higher mem and better perf - # (the smaller the better) - if dp[table_0][mem_i][0] > perf_by_fqn[table_0][opt_j]: - dp[table_0][mem_i] = ( - perf_by_fqn[table_0][opt_j], - mem_by_fqn[table_0][opt_j], - ) - backtrack[table_0][mem_i] = (opt_j, -1) - - # dp: table_count x option_count x bin_count - for table_i in range(1, table_count): - for opt_j in range(option_count): - for mem in range(bin_count): - prev_perf, perv_mem = dp[table_i - 1][mem] - if prev_perf < float("inf"): - new_mem = perv_mem + mem_by_fqn[table_i][opt_j] - if new_mem < bin_count: - new_mem_i = int(new_mem) - new_perf = prev_perf + perf_by_fqn[table_i][opt_j] - if dp[table_i][new_mem_i][0] > new_perf: - dp[table_i][new_mem_i] = (new_perf, new_mem) - backtrack[table_i][new_mem_i] = (opt_j, mem) - self._proposal_list = [] - # fill in all the proposals, starting from highest mem to lowest mem - for c in range(bin_count - 1, -1, -1): - cur_opt_idx, cur_mem_idx = backtrack[table_count - 1][c] - if cur_opt_idx >= 0: - proposal_indices = [-1] * table_count - proposal_indices[table_count - 1] = cur_opt_idx - for i in range(table_count - 2, -1, -1): - proposal_indices[i], cur_mem_idx = backtrack[i][cur_mem_idx] - self._proposal_list.append(proposal_indices) - if len(self._proposal_list) > 0: - self._current_proposal = 0 - else: + """Run 2D DP on first feedback; otherwise advance the proposal cursor.""" + if self._inited: self._current_proposal += 1 if self._current_proposal >= len(self._proposal_list): self._current_proposal = -1 + return + + self._inited = True + assert storage_constraint is not None + table_count = len(self._sharding_options_by_fqn) + if table_count == 0: + return + option_count = max(len(x) for x in self._sharding_options_by_fqn.values()) + + num_devices = len(storage_constraint.devices) + max_device_hbm = 0 + hbm_total = 0 + ddr_total = 0 + for device in storage_constraint.devices: + max_device_hbm = max(max_device_hbm, device.storage.hbm or 0) + hbm_total += device.storage.hbm or 0 + ddr_total += device.storage.ddr or 0 + # DDR is host-shared across ranks co-located on one machine, so the + # per-option fit check compares against the largest machine's DDR pool + # -- not per-device. HBM is GPU-local, so its prune stays per-device. + per_host = getattr(storage_constraint, "local_world_size", None) or num_devices + per_host = max(per_host, 1) + max_machine_ddr = 0 + for host_start in range(0, num_devices, per_host): + host_end = min(host_start + per_host, num_devices) + machine_ddr = sum( + (storage_constraint.devices[i].storage.ddr or 0) + for i in range(host_start, host_end) + ) + max_machine_ddr = max(max_machine_ddr, machine_ddr) + + hbm_bins = max( + min(self._hbm_bins_per_device * num_devices, _DP_AXIS_BIN_CAP), 1 + ) + ddr_bins = max( + min(self._ddr_bins_per_device * num_devices, _DP_AXIS_BIN_CAP), 1 + ) + # Collapse a degenerate axis to a single bin so we don't waste states + # on (e.g.) CPU-only topologies that have hbm == 0 everywhere. + if hbm_total == 0: + hbm_bins = 1 + if ddr_total == 0: + ddr_bins = 1 + hbm_bin_size = float(hbm_total) / hbm_bins if hbm_bins > 0 else 1.0 + ddr_bin_size = float(ddr_total) / ddr_bins if ddr_bins > 0 else 1.0 + + hbm_by_fqn = [[_INF] * option_count for _ in range(table_count)] + ddr_by_fqn = [[_INF] * option_count for _ in range(table_count)] + perf_by_fqn = [[_INF] * option_count for _ in range(table_count)] + + for table_id, sharding_options in enumerate( + self._sharding_options_by_fqn.values() + ): + for opt_id, sharding_option in enumerate(sharding_options): + max_shard_hbm = max( + (shard.storage.hbm or 0) for shard in sharding_option.shards + ) + max_shard_ddr = max( + (shard.storage.ddr or 0) for shard in sharding_option.shards + ) + # HBM is per-device, DDR is per-machine: see comment above. + if hbm_total > 0 and max_shard_hbm > max_device_hbm: + continue + if ddr_total > 0 and max_shard_ddr > max_machine_ddr: + continue + hbm_by_fqn[table_id][opt_id] = ( + _bytes_to_float_bin( + sharding_option.total_storage.hbm or 0, hbm_bin_size + ) + if hbm_total > 0 + else 0.0 + ) + ddr_by_fqn[table_id][opt_id] = ( + _bytes_to_float_bin( + sharding_option.total_storage.ddr or 0, ddr_bin_size + ) + if ddr_total > 0 + else 0.0 + ) + perf_by_fqn[table_id][opt_id] = sharding_option.total_perf + + # dp[table][hbm_bin][ddr_bin] = (perf, hbm_actual, ddr_actual) + # backtrack[table][hbm_bin][ddr_bin] = (opt_id, prev_hbm_bin, prev_ddr_bin) + empty_state = (_INF, _INF, _INF) + empty_back = (-1, -1, -1) + dp = [ + [[empty_state] * ddr_bins for _ in range(hbm_bins)] + for _ in range(table_count) + ] + backtrack = [ + [[empty_back] * ddr_bins for _ in range(hbm_bins)] + for _ in range(table_count) + ] + + # Seed the first table. + for opt_j in range(option_count): + h = hbm_by_fqn[0][opt_j] + d = ddr_by_fqn[0][opt_j] + if h >= hbm_bins or d >= ddr_bins: + continue + h_i = int(h) + d_i = int(d) + if dp[0][h_i][d_i][0] > perf_by_fqn[0][opt_j]: + dp[0][h_i][d_i] = (perf_by_fqn[0][opt_j], h, d) + backtrack[0][h_i][d_i] = (opt_j, -1, -1) + + for table_i in range(1, table_count): + prev_dp = dp[table_i - 1] + cur_dp = dp[table_i] + cur_back = backtrack[table_i] + hbm_t = hbm_by_fqn[table_i] + ddr_t = ddr_by_fqn[table_i] + perf_t = perf_by_fqn[table_i] + for h in range(hbm_bins): + prev_dp_h = prev_dp[h] + for d in range(ddr_bins): + prev_state = prev_dp_h[d] + prev_perf = prev_state[0] + if prev_perf == _INF: + continue + prev_h_val = prev_state[1] + prev_d_val = prev_state[2] + for opt_j in range(option_count): + delta_perf = perf_t[opt_j] + if delta_perf == _INF: + continue + new_h = prev_h_val + hbm_t[opt_j] + new_d = prev_d_val + ddr_t[opt_j] + if new_h >= hbm_bins or new_d >= ddr_bins: + continue + new_h_i = int(new_h) + new_d_i = int(new_d) + new_perf = prev_perf + delta_perf + if cur_dp[new_h_i][new_d_i][0] > new_perf: + cur_dp[new_h_i][new_d_i] = (new_perf, new_h, new_d) + cur_back[new_h_i][new_d_i] = (opt_j, h, d) + + # Enumerate proposals in decreasing total memory order. Total memory + # is a tie-break heuristic — Pareto-optimal (perf, hbm, ddr) frontier + # tends to live at the high end, so larger-mem cells are explored + # first and small-mem fallbacks come later. + self._proposal_list = [] + last_back = backtrack[table_count - 1] + coords = sorted( + ( + (h, d) + for h in range(hbm_bins) + for d in range(ddr_bins) + if last_back[h][d][0] >= 0 + ), + key=lambda hd: hd[0] + hd[1], + reverse=True, + ) + for h, d in coords: + cur_opt_idx, prev_h, prev_d = last_back[h][d] + if cur_opt_idx < 0: + continue + proposal_indices = [-1] * table_count + proposal_indices[table_count - 1] = cur_opt_idx + for i in range(table_count - 2, -1, -1): + proposal_indices[i], prev_h, prev_d = backtrack[i][prev_h][prev_d] + self._proposal_list.append(proposal_indices) + if self._proposal_list: + self._current_proposal = 0 def _extract_constraints_for_param( @@ -755,6 +835,38 @@ def calculate_shard_storages( ) +def _emit_dynamicemb_variants( + base_option: ShardingOption, + dynamicemb_options: Any, +) -> List[ShardingOption]: + """Expand a dynamicemb ShardingOption into HYBRID + CACHING variants. + + Sweeps both placement modes (``caching=False`` and ``caching=True``) and, + when ``base_option.cache_params`` is unset, ten cache_load_factor values + (0.1, 0.2, ..., 1.0). The downstream 2D DP proposer picks per table the + best (mode, ratio) that fits both HBM and host topology budgets. + + Each returned ShardingOption owns a freshly deep-copied + ``dynamicemb_options`` instance so per-option ``caching`` mutations do not + bleed across variants. + """ + if base_option.cache_params is None: + load_factors = [(i + 1) / 10 for i in range(10)] + stats = None + else: + load_factors = [base_option.cache_params.load_factor] + stats = base_option.cache_params.stats + variants: List[ShardingOption] = [] + for caching_mode in (False, True): + for load_factor in load_factors: + opt = copy.deepcopy(base_option) + opt.cache_params = CacheParams(load_factor=load_factor, stats=stats) + # deepcopy(base_option) already produced a fresh dynamicemb_options. + opt.dynamicemb_options.caching = caching_mode # pyre-ignore [16] + variants.append(opt) + return variants + + class EmbeddingEnumerator(_EmbeddingEnumerator): """Generates embedding sharding options for given `nn.Module` with constraints. @@ -934,20 +1046,11 @@ def enumerate( sharding_option.use_dynamicemb = use_dynamicemb # pyre-ignore [16] sharding_option.dynamicemb_options = dynamicemb_options - if sharding_option.cache_params is None: - # add cache_load_factor automatic search space - for load_factor_step in range(10): - sharding_option_copy = copy.deepcopy( - sharding_option - ) - sharding_option_copy.cache_params = CacheParams( - load_factor=(load_factor_step + 1) / 10 - ) - sharding_options_per_table.append( - sharding_option_copy - ) - else: - sharding_options_per_table.append(sharding_option) + sharding_options_per_table.extend( + _emit_dynamicemb_variants( + sharding_option, dynamicemb_options + ) + ) else: sharding_options_per_table.append(sharding_option) diff --git a/tzrec/utils/plan_util_test.py b/tzrec/utils/plan_util_test.py index 868b09df..2c578876 100644 --- a/tzrec/utils/plan_util_test.py +++ b/tzrec/utils/plan_util_test.py @@ -10,17 +10,21 @@ # limitations under the License. import unittest +from types import SimpleNamespace import torch +from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.model_parallel import get_default_sharders from torchrec.distributed.planner.enumerators import EmbeddingEnumerator from torchrec.distributed.planner.partitioners import GreedyPerfPartitioner from torchrec.distributed.planner.proposers import GridSearchProposer from torchrec.distributed.planner.types import PlannerError, Topology from torchrec.distributed.test_utils.test_model import TestSparseNN +from torchrec.distributed.types import CacheParams, ShardingType from torchrec.modules.embedding_configs import EmbeddingBagConfig -from tzrec.utils.plan_util import DynamicProgrammingProposer +from tzrec.utils.dynamicemb_util import has_dynamicemb +from tzrec.utils.plan_util import DynamicProgrammingProposer, _emit_dynamicemb_variants class PlanUtilTest(unittest.TestCase): @@ -136,5 +140,342 @@ def test_dp_proposer_with_prune(self) -> None: ) +class EmitDynamicEmbVariantsTest(unittest.TestCase): + """``_emit_dynamicemb_variants`` produces { HYBRID, CACHING } × factors.""" + + def _make_base(self, cache_params=None): + # _emit_dynamicemb_variants only touches cache_params and + # dynamicemb_options on the option, so a SimpleNamespace stand-in is + # sufficient and avoids the heavy ShardingOption constructor surface. + return SimpleNamespace( + cache_params=cache_params, + dynamicemb_options=SimpleNamespace(caching=False, bucket_capacity=128), + ) + + def _make_dynamicemb_options(self): + return SimpleNamespace(caching=False, bucket_capacity=128) + + def test_unfixed_factor_emits_twenty_variants(self): + variants = _emit_dynamicemb_variants( + self._make_base(cache_params=None), self._make_dynamicemb_options() + ) + self.assertEqual(len(variants), 20) + cache_modes = sorted({v.dynamicemb_options.caching for v in variants}) + self.assertEqual(cache_modes, [False, True]) + for caching_mode in (False, True): + factors = sorted( + v.cache_params.load_factor + for v in variants + if v.dynamicemb_options.caching is caching_mode + ) + self.assertEqual(factors, [round((i + 1) / 10, 4) for i in range(10)]) + + def test_fixed_factor_emits_two_variants(self): + variants = _emit_dynamicemb_variants( + self._make_base(cache_params=CacheParams(load_factor=0.3)), + self._make_dynamicemb_options(), + ) + self.assertEqual(len(variants), 2) + cache_modes = sorted(v.dynamicemb_options.caching for v in variants) + self.assertEqual(cache_modes, [False, True]) + for v in variants: + self.assertEqual(v.cache_params.load_factor, 0.3) + + def test_variants_own_dynamicemb_options(self): + # Per-variant mutation of caching must not bleed across variants. + opts = self._make_dynamicemb_options() + variants = _emit_dynamicemb_variants(self._make_base(), opts) + for v in variants: + self.assertIsNot(v.dynamicemb_options, opts) + + def test_stats_preserved_on_clone(self): + class _Stats: + expected_lookups = 100.0 + + def expected_miss_rate(self, ratio): + return 0.1 + + stats = _Stats() + variants = _emit_dynamicemb_variants( + self._make_base(cache_params=CacheParams(load_factor=0.2, stats=stats)), + self._make_dynamicemb_options(), + ) + for v in variants: + self.assertIs(v.cache_params.stats, stats) + + +class _FakeStorage: + def __init__(self, hbm, ddr): + self.hbm = hbm + self.ddr = ddr + + +class _FakeShard: + def __init__(self, hbm, ddr): + self.storage = _FakeStorage(hbm, ddr) + + +class _FakeShardingOption: + """Minimal ShardingOption stand-in: only the fields the DP proposer reads.""" + + def __init__(self, fqn, hbm, ddr, perf): + self.fqn = fqn + # Total = single shard for simplicity (single-rank assignment). + self.shards = [_FakeShard(hbm, ddr)] + self.total_storage = _FakeStorage(hbm, ddr) + self.total_perf = perf + + +def _make_topology(num_devices, hbm_per_device, ddr_per_device, local_world_size=None): + return SimpleNamespace( + devices=[ + SimpleNamespace( + storage=_FakeStorage(hbm=hbm_per_device, ddr=ddr_per_device) + ) + for _ in range(num_devices) + ], + local_world_size=local_world_size or num_devices, + ) + + +class DynamicProgrammingProposerTest(unittest.TestCase): + """2D DP across HBM × DDR picks per-table mode under joint budgets.""" + + def _run(self, search_space, topology): + proposer = DynamicProgrammingProposer( + hbm_bins_per_device=20, ddr_bins_per_device=20 + ) + proposer.load(search_space) + # First propose returns the lowest-mem-per-table seed. + proposer.propose() + proposer.feedback(partitionable=True, storage_constraint=topology) + proposals = [] + proposal = proposer.propose() + while proposal: + proposals.append(proposal) + proposer.feedback(partitionable=True, storage_constraint=topology) + proposal = proposer.propose() + return proposals + + def test_caching_preferred_when_ddr_is_generous(self): + # Three options for one table: + # HYBRID @ x=1.0: hbm = T, ddr = 0, perf = high (HBM-only) + # HYBRID @ x=0.1: hbm = .1T, ddr = .9T, perf = high (slow) + # CACHING @ x=0.1: hbm = .1T, ddr = T, perf = low (fast — modeled hits) + opts = [ + _FakeShardingOption("table_a", hbm=1000, ddr=0, perf=50.0), + _FakeShardingOption("table_a", hbm=100, ddr=900, perf=80.0), + _FakeShardingOption("table_a", hbm=100, ddr=1000, perf=10.0), + ] + topology = _make_topology( + num_devices=2, hbm_per_device=2000, ddr_per_device=2000 + ) + proposals = self._run(opts, topology) + # Best plan must be the CACHING option (perf=10). + best = min(proposals, key=lambda p: sum(o.total_perf for o in p)) + self.assertEqual(best[0].total_perf, 10.0) + + def test_caching_rejected_when_ddr_is_tight(self): + # Host budget is only 950 — CACHING (ddr=1000) cannot fit; HYBRID can. + opts = [ + _FakeShardingOption("table_a", hbm=100, ddr=900, perf=80.0), + _FakeShardingOption("table_a", hbm=100, ddr=1000, perf=10.0), + ] + topology = _make_topology( + num_devices=1, hbm_per_device=2000, ddr_per_device=950 + ) + proposals = self._run(opts, topology) + # Every proposed plan must pick the HYBRID option (perf=80). + for p in proposals: + self.assertEqual(p[0].total_perf, 80.0) + + def test_high_factor_collapses_modes(self): + # At x=1.0 HYBRID == CACHING in HBM and CACHING.ddr = T = HYBRID.hbm. + # If we offer just the high-factor options, DP picks one of them. + opts = [ + _FakeShardingOption("table_a", hbm=1000, ddr=0, perf=50.0), # HYBRID x=1.0 + _FakeShardingOption( + "table_a", hbm=1000, ddr=1000, perf=50.0 + ), # CACHING x=1.0 + ] + topology = _make_topology( + num_devices=1, hbm_per_device=1100, ddr_per_device=2000 + ) + proposals = self._run(opts, topology) + # Either option is fine — they're tied. Just verify the proposer + # returned something feasible. + self.assertGreater(len(proposals), 0) + for p in proposals: + self.assertEqual(p[0].total_perf, 50.0) + + def test_two_tables_pick_mixed_modes_under_joint_budget(self): + # Two tables, each with HYBRID@1.0 (all-HBM, no DDR) and CACHING@0.1 + # (small HBM, full-T DDR). Topology HBM=2000 admits exactly one + # full HYBRID + one small CACHING shard (1500+100), and host DDR + # admits exactly one full-T CACHING backing (1500). Both-HYBRID is + # HBM-infeasible (3000>2000), both-CACHING is DDR-infeasible + # (3000>2000). Only the mixed plan fits. Exercises the + # cross-table DP transition at plan_util.py table_i==1. + opts = [ + _FakeShardingOption("table_a", hbm=1500, ddr=0, perf=50.0), # HYBRID@1.0 + _FakeShardingOption("table_a", hbm=100, ddr=1500, perf=40.0), # CACHING@0.1 + _FakeShardingOption("table_b", hbm=1500, ddr=0, perf=50.0), # HYBRID@1.0 + _FakeShardingOption("table_b", hbm=100, ddr=1500, perf=40.0), # CACHING@0.1 + ] + topology = _make_topology( + num_devices=1, hbm_per_device=2000, ddr_per_device=2000 + ) + proposals = self._run(opts, topology) + self.assertGreater(len(proposals), 0) + best = min(proposals, key=lambda p: sum(o.total_perf for o in p)) + styles = sorted( + "hybrid" if o.shards[0].storage.ddr == 0 else "caching" for o in best + ) + self.assertEqual(styles, ["caching", "hybrid"]) + + def test_legacy_mem_bins_per_device_alias(self): + # The legacy kwarg `mem_bins_per_device` is preserved as an alias + # for `hbm_bins_per_device`. + proposer = DynamicProgrammingProposer(mem_bins_per_device=77) + self.assertEqual(proposer._hbm_bins_per_device, 77) + # Explicit hbm_bins_per_device takes precedence: when both are + # provided, mem_bins_per_device wins (it's the legacy override). + proposer = DynamicProgrammingProposer( + mem_bins_per_device=77, hbm_bins_per_device=99 + ) + self.assertEqual(proposer._hbm_bins_per_device, 77) + + def test_empty_search_space_returns_empty_proposal(self): + proposer = DynamicProgrammingProposer() + proposer.load([]) + # Seed proposal is the per-table first option; with no tables the + # list is empty. + self.assertEqual(proposer.propose(), []) + # Feedback must not raise on an empty proposer (it short-circuits + # via `table_count == 0`). + topology = _make_topology( + num_devices=1, hbm_per_device=1000, ddr_per_device=1000 + ) + proposer.feedback(partitionable=True, storage_constraint=topology) + # After feedback, no proposals are available. + self.assertIsNone(proposer.propose()) + + +@unittest.skipUnless(has_dynamicemb, "dynamicemb is not installed; skipping.") +@unittest.skipUnless(torch.cuda.is_available(), "CUDA is required for dynamicemb.") +class PlanUtilDynamicEmbE2ETest(unittest.TestCase): + """End-to-end exercise of the dynamicemb planner integration.""" + + def _build_constraint(self, max_capacity=4096): + import dynamicemb + from dynamicemb.planner import DynamicEmbParameterConstraints + + opts = dynamicemb.DynamicEmbTableOptions( + max_capacity=max_capacity, + initializer_args=dynamicemb.DynamicEmbInitializerArgs( + mode=dynamicemb.DynamicEmbInitializerMode.UNIFORM, + lower=-0.01, + upper=0.01, + ), + eval_initializer_args=dynamicemb.DynamicEmbInitializerArgs( + mode=dynamicemb.DynamicEmbInitializerMode.CONSTANT, value=0.0 + ), + score_strategy=dynamicemb.DynamicEmbScoreStrategy.STEP, + ) + return DynamicEmbParameterConstraints( + use_dynamicemb=True, + sharding_types=[ShardingType.ROW_WISE.value], + compute_kernels=[EmbeddingComputeKernel.CUSTOMIZED_KERNEL.value], + dynamicemb_options=opts, + ) + + def _build_model(self): + table = EmbeddingBagConfig( + num_embeddings=4096, + embedding_dim=32, + name="table_de", + feature_names=["feat_de"], + ) + return TestSparseNN(tables=[table], sparse_device=torch.device("meta")) + + def test_enumerate_yields_both_modes_and_all_factors(self): + from tzrec.utils.plan_util import ( + EmbeddingEnumerator as _TzrecEmbeddingEnumerator, + ) + from tzrec.utils.plan_util import ( + get_default_sharders as _tzrec_get_default_sharders, + ) + + model = self._build_model() + topology = Topology(world_size=2, compute_device="cuda") + enumerator = _TzrecEmbeddingEnumerator( + topology=topology, + batch_size=128, + fqn_constraints={"sparse.ebc.table_de": self._build_constraint()}, + ) + search_space = enumerator.enumerate( + module=model, sharders=_tzrec_get_default_sharders() + ) + self.assertEqual(len(search_space), 20) + caching_modes = sorted( + { + so.dynamicemb_options.caching + for so in search_space + if getattr(so, "use_dynamicemb", False) + } + ) + self.assertEqual(caching_modes, [False, True]) + load_factors = sorted( + { + round(so.cache_load_factor, 4) + for so in search_space + if getattr(so, "use_dynamicemb", False) + } + ) + self.assertEqual(load_factors, [round((i + 1) / 10, 4) for i in range(10)]) + # Each option must carry a non-zero perf and storage estimate. + for so in search_space: + self.assertGreater(so.total_perf, 0) + self.assertGreaterEqual(so.total_storage.hbm, 0) + self.assertGreaterEqual(so.total_storage.ddr, 0) + + def test_dp_proposer_picks_feasible_dynamicemb_plan(self): + from tzrec.utils.plan_util import ( + EmbeddingEnumerator as _TzrecEmbeddingEnumerator, + ) + from tzrec.utils.plan_util import ( + get_default_sharders as _tzrec_get_default_sharders, + ) + + model = self._build_model() + topology = Topology(world_size=2, compute_device="cuda") + enumerator = _TzrecEmbeddingEnumerator( + topology=topology, + batch_size=128, + fqn_constraints={"sparse.ebc.table_de": self._build_constraint()}, + ) + search_space = enumerator.enumerate( + module=model, sharders=_tzrec_get_default_sharders() + ) + + proposer = DynamicProgrammingProposer() + proposer.load(search_space) + proposal = proposer.propose() + self.assertIsNotNone(proposal) + proposer.feedback(partitionable=True, storage_constraint=topology) + + # At least one further proposal should be generated by the 2D DP. + count = 0 + proposal = proposer.propose() + while proposal is not None and count < 5: + count += 1 + for so in proposal: + if getattr(so, "use_dynamicemb", False): + self.assertIn(so.dynamicemb_options.caching, (False, True)) + proposer.feedback(partitionable=True, storage_constraint=topology) + proposal = proposer.propose() + self.assertGreater(count, 0) + + if __name__ == "__main__": unittest.main()