Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions test/nodes/test_get_worker_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from typing import Any, Dict

import torch
from torch.testing._internal.common_utils import IS_WINDOWS, TestCase
from torchdata.nodes import get_worker_info, IterableWrapper, ParallelMapper

from .utils import MockSource


def _capture_worker_info(item: Dict[str, Any]) -> Dict[str, Any]:
"""UDF that augments the item with the current worker's WorkerInfo fields."""
info = get_worker_info()
item = dict(item)
if info is None:
item["worker_id"] = None
item["num_workers"] = None
else:
item["worker_id"] = info.id
item["num_workers"] = info.num_workers
return item


def _capture_torch_worker_info(item: Dict[str, Any]) -> Dict[str, Any]:
"""UDF that reads worker info via torch.utils.data.get_worker_info() (process workers only)."""
info = torch.utils.data.get_worker_info()
item = dict(item)
item["worker_id"] = info.id if info is not None else None
item["num_workers"] = info.num_workers if info is not None else None
return item


class TestGetWorkerInfo(TestCase):
def test_none_outside_worker(self) -> None:
self.assertIsNone(get_worker_info())

def test_thread_workers(self) -> None:
num_workers = 4
src = MockSource(num_samples=40)
node = ParallelMapper(src, _capture_worker_info, num_workers=num_workers, method="thread", in_order=False)

results = list(node)
self.assertEqual(len(results), 40)

# All reported worker ids must be in [0, num_workers)
for r in results:
self.assertIn(r["worker_id"], set(range(num_workers)))
self.assertEqual(r["num_workers"], num_workers)

@unittest.skipIf(IS_WINDOWS, "forkserver not supported on Windows")
def test_process_workers_get_worker_info(self) -> None:
"""torch.utils.data.get_worker_info() works correctly in process workers."""
num_workers = 2
src = MockSource(num_samples=8)
node = ParallelMapper(
src,
_capture_torch_worker_info,
num_workers=num_workers,
method="process",
multiprocessing_context="forkserver",
in_order=False,
)

results = list(node)
self.assertEqual(len(results), 8)
for r in results:
self.assertIn(r["worker_id"], set(range(num_workers)))
self.assertEqual(r["num_workers"], num_workers)

@unittest.skipIf(IS_WINDOWS, "forkserver not supported on Windows")
def test_process_workers_torchdata_get_worker_info(self) -> None:
"""torchdata.nodes.get_worker_info() works correctly in process workers."""
num_workers = 2
src = MockSource(num_samples=8)
node = ParallelMapper(
src,
_capture_worker_info,
num_workers=num_workers,
method="process",
multiprocessing_context="forkserver",
in_order=False,
)

results = list(node)
self.assertEqual(len(results), 8)
for r in results:
self.assertIn(r["worker_id"], set(range(num_workers)))
self.assertEqual(r["num_workers"], num_workers)

def test_num_workers_zero_no_worker_info(self) -> None:
"""With num_workers=0 (inline), get_worker_info() returns None inside UDF."""
src = MockSource(num_samples=5)
node = ParallelMapper(src, _capture_worker_info, num_workers=0)
results = list(node)
for r in results:
self.assertIsNone(r["worker_id"])
self.assertIsNone(r["num_workers"])
3 changes: 2 additions & 1 deletion torchdata/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._worker_info import get_worker_info
from .adapters import IterableWrapper, MapStyleWrapper, SamplerWrapper
from .base_node import BaseNode, T
from .batch import Batcher, Unbatcher
Expand All @@ -19,7 +20,6 @@
from .shuffler import Shuffler
from .types import Stateful


__all__ = [
"BaseNode",
"Batcher",
Expand All @@ -40,6 +40,7 @@
"StopCriteria",
"T",
"Unbatcher",
"get_worker_info",
]

assert sorted(__all__) == __all__
5 changes: 4 additions & 1 deletion torchdata/nodes/_apply_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@

import torch
import torch.multiprocessing as mp

from torch._utils import ExceptionWrapper

from ._worker_info import _set_worker_info
from .constants import QUEUE_TIMEOUT


Expand All @@ -23,6 +23,7 @@ def _apply_udf(
out_q: Union[queue.Queue, mp.Queue],
udf: Callable,
stop_event: Union[threading.Event, python_mp_synchronize.Event],
num_workers: int = 1,
):
"""_apply_udf assumes in_q emits tuples of (x, idx) where x is the
payload, idx is the index of the result, potentially used for maintaining
Expand All @@ -31,6 +32,8 @@ def _apply_udf(
StopIteration from in_q).
"""
torch.set_num_threads(1)
_set_worker_info(worker_id, num_workers)

while True:
if stop_event.is_set() and in_q.empty():
break
Expand Down
55 changes: 55 additions & 0 deletions torchdata/nodes/_worker_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import threading
from typing import Any, Optional

import torch
import torch.utils.data._utils.worker as _worker_module # type: ignore[import]

_thread_local = threading.local()


def get_worker_info() -> Optional[Any]:
"""Return a :class:`~torch.utils.data.WorkerInfo` for the current
:class:`~torchdata.nodes.ParallelMapper` worker, or ``None`` if called
from outside a worker context.

Unlike :func:`torch.utils.data.get_worker_info`, this function uses
thread-local storage and is therefore correct for both thread-based and
process-based :class:`~torchdata.nodes.ParallelMapper` workers.

The returned object has the following attributes:

* ``id`` (int): the worker index (0 to num_workers - 1)
* ``num_workers`` (int): total number of workers
* ``seed`` (int): per-worker seed derived from the initial RNG seed
* ``dataset``: always ``None`` — :class:`~torchdata.nodes.ParallelMapper`
operates on items, not datasets

Returns:
A ``WorkerInfo`` object, or ``None`` when called from outside a worker.
"""
return getattr(_thread_local, "worker_info", None)


def _set_worker_info(worker_id: int, num_workers: int) -> None:
"""Set up WorkerInfo for the current thread/process so that
:func:`get_worker_info` and :func:`torch.utils.data.get_worker_info`
return valid info from inside a UDF.

``dataset`` is ``None`` because :class:`~torchdata.nodes.ParallelMapper`
maps over arbitrary items rather than a PyTorch Dataset object.
"""
seed = torch.initial_seed() + worker_id
worker_info = _worker_module.WorkerInfo(id=worker_id, num_workers=num_workers, seed=seed, dataset=None) # type: ignore[attr-defined,arg-type]
# Thread-local storage: always returns the correct info for this worker
# even when multiple thread-workers run concurrently in the same process.
_thread_local.worker_info = worker_info
# Module-level global: works correctly for process workers (each process
# has isolated memory). For thread workers this may race between workers,
# so thread-based callers should use torchdata.nodes.get_worker_info().
_worker_module._worker_info = worker_info # type: ignore[attr-defined]
2 changes: 2 additions & 0 deletions torchdata/nodes/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def __init__(
self._intermed_q,
self.map_fn,
self._stop,
self.num_workers,
)

elif self.method == "process":
Expand All @@ -213,6 +214,7 @@ def __init__(
self._intermed_q,
self.map_fn,
self._mp_stop,
self.num_workers,
)
self._workers.append(mp_context.Process(target=_apply_udf, args=_args, daemon=True))
for t in self._workers:
Expand Down
Loading