From 2b7cbac94cf84c4103b824edda1e61dcc12eae21 Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Wed, 29 Apr 2026 17:18:29 -0700 Subject: [PATCH] Support `output_spec` for custom batch functions. PiperOrigin-RevId: 907857993 --- .../python/dataset/transformations/batch.py | 24 +++++++++++++- .../dataset/transformations/batch_test.py | 33 ++++++++++++++++++- 2 files changed, 55 insertions(+), 2 deletions(-) diff --git a/grain/_src/python/dataset/transformations/batch.py b/grain/_src/python/dataset/transformations/batch.py index da29eb8af..4a3665375 100644 --- a/grain/_src/python/dataset/transformations/batch.py +++ b/grain/_src/python/dataset/transformations/batch.py @@ -21,7 +21,7 @@ import math import pprint import sys -from typing import Any, Callable, TypeVar, cast +from typing import Any, Callable, Generic, Protocol, TypeVar, cast, runtime_checkable from grain._src.core import tree_lib from grain._src.python.dataset import base @@ -306,6 +306,24 @@ def __str__(self) -> str: ) +@runtime_checkable +class BatchFn(Protocol, Generic[S, T]): + """Custom batch function that support element spec inference. + + If you need a custom batch function with `ds.batch(batch_fn=...)`, you can + implement this protocol to allow `batch` to infer the element spec of the + batched dataset. If not implemented, the output element spec will be unknown. + """ + + def __call__(self, elements: Sequence[S]) -> T: + """Batches elements.""" + + def output_spec( + self, input_spec: Any, batch_size: int, drop_remainder: bool + ) -> Any: + """Returns the element spec for batches produced by this function.""" + + def _get_batch_element_spec( input_spec: Any, batch_size: int, @@ -317,6 +335,10 @@ def _get_batch_element_spec( wrapped_batch_fn = batch_fn if isinstance(batch_fn, functools.partial): wrapped_batch_fn = batch_fn.func + + if isinstance(wrapped_batch_fn, BatchFn): + return wrapped_batch_fn.output_spec(input_spec, batch_size, drop_remainder) + if wrapped_batch_fn is not make_batch and not isinstance( wrapped_batch_fn, _MakeBatchParallel ): diff --git a/grain/_src/python/dataset/transformations/batch_test.py b/grain/_src/python/dataset/transformations/batch_test.py index 08d7fcffe..b5622ea31 100644 --- a/grain/_src/python/dataset/transformations/batch_test.py +++ b/grain/_src/python/dataset/transformations/batch_test.py @@ -17,7 +17,7 @@ import functools import importlib import sys -from typing import Any +from typing import Any, Sequence from unittest import mock from absl.testing import absltest @@ -57,6 +57,25 @@ def output_spec(self, input_spec: Any) -> Any: } +class CustomBatchFn(batch.BatchFn): + + def __call__(self, elements: Sequence[Any]) -> dict[str, Any]: + return {"batch": batch.make_batch(elements)} + + def output_spec( + self, input_spec: Any, batch_size: int, drop_remainder: bool + ) -> Any: + batch_dim = batch_size if drop_remainder else None + return { + "batch": tree_lib.map_structure( + lambda s: base.ShapeDtypeStruct( + shape=(batch_dim,) + s.shape, dtype=s.dtype + ), + input_spec, + ) + } + + class MakeBatchTest(absltest.TestCase): def test_batch_zero_values_error(self): @@ -554,6 +573,18 @@ def test_element_spec( self.assertEqual(spec.shape, expected_shape) self.assertEqual(spec.dtype, np.int64) + def test_element_spec_custom_batch_fn(self): + ds = dataset.MapDataset.range(0, 10) + batch_size = 3 + ds = batch.BatchMapDataset( + ds, batch_size, drop_remainder=True, batch_fn=CustomBatchFn() + ) + spec = dataset.get_element_spec(ds) + self.assertEqual( + spec, + {"batch": base.ShapeDtypeStruct(shape=(batch_size,), dtype=np.int64)}, + ) + class BatchIterDatasetTest(parameterized.TestCase):