diff --git a/test/nodes/test_get_worker_info.py b/test/nodes/test_get_worker_info.py new file mode 100644 index 000000000..8fa91a8c7 --- /dev/null +++ b/test/nodes/test_get_worker_info.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Any, Dict + +import torch +from torch.testing._internal.common_utils import IS_WINDOWS, TestCase +from torchdata.nodes import get_worker_info, IterableWrapper, ParallelMapper + +from .utils import MockSource + + +def _capture_worker_info(item: Dict[str, Any]) -> Dict[str, Any]: + """UDF that augments the item with the current worker's WorkerInfo fields.""" + info = get_worker_info() + item = dict(item) + if info is None: + item["worker_id"] = None + item["num_workers"] = None + else: + item["worker_id"] = info.id + item["num_workers"] = info.num_workers + return item + + +def _capture_torch_worker_info(item: Dict[str, Any]) -> Dict[str, Any]: + """UDF that reads worker info via torch.utils.data.get_worker_info() (process workers only).""" + info = torch.utils.data.get_worker_info() + item = dict(item) + item["worker_id"] = info.id if info is not None else None + item["num_workers"] = info.num_workers if info is not None else None + return item + + +class TestGetWorkerInfo(TestCase): + def test_none_outside_worker(self) -> None: + self.assertIsNone(get_worker_info()) + + def test_thread_workers(self) -> None: + num_workers = 4 + src = MockSource(num_samples=40) + node = ParallelMapper(src, _capture_worker_info, num_workers=num_workers, method="thread", in_order=False) + + results = list(node) + self.assertEqual(len(results), 40) + + # All reported worker ids must be in [0, num_workers) + for r in results: + self.assertIn(r["worker_id"], set(range(num_workers))) + self.assertEqual(r["num_workers"], num_workers) + + @unittest.skipIf(IS_WINDOWS, "forkserver not supported on Windows") + def test_process_workers_get_worker_info(self) -> None: + """torch.utils.data.get_worker_info() works correctly in process workers.""" + num_workers = 2 + src = MockSource(num_samples=8) + node = ParallelMapper( + src, + _capture_torch_worker_info, + num_workers=num_workers, + method="process", + multiprocessing_context="forkserver", + in_order=False, + ) + + results = list(node) + self.assertEqual(len(results), 8) + for r in results: + self.assertIn(r["worker_id"], set(range(num_workers))) + self.assertEqual(r["num_workers"], num_workers) + + @unittest.skipIf(IS_WINDOWS, "forkserver not supported on Windows") + def test_process_workers_torchdata_get_worker_info(self) -> None: + """torchdata.nodes.get_worker_info() works correctly in process workers.""" + num_workers = 2 + src = MockSource(num_samples=8) + node = ParallelMapper( + src, + _capture_worker_info, + num_workers=num_workers, + method="process", + multiprocessing_context="forkserver", + in_order=False, + ) + + results = list(node) + self.assertEqual(len(results), 8) + for r in results: + self.assertIn(r["worker_id"], set(range(num_workers))) + self.assertEqual(r["num_workers"], num_workers) + + def test_num_workers_zero_no_worker_info(self) -> None: + """With num_workers=0 (inline), get_worker_info() returns None inside UDF.""" + src = MockSource(num_samples=5) + node = ParallelMapper(src, _capture_worker_info, num_workers=0) + results = list(node) + for r in results: + self.assertIsNone(r["worker_id"]) + self.assertIsNone(r["num_workers"]) diff --git a/torchdata/nodes/__init__.py b/torchdata/nodes/__init__.py index daee61d9d..4645b9c98 100644 --- a/torchdata/nodes/__init__.py +++ b/torchdata/nodes/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from ._worker_info import get_worker_info from .adapters import IterableWrapper, MapStyleWrapper, SamplerWrapper from .base_node import BaseNode, T from .batch import Batcher, Unbatcher @@ -19,7 +20,6 @@ from .shuffler import Shuffler from .types import Stateful - __all__ = [ "BaseNode", "Batcher", @@ -40,6 +40,7 @@ "StopCriteria", "T", "Unbatcher", + "get_worker_info", ] assert sorted(__all__) == __all__ diff --git a/torchdata/nodes/_apply_udf.py b/torchdata/nodes/_apply_udf.py index ae272b4db..c5f0e990e 100644 --- a/torchdata/nodes/_apply_udf.py +++ b/torchdata/nodes/_apply_udf.py @@ -11,9 +11,9 @@ import torch import torch.multiprocessing as mp - from torch._utils import ExceptionWrapper +from ._worker_info import _set_worker_info from .constants import QUEUE_TIMEOUT @@ -23,6 +23,7 @@ def _apply_udf( out_q: Union[queue.Queue, mp.Queue], udf: Callable, stop_event: Union[threading.Event, python_mp_synchronize.Event], + num_workers: int = 1, ): """_apply_udf assumes in_q emits tuples of (x, idx) where x is the payload, idx is the index of the result, potentially used for maintaining @@ -31,6 +32,8 @@ def _apply_udf( StopIteration from in_q). """ torch.set_num_threads(1) + _set_worker_info(worker_id, num_workers) + while True: if stop_event.is_set() and in_q.empty(): break diff --git a/torchdata/nodes/_worker_info.py b/torchdata/nodes/_worker_info.py new file mode 100644 index 000000000..3a26a9f89 --- /dev/null +++ b/torchdata/nodes/_worker_info.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import threading +from typing import Any, Optional + +import torch +import torch.utils.data._utils.worker as _worker_module # type: ignore[import] + +_thread_local = threading.local() + + +def get_worker_info() -> Optional[Any]: + """Return a :class:`~torch.utils.data.WorkerInfo` for the current + :class:`~torchdata.nodes.ParallelMapper` worker, or ``None`` if called + from outside a worker context. + + Unlike :func:`torch.utils.data.get_worker_info`, this function uses + thread-local storage and is therefore correct for both thread-based and + process-based :class:`~torchdata.nodes.ParallelMapper` workers. + + The returned object has the following attributes: + + * ``id`` (int): the worker index (0 to num_workers - 1) + * ``num_workers`` (int): total number of workers + * ``seed`` (int): per-worker seed derived from the initial RNG seed + * ``dataset``: always ``None`` — :class:`~torchdata.nodes.ParallelMapper` + operates on items, not datasets + + Returns: + A ``WorkerInfo`` object, or ``None`` when called from outside a worker. + """ + return getattr(_thread_local, "worker_info", None) + + +def _set_worker_info(worker_id: int, num_workers: int) -> None: + """Set up WorkerInfo for the current thread/process so that + :func:`get_worker_info` and :func:`torch.utils.data.get_worker_info` + return valid info from inside a UDF. + + ``dataset`` is ``None`` because :class:`~torchdata.nodes.ParallelMapper` + maps over arbitrary items rather than a PyTorch Dataset object. + """ + seed = torch.initial_seed() + worker_id + worker_info = _worker_module.WorkerInfo(id=worker_id, num_workers=num_workers, seed=seed, dataset=None) # type: ignore[attr-defined,arg-type] + # Thread-local storage: always returns the correct info for this worker + # even when multiple thread-workers run concurrently in the same process. + _thread_local.worker_info = worker_info + # Module-level global: works correctly for process workers (each process + # has isolated memory). For thread workers this may race between workers, + # so thread-based callers should use torchdata.nodes.get_worker_info(). + _worker_module._worker_info = worker_info # type: ignore[attr-defined] diff --git a/torchdata/nodes/map.py b/torchdata/nodes/map.py index a73346f2e..79fd78bc4 100644 --- a/torchdata/nodes/map.py +++ b/torchdata/nodes/map.py @@ -202,6 +202,7 @@ def __init__( self._intermed_q, self.map_fn, self._stop, + self.num_workers, ) elif self.method == "process": @@ -213,6 +214,7 @@ def __init__( self._intermed_q, self.map_fn, self._mp_stop, + self.num_workers, ) self._workers.append(mp_context.Process(target=_apply_udf, args=_args, daemon=True)) for t in self._workers: