From 20abd3a3d0b0db36bc9cdf4edffd709a72c58259 Mon Sep 17 00:00:00 2001 From: Yan Xiong Date: Tue, 2 Sep 2025 14:38:30 -0700 Subject: [PATCH 1/2] updated version of device-with-speclist' bench (#4648) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/4648 X-link: https://github.com/facebookresearch/FBGEMM/pull/1682 Add TBE/VBE bench (-- device-with-speclist) that takes dim lists for args under tbe_training. #### Summary This diff introduces a new feature to the TBE (Table Batched Embedding) bench, which allows for benchmarking with a list of devices and speclist. The changes include: * **tbe_data_config_loader.py:** Added a new class `TBEDataListConfig` to load a list of `TBEDataConfig` objects. Modified the `TBEDataConfigLoader` class to load a list of `TBEDataConfig` objects. * **bench_runs.py:** Added a new function `bench_warmup_with_spec` to perform warm-up runs with a specified list of devices and batch sizes. * **tbe_data_config_param_models.py:** Added a new class `BatchListParams` to represent a list of batch sizes and their corresponding standard deviations. * **__init__.py:** Imported the new `TBEDataListConfig` class and the `benchmark_requests_with_spec` function. * **tbe_training_benchmark.py:** Added a new function `benchmark_requests_with_spec` to benchmark requests with a specified list of devices and batch sizes. These changes enable the TBE bench to support benchmarking with a list of devices and speclist, making it more flexible and efficient for testing and optimizing TBE performance. Reviewed By: spcyppt Differential Revision: D78390216 --- .../bench/tbe/tbe_training_benchmark.py | 344 ++++++++++++++++++ fbgemm_gpu/fbgemm_gpu/tbe/bench/__init__.py | 1 + fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py | 149 +++++++- .../fbgemm_gpu/tbe/bench/tbe_data_config.py | 27 +- .../tbe/bench/tbe_data_config_bench_helper.py | 94 ++++- .../tbe/bench/tbe_data_config_loader.py | 50 ++- .../tbe/bench/tbe_data_config_param_models.py | 9 +- 7 files changed, 664 insertions(+), 10 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/tbe_training_benchmark.py b/fbgemm_gpu/bench/tbe/tbe_training_benchmark.py index b481f769cd..19e8cdbfa5 100644 --- a/fbgemm_gpu/bench/tbe/tbe_training_benchmark.py +++ b/fbgemm_gpu/bench/tbe/tbe_training_benchmark.py @@ -32,10 +32,12 @@ from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( ComputeDevice, DenseTableBatchedEmbeddingBagsCodegen, + get_available_compute_device, SplitTableBatchedEmbeddingBagsCodegen, ) from fbgemm_gpu.tbe.bench import ( benchmark_requests, + benchmark_requests_with_spec, EmbeddingOpsCommonConfigLoader, TBEBenchmarkingConfigLoader, TBEDataConfigLoader, @@ -44,6 +46,7 @@ generate_embedding_dims, generate_feature_requires_grad, generate_requests, + generate_requests_with_Llist, ) from fbgemm_gpu.tbe.ssd import SSDTableBatchedEmbeddingBags from fbgemm_gpu.tbe.utils import get_device @@ -53,6 +56,13 @@ logger.setLevel(logging.DEBUG) logging.basicConfig(level=logging.DEBUG) +try: + import mtia.host_runtime.torch_mtia.dynamic_library # pyright: ignore # noqa: F401 # pyre-ignore[21] + + torch.mtia.init() +except Exception: + pass + @click.group() def cli() -> None: @@ -358,5 +368,339 @@ def _context_factory( ) +@cli.command() +@click.option( + "--emb-op-type", + default="split", + type=click.Choice(["split", "dense", "ssd"], case_sensitive=False), + help="The type of the embedding op to benchmark", +) +@click.option( + "--row-wise/--no-row-wise", + default=True, + help="Whether to use row-wise adagrad optimzier or not", +) +@click.option( + "--weighted-num-requires-grad", + type=int, + default=None, + help="The number of weighted tables that require gradient", +) +@click.option( + "--ssd-prefix", + type=str, + default="/tmp/ssd_benchmark", + help="SSD directory prefix", +) +@click.option( + "--pooling-list", + type=str, + default=None, + help="override pooling list", +) +@click.option("--cache-load-factor", default=0.2) +@TBEBenchmarkingConfigLoader.options +@TBEDataConfigLoader.options +@EmbeddingOpsCommonConfigLoader.options +@click.pass_context +def device_with_speclist( # noqa C901 + context: click.Context, + emb_op_type: click.Choice, + row_wise: bool, + weighted_num_requires_grad: Optional[int], + cache_load_factor: float, + # SSD params + ssd_prefix: str, + pooling_list: Optional[str], + # pyre-ignore[2] + **kwargs, +) -> None: + """ + A TBE benchmark supporting TBE param list and EEG params as input arguments. This allows for more flexible and customizable benchmarking. + Args: + uses optional arguments from TBEDataConfigLoader to take in TBE param list and EEG params as input arguments: + --tbe-num-embeddings-list: the list of embedding table sizes + --tbe-embedding-dim-list: the list of embedding dimensions + --tbe-batch-sizes-list: the list of batch sizes + --pooling-list: the list of pooling factors + Example: + buck2 run @mode/opt -c fbcode.nvcc_arch=h100 -c fbcode.platform=platform010 //deeplearning/fbgemm/fbgemm_gpu/bench:tbe_training -- device-with-speclist --bench-warmup-iterations 2 \ + --bench-iterations 10 --emb-pooling-mode sum --row-wise --tbe-num-tables 5 --tbe-num-embeddings-list 169694,66932,3717056,335,101083 --tbe-embedding-dim-list 128,128,128,128,128 \ + --tbe-batch-sizes-list 245760,245760,245760,245760,245760 \ + --pooling-list 4.454203287760417,8.075313313802083,1.5521280924479166,9.099202473958334,37.089603678385416\ + --tbe-indices-zipf 2.75 0.8900000000000006 --tbe-indices-hitters 0.0032447561639423338,0.002346034168270899,0.002270828999570933,0.0021225501825015603,0.0021215337630846732,0.0019088356649139518,0.001890480906511913,0.0018895829048911682,0.001865188838885878,0.001863886243128314,0.0018611428975176868,0.0018586561237987009,0.001858429156356095,0.0018583502111586669,0.0018583206067096312,0.001705492572638463,0.0017048511429093595,0.00170478206586161,0.0017045847028680395,0.0017042393176292915 \ + --tbe-indices-dtype 64 --tbe-offsets-dtype 64 --tbe-pooling-size 21 --tbe-pooling-vl-sigma 35 --tbe-pooling-vl-dist normal --emb-cache-dtype fp16 --emb-weights-dtype fp16 --bench-export-trace --emb-stochastic-rounding \ + """ + + # Initialize random seeds + np.random.seed(42) + torch.manual_seed(42) + + # Load general TBE benchmarking configuration from cli arguments + benchconfig = TBEBenchmarkingConfigLoader.load(context) + + # Load TBE data configuration from cli arguments + tbeconfig = TBEDataConfigLoader.load(context) + + # Load common embedding op configuration from cli arguments + embconfig = EmbeddingOpsCommonConfigLoader.load(context) + assert tbeconfig.Es is not None, "E list is not provided" + assert tbeconfig.Ds is not None, "D list is not provided" + # Generate feature_requires_grad + feature_requires_grad = ( + generate_feature_requires_grad(tbeconfig, weighted_num_requires_grad) + if weighted_num_requires_grad + else None + ) + + # Determine the optimizer + optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD + + # Construct the common split arguments for the embedding op + common_split_args: Dict[str, Any] = embconfig.split_args() | { + "optimizer": optimizer, + "learning_rate": 0.1, + "eps": 0.1, + "feature_table_map": list(range(tbeconfig.T)), + } + assert tbeconfig.batch_params.Bs is not None, "B list is not provided" + + batch_size_per_feature_per_rank = None + if tbeconfig.batch_params.sigma_B is not None: + batch_size_per_feature_per_rank = [] + for b in tbeconfig.batch_params.Bs: + batch_size_per_feature_per_rank.append([b]) + + managed_option = ( + EmbeddingLocation.DEVICE + if get_available_compute_device() == ComputeDevice.CUDA + else EmbeddingLocation.HOST + ) + + if emb_op_type == "dense": + embedding_op = DenseTableBatchedEmbeddingBagsCodegen( + [ + ( + e, + d, + ) + for e, d in zip(tbeconfig.Es, tbeconfig.Ds) + ], + pooling_mode=embconfig.pooling_mode, + use_cpu=not torch.cuda.is_available(), + ) + elif emb_op_type == "ssd": + assert ( + torch.cuda.is_available() + ), "SSDTableBatchedEmbeddingBags only supports GPU execution" + cache_set = max(sum(tbeconfig.batch_params.Bs), 1) + tempdir = tempfile.mkdtemp(prefix=ssd_prefix) + embedding_op = SSDTableBatchedEmbeddingBags( + embedding_specs=[(e, d) for e, d in zip(tbeconfig.Es, tbeconfig.Ds)], + cache_sets=cache_set, + ssd_storage_directory=tempdir, + ssd_cache_location=EmbeddingLocation.DEVICE, + ssd_rocksdb_shards=8, + **common_split_args, + ) + else: + embedding_op = SplitTableBatchedEmbeddingBagsCodegen( + [ + ( + e, + d, + managed_option, + get_available_compute_device(), + ) + for e, d in zip(tbeconfig.Es, tbeconfig.Ds) + ], + cache_precision=( + embconfig.weights_dtype + if embconfig.cache_dtype is None + else embconfig.cache_dtype + ), + cache_algorithm=CacheAlgorithm.LRU, + cache_load_factor=cache_load_factor, + device=get_device(), + **common_split_args, + ).to(get_device()) + embedding_op = embedding_op.to(get_device()) + + if embconfig.weights_dtype == SparseType.INT8: + # pyre-fixme[29]: `Union[(self: DenseTableBatchedEmbeddingBagsCodegen, + # min_val: float, max_val: float) -> None, (self: + # SplitTableBatchedEmbeddingBagsCodegen, min_val: float, max_val: float) -> + # None, Tensor, Module]` is not a function. + embedding_op.init_embedding_weights_uniform(-0.0003, 0.0003) + + avg_B = int(np.average(tbeconfig.batch_params.Bs)) + + nparams = sum(d * e for e, d in zip(tbeconfig.Es, tbeconfig.Ds)) + param_size_multiplier = embconfig.weights_dtype.bit_rate() / 8.0 + output_size_multiplier = embconfig.output_dtype.bit_rate() / 8.0 + if embconfig.pooling_mode.do_pooling(): + read_write_bytes = ( + output_size_multiplier * avg_B * sum(tbeconfig.Ds) + + param_size_multiplier + * avg_B + * sum(tbeconfig.Ds) + * tbeconfig.pooling_params.L + ) + else: + read_write_bytes = ( + output_size_multiplier + * avg_B + * sum(tbeconfig.Ds) + * tbeconfig.pooling_params.L + + param_size_multiplier + * avg_B + * sum(tbeconfig.Ds) + * tbeconfig.pooling_params.L + ) + + logging.info(f"Managed option: {embconfig.embedding_location}") + logging.info( + f"Embedding parameters: {nparams / 1.0e9: .2f} GParam, " + f"{nparams * param_size_multiplier / 1.0e9: .2f} GB" + ) + logging.info( + f"Accessed weights per batch: {avg_B * sum(tbeconfig.Ds) * tbeconfig.pooling_params.L * param_size_multiplier / 1.0e9: .2f} GB" + ) + + if pooling_list is not None: + pooling_list_extracted = [float(x) for x in pooling_list.split(",")] + tensor_pooling_list = torch.tensor(pooling_list_extracted) + requests = generate_requests_with_Llist( + tbeconfig, + tensor_pooling_list, + benchconfig.num_requests, + batch_size_per_feature_per_rank, + ) + else: + requests = generate_requests( + tbeconfig, benchconfig.num_requests, batch_size_per_feature_per_rank + ) + + # pyre-ignore[53] + def _kineto_trace_handler(p: profile, phase: str) -> None: + p.export_chrome_trace( + benchconfig.trace_url.format( + emb_op_type=emb_op_type, phase=phase, ospid=os.getpid() + ) + ) + + # pyre-ignore[3,53] + def _context_factory(on_trace_ready: Callable[[profile], None]): + return ( + profile(on_trace_ready=on_trace_ready, with_stack=True, record_shapes=True) + if benchconfig.export_trace + else nullcontext() + ) + + # to add batch_size_per_feature_per_rank, Yan's edit + + if torch.cuda.is_available(): + with _context_factory(lambda p: _kineto_trace_handler(p, "fwd")): + # forward + time_per_iter = benchmark_requests_with_spec( + requests, + lambda indices, offsets, per_sample_weights, batch_size_per_feature_per_rank: embedding_op.forward( + indices.to(dtype=tbeconfig.indices_params.index_dtype), + offsets.to(dtype=tbeconfig.indices_params.offset_dtype), + per_sample_weights, + feature_requires_grad=feature_requires_grad, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ), + flush_gpu_cache_size_mb=benchconfig.flush_gpu_cache_size_mb, + num_warmups=benchconfig.warmup_iterations, + iters=benchconfig.iterations, + ) + else: + time_per_iter = benchmark_requests_with_spec( + requests, + lambda indices, offsets, per_sample_weights, batch_size_per_feature_per_rank: embedding_op.forward( + indices.to(dtype=tbeconfig.indices_params.index_dtype), + offsets.to(dtype=tbeconfig.indices_params.offset_dtype), + per_sample_weights, + feature_requires_grad=feature_requires_grad, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ), + flush_gpu_cache_size_mb=benchconfig.flush_gpu_cache_size_mb, + num_warmups=benchconfig.warmup_iterations, + iters=benchconfig.iterations, + ) + + avg_E = int(np.average(tbeconfig.E)) + avg_D = int(np.average(tbeconfig.D)) + logging.info( + f"Forward, B: {avg_B}, " + f"E: {avg_E}, T: {tbeconfig.T}, D: {avg_D}, L: {tbeconfig.pooling_params.L}, W: {tbeconfig.weighted}, " + f"BW: {read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 + f"T: {time_per_iter * 1.0e6:.0f}us" + ) + + if embconfig.output_dtype == SparseType.INT8: + # backward bench not representative + return + + if embconfig.pooling_mode.do_pooling(): + if batch_size_per_feature_per_rank is None: + grad_output = torch.randn(avg_B, sum(tbeconfig.Ds)).to(get_device()) + else: + output_size = sum( + [b * d for (b, d) in zip(tbeconfig.batch_params.Bs, tbeconfig.Ds)] + ) + grad_output = torch.randn(output_size).to(get_device()) + + else: + grad_output = torch.randn( + avg_B * tbeconfig.T * tbeconfig.pooling_params.L, + avg_D, + ).to(get_device()) + assert ( + batch_size_per_feature_per_rank is None or grad_output.dim() == 1 + ), f"VBE expects 1D grad_output but got {grad_output.shape}" + if torch.cuda.is_available(): + with _context_factory(lambda p: _kineto_trace_handler(p, "fwd_bwd")): + # backward + time_per_iter = benchmark_requests_with_spec( + requests, + lambda indices, offsets, per_sample_weights, batch_size_per_feature_per_rank: embedding_op( + indices.to(dtype=tbeconfig.indices_params.index_dtype), + offsets.to(dtype=tbeconfig.indices_params.offset_dtype), + per_sample_weights, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + feature_requires_grad=feature_requires_grad, + ), + flush_gpu_cache_size_mb=benchconfig.flush_gpu_cache_size_mb, + bwd_only=True, + grad=grad_output, + num_warmups=benchconfig.warmup_iterations, + iters=benchconfig.iterations, + ) + else: + time_per_iter = benchmark_requests_with_spec( + requests, + lambda indices, offsets, per_sample_weights, batch_size_per_feature_per_rank: embedding_op( + indices.to(dtype=tbeconfig.indices_params.index_dtype), + offsets.to(dtype=tbeconfig.indices_params.offset_dtype), + per_sample_weights, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + feature_requires_grad=feature_requires_grad, + ), + flush_gpu_cache_size_mb=benchconfig.flush_gpu_cache_size_mb, + bwd_only=True, + grad=grad_output, + num_warmups=benchconfig.warmup_iterations, + iters=benchconfig.iterations, + ) + + logging.info( + f"Backward, B: {avg_B}, E: {avg_E}, T: {tbeconfig.T}, D: {avg_D}, L: {tbeconfig.pooling_params.L}, " + f"BW: {2 * read_write_bytes / time_per_iter / 1.0e9: .2f} GB/s, " + f"T: {time_per_iter * 1.0e6:.0f}us" + ) + + if __name__ == "__main__": cli() diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/__init__.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/__init__.py index 52fc2b2b12..bed20c2daf 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/__init__.py @@ -21,6 +21,7 @@ benchmark_pipelined_requests, benchmark_requests, benchmark_requests_refer, + benchmark_requests_with_spec, benchmark_vbe, ) from .benchmark_click_interface import TbeBenchClickInterface # noqa F401 diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py index 00bf30d230..a0f0c518d8 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/bench_runs.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # @@ -16,7 +17,7 @@ import torch from fbgemm_gpu.tbe.utils import b_indices, TBERequest - +from fbgemm_gpu.tbe.utils.common import get_device logging.basicConfig(level=logging.DEBUG) @@ -43,6 +44,31 @@ def bench_warmup( out.backward(grad) +def bench_warmup_with_spec( + request: TBERequest, + warmup_ms: int, + warmup_runs: int, + func: Callable[ + [torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[List[List[int]]]], + torch.Tensor, + ], + bwd_only: bool = False, + grad: Optional[torch.Tensor] = None, +) -> None: + indices, offsets, weights, batch_size_per_feature_per_rank = request.unpack_4() + if warmup_ms: + start_time_ms = time.time() * 1000 + while time.time() * 1000 - start_time_ms < warmup_ms: + out = func(indices, offsets, weights, batch_size_per_feature_per_rank) + if bwd_only: + out.backward(grad) + else: + for _ in range(warmup_runs): + out = func(indices, offsets, weights, batch_size_per_feature_per_rank) + if bwd_only: + out.backward(grad) + + class BMBarrier: def __init__(self) -> None: @@ -266,7 +292,7 @@ def benchmark_requests( # noqa: C901 _ = torch.rand( flush_gpu_cache_size_mb * 1024 * 1024 // 4, dtype=torch.float, - device="cuda", + device=get_device(), ) start_events[it].record() @@ -308,6 +334,121 @@ def benchmark_requests( # noqa: C901 return median_time if check_median else avg_time +def benchmark_requests_with_spec( # noqa: C901 + requests: List[TBERequest], + func: Callable[ + [torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[List[List[int]]]], + torch.Tensor, + ], + flush_gpu_cache_size_mb: int = 0, + check_median: bool = False, + num_warmups: int = 0, + bwd_only: bool = False, + grad: Optional[torch.Tensor] = None, + # Used to label benchmark iterations differently in nsys profile result + # so that we can compare performance of two different models for example. + # If empty string is provided, it won't have any effect. + nvtx_range: str = "", + # Can be used to clear model's stats after warmup for example. + callback_after_warmup: Optional[Callable[[], None]] = None, + periodic_logs: bool = False, + warmup_ms: Optional[int] = None, + iters: int = -1, +) -> float: + times = [] + # Run at least one warmup iteration to avoid the long cudaLaunchKernel time + # for the first kernel if warmup_ms > 0 + # warmup_ms is prioritized over num_warmups + + if warmup_ms is None: + num_warmups = num_warmups + 1 if num_warmups >= 0 else 1 + + # warm-up the GPU before profiling + bench_warmup_with_spec( + requests[0], + # pyre-ignore[6] + warmup_ms, + num_warmups, + lambda indices, offsets, per_sample_weights, batch_size_per_feature_per_rank: func( + indices, offsets, per_sample_weights, batch_size_per_feature_per_rank + ), + bwd_only=bwd_only, + grad=grad, + ) + + if callback_after_warmup is not None: + callback_after_warmup() + + num_reqs = len(requests) + iters = num_reqs if iters == -1 else iters + + if torch.cuda.is_available(): + torch.cuda.synchronize() + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + else: + start_events = [] + end_events = [] + + for it in range(iters): + req = requests[it % num_reqs] + + indices, offsets, weights, batch_size_per_feature_per_rank = req.unpack_4() + # logging.info( + # f"[Benchmark Request] batch_size_per_feature_per_rank {batch_size_per_feature_per_rank} {indices.device}" + # ) + + if bwd_only: + # Run forward before profiling if does backward only + out = func(indices, offsets, weights, batch_size_per_feature_per_rank) + start_time = time.time() + if torch.cuda.is_available(): + if flush_gpu_cache_size_mb: + _ = torch.rand( + flush_gpu_cache_size_mb * 1024 * 1024 // 4, + dtype=torch.float, + device=get_device(), + ) + start_events[it].record() + + if nvtx_range: + torch.cuda.nvtx.range_push(f"{nvtx_range}-{it}") + + if bwd_only: + out.backward(grad) + else: + func(indices, offsets, weights, batch_size_per_feature_per_rank) + + if nvtx_range: + torch.cuda.nvtx.range_pop() + + if torch.cuda.is_available(): + end_events[it].record() + else: + it_time = time.time() - start_time + times.append(it_time) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + times = [ + start.elapsed_time(end) * 1.0e-3 + for start, end in zip(start_events, end_events) + ] + + if periodic_logs: + for it in range(100, iters + 1, 100): + times_ = times[0:it] + avg_time = sum(times_) / len(times_) * 1.0e6 + last_100_avg = sum(times_[-100:]) / 100 * 1.0e6 + logging.info( + f"Iteration [{it}/{len(requests)}]: Last 100: {last_100_avg:.2f} us, Running avg: {avg_time:.2f} us" + ) + + avg_time = sum(times) / iters + median_time = statistics.median(times) + return median_time if check_median else avg_time + + def benchmark_requests_refer( requests: List[TBERequest], T: int, @@ -348,7 +489,7 @@ def benchmark_requests_refer( _ = torch.rand( flush_gpu_cache_size_mb * 1024 * 1024 // 4, dtype=torch.float, - device="cuda", + device=get_device(), ) torch.cuda.synchronize() start_event.record() @@ -422,7 +563,7 @@ def benchmark_pipelined_requests( _ = torch.rand( flush_gpu_cache_size_mb * 1024 * 1024 // 4, dtype=torch.float, - device="cuda", + device=get_device(), ) torch.cuda.synchronize() start_event[0].record() diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py index 156b903256..8ad4c7c3f1 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py @@ -9,7 +9,7 @@ import dataclasses import json -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import torch @@ -45,6 +45,19 @@ class TBEDataConfig: pooling_params: PoolingParams # Force generated tensors to be on CPU use_cpu: bool = False + # Number of embeddings in each embedding features (number of rows) + Es: Optional[List[int]] = None + # Target embedding dimension for each features (number of columns) + Ds: Optional[List[int]] = None + # Maximum number of indices + max_indices: Optional[int] = None # Maximum number of indices + + def __post_init__(self) -> None: + if isinstance(self.D, list): + object.__setattr__(self, "mixed_dim", len(set(self.D)) > 1) + if isinstance(self.E, list) and self.max_indices is None: + object.__setattr__(self, "max_indices", sum(self.E) - 1) + self.validate() @staticmethod def complex_fields() -> Dict[str, Any]: @@ -81,7 +94,19 @@ def validate(self): # NOTE: Add validation logic here assert self.T > 0, "T must be positive" assert self.E > 0, "E must be positive" + if self.Es is not None: + assert all(e > 0 for e in self.Es), "All elements in Es must be positive" assert self.D > 0, "D must be positive" + if self.Ds is not None: + assert all(d > 0 for d in self.Ds), "All elements in Ds must be positive" + if isinstance(self.E, list) and isinstance(self.D, list): + assert ( + len(self.E) == len(self.D) == self.T + ), "Lengths of Es, Lengths of Ds, and T must be equal" + if self.max_indices is not None: + assert self.max_indices == ( + sum(self.Es) - 1 + ), "max_indices must be equal to sum(Es) - 1" self.batch_params.validate() self.indices_params.validate() self.pooling_params.validate() diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py index 70f520d568..d554e4a268 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py @@ -9,6 +9,7 @@ from typing import List, Optional, Tuple +import numpy as np import torch from fbgemm_gpu.tbe.bench.tbe_data_config import TBEDataConfig @@ -174,9 +175,16 @@ def _build_requests_dense( def generate_requests( tbe_data_config: TBEDataConfig, iters: int = 1, + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, ) -> List[TBERequest]: + # Generate batch sizes - Bs, Bs_feature_rank = _generate_batch_sizes(tbe_data_config) + if batch_size_per_feature_per_rank: + Bs = tbe_data_config.batch_params.Bs + else: + Bs, _ = _generate_batch_sizes(tbe_data_config) + + assert Bs is not None, "Batch sizes (Bs) must be set" # Generate pooling info L_offsets = _generate_pooling_info(tbe_data_config, iters, Bs) @@ -184,10 +192,92 @@ def generate_requests( # Generate indices all_indices = _generate_indices(tbe_data_config, iters, Bs, L_offsets) + # Build TBE requests + if tbe_data_config.variable_B() or tbe_data_config.variable_L(): + if batch_size_per_feature_per_rank: + return _build_requests_jagged( + tbe_data_config, + iters, + Bs, + batch_size_per_feature_per_rank, + L_offsets, + all_indices, + ) + else: + return _build_requests_jagged( + tbe_data_config, + iters, + Bs, + batch_size_per_feature_per_rank, + L_offsets, + all_indices, + ) + else: + return _build_requests_dense(tbe_data_config, iters, all_indices) + + +def generate_requests_with_Llist( + tbe_data_config: TBEDataConfig, + L_list: torch.Tensor, + iters: int = 1, + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, +) -> List[TBERequest]: + """ + Generate a list of TBERequest objects based on the provided TBE data configuration and L_list + This function generates batch sizes and pooling information from the input L_list, + simulates L distributions with Gaussian noise, and creates indices for embedding lookups. + It supports both variable batch sizes and sequence lengths, building either jagged or dense requests accordingly. + Args: + tbe_data_config (TBEDataConfig): Configuration object containing batch parameters and pooling parameters. + L_list (torch.Tensor): Tensor of base sequence lengths for each batch. + iters (int, optional): Number of iterations to repeat the generated requests. Defaults to 1. + batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Optional batch size specification per feature per rank. Defaults to None. + Returns: + List[TBERequest]: A list of TBERequest objects constructed according to the configuration and input parameters. + Raises: + AssertionError: If batch sizes (Bs) are not set in the tbe_data_config. + Example: + >>> requests = generate_requests_with_Llist(tbe_data_config, L_list=torch.tensor([10, 20]), iters=2) + >>> len(requests) + 2 + """ + + # Generate batch sizes + Bs = tbe_data_config.batch_params.Bs + assert ( + Bs is not None + ), "Batch sizes (Bs) must be set for generate_requests_with_Llist" + + # Generate pooling info from L list + Ls_list = [] + for i in range(len(Bs)): + L = L_list[i] + B = Bs[i] + Ls_iter = np.random.normal( + loc=L, scale=tbe_data_config.pooling_params.sigma_L, size=B + ).astype(int) + Ls_list.append(Ls_iter) + Ls = np.concatenate(Ls_list) + Ls[Ls < 0] = 0 + # Use the same L distribution across iters + Ls = np.tile(Ls, iters) + L = Ls.max() + # Make it exclusive cumsum + L_offsets = torch.from_numpy(np.insert(Ls.cumsum(), 0, 0)).to(torch.long) + + # Generate indices + all_indices = _generate_indices(tbe_data_config, iters, Bs, L_offsets) + all_indices = all_indices.to(get_device()) + # Build TBE requests if tbe_data_config.variable_B() or tbe_data_config.variable_L(): return _build_requests_jagged( - tbe_data_config, iters, Bs, Bs_feature_rank, L_offsets, all_indices + tbe_data_config, + iters, + Bs, + batch_size_per_feature_per_rank, + L_offsets, + all_indices, ) else: return _build_requests_dense(tbe_data_config, iters, all_indices) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_loader.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_loader.py index 222c14f6dd..56cdf55f18 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_loader.py @@ -77,12 +77,26 @@ def options(cls, func) -> click.Command: default=int(1e5), help=TBEDataConfigHelperText.TBE_NUM_EMBEDDINGS.value, ), + click.option( + "--tbe-num-embeddings-list", + type=str, + required=False, + default=None, + help="Comma-separated list of number of embeddings (Es)", + ), click.option( "--tbe-embedding-dim", type=int, default=128, help=TBEDataConfigHelperText.TBE_EMBEDDING_DIM.value, ), + click.option( + "--tbe-embedding-dim-list", + type=str, + required=False, + default=None, + help="Comma-separated list of number of Embedding dimensions (D)", + ), click.option( "--tbe-mixed-dim", is_flag=True, @@ -95,6 +109,13 @@ def options(cls, func) -> click.Command: default=False, help=TBEDataConfigHelperText.TBE_WEIGHTED.value, ), + click.option( + "--tbe-max-indices", + type=int, + required=False, + default=None, + help="(Optional) Maximum number of indices, will be calculated if not provided", + ), # Batch Parameters click.option( "--tbe-batch-size", @@ -102,6 +123,13 @@ def options(cls, func) -> click.Command: default=512, help=TBEDataConfigHelperText.TBE_BATCH_SIZE.value, ), + click.option( + "--tbe-batch-sizes-list", + type=str, + required=False, + default=None, + help="List Batch sizes per feature (Bs)", + ), click.option( "--tbe-batch-vbe-sigma", type=int, @@ -186,16 +214,33 @@ def load_from_context(cls, context: click.Context) -> TBEDataConfig: # Read table parameters T = params["tbe_num_tables"] E = params["tbe_num_embeddings"] + if params["tbe_num_embeddings_list"] is not None: + Es = [int(x) for x in params["tbe_num_embeddings_list"].split(",")] + else: + Es = None D = params["tbe_embedding_dim"] + if params["tbe_embedding_dim_list"] is not None: + Ds = [int(x) for x in params["tbe_embedding_dim_list"].split(",")] + else: + Ds = None + mixed_dim = params["tbe_mixed_dim"] weighted = params["tbe_weighted"] + if params["tbe_max_indices"] is not None: + max_indices = params["tbe_max_indices"] + else: + max_indices = None # Read batch parameters B = params["tbe_batch_size"] sigma_B = params["tbe_batch_vbe_sigma"] vbe_distribution = params["tbe_batch_vbe_dist"] vbe_num_ranks = params["tbe_batch_vbe_ranks"] - batch_params = BatchParams(B, sigma_B, vbe_distribution, vbe_num_ranks) + if params["tbe_batch_sizes_list"] is not None: + Bs = [int(x) for x in params["tbe_batch_sizes_list"].split(",")] + else: + Bs = None + batch_params = BatchParams(B, sigma_B, vbe_distribution, vbe_num_ranks, Bs) # Read indices parameters heavy_hitters = ( @@ -230,6 +275,9 @@ def load_from_context(cls, context: click.Context) -> TBEDataConfig: indices_params, pooling_params, not torch.cuda.is_available(), + Es, + Ds, + max_indices, ).validate() @classmethod diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py index 4df3d5596c..1519788ef2 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_param_models.py @@ -9,7 +9,7 @@ import dataclasses import json -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import torch @@ -98,6 +98,8 @@ class BatchParams: vbe_distribution: Optional[str] = "normal" # Number of ranks for variable batch size generation vbe_num_ranks: Optional[int] = None + # List of target batch sizes, i.e. number of batch lookups per table + Bs: Optional[List[int]] = None @classmethod # pyre-ignore [3] @@ -117,7 +119,10 @@ def json(self, format: bool = False) -> str: # pyre-ignore [3] def validate(self): - assert self.B > 0, "B must be positive" + if self.Bs is not None: + assert all(b > 0 for b in self.Bs), "All elements in Bs must be positive" + else: + assert self.B > 0, "B must be positive" assert not self.sigma_B or self.sigma_B > 0, "sigma_B must be positive" assert ( self.vbe_num_ranks is None or self.vbe_num_ranks > 0 From 91600441d30431371fc40700f41bf4e9f62def4a Mon Sep 17 00:00:00 2001 From: Shreyashri Biswas Date: Mon, 22 Sep 2025 16:09:05 +0000 Subject: [PATCH 2/2] workgroup tuning and loop unrolled --- .../forward/embedding_forward_split_kernel_template.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) mode change 100644 => 100755 fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu diff --git a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu old mode 100644 new mode 100755 index 0122cfcee9..5b13aefef8 --- a/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu +++ b/fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu @@ -469,10 +469,10 @@ using namespace fbgemm_gpu; {%- endif %} {%- if is_rocm %} - for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + kThreadGroupSize > L && l_start + j < L; ++j) { + for(auto j = L % kThreadGroupSize - L % kManualUnrollLength; l_start + (kThreadGroupSize/32) > L && l_start + j < L; ++j) { {%- else %} // Iterate over kThreadGroupSize indices - for (auto j = 0; j < kThreadGroupSize && l_start + j < L; ++j) { + for (auto j = 0; j < (kThreadGroupSize/32) && l_start + j < L; ++j) { {%- endif %} {%- if dense or lxu_miss_rate != "cache_conflict_miss_rate::zero" %} // Load index from thread j in the group @@ -641,7 +641,7 @@ batch_index_select_dim0_codegen_forward_kernel( {%- endif %} {%- if is_rocm %} // Unroll factor for ROCm devices - constexpr int kManualUnrollLength = 4; + constexpr int kManualUnrollLength = 8; {%- endif %} // Determine the linearized warp ID, and exit early if needed