Nodes/get worker info parallel mapper#1546
Conversation
…orch#1388) Passes num_workers to _apply_udf and sets a WorkerInfo on each worker before the processing loop so torch.utils.data.get_worker_info() works in process workers and torchdata.nodes.get_worker_info() (thread-local, always correct) works in both thread and process workers. Exports get_worker_info from torchdata.nodes. https://claude.ai/code/session_01SjAXQKfwGwDdLAzXdevKwt
- Add type: ignore[import] for private torch.utils.data._utils.worker module - Import WorkerInfo from public torch.utils.data API for proper typing - Fix get_worker_info() return type to Optional[WorkerInfo] - Add type: ignore[arg-type] for dataset=None (not a Dataset[Any]) - Add type: ignore[attr-defined] for private _worker_info attribute assignment https://claude.ai/code/session_01SjAXQKfwGwDdLAzXdevKwt
- Remove extra blank lines flagged by black formatter - Move get_worker_info to end of __all__ (lowercase sorts after uppercase) https://claude.ai/code/session_01SjAXQKfwGwDdLAzXdevKwt
Remove top-level 'from torch.utils.data import WorkerInfo' — WorkerInfo is not part of the public torch.utils.data API in all supported versions. Construct WorkerInfo through the already-ignored private module and use Optional[Any] as the return type of get_worker_info(). https://claude.ai/code/session_01SjAXQKfwGwDdLAzXdevKwt
… coverage type: ignore only suppresses errors on the line it appears — splitting the call across lines left dataset=None on an uncovered line. https://claude.ai/code/session_01SjAXQKfwGwDdLAzXdevKwt
test_thread_workers_each_id_used and test_thread_workers_unique_seeds asserted that all worker IDs appeared across N items, but thread work distribution is non-deterministic — fast workers can drain the queue before slower ones start, so some IDs may never appear. Replace with a subset check that only validates correctness (id in [0, num_workers)). https://claude.ai/code/session_01SjAXQKfwGwDdLAzXdevKwt
| _thread_local = threading.local() | ||
|
|
||
|
|
||
| def get_worker_info() -> Optional[Any]: |
There was a problem hiding this comment.
This shouldn't live inside _apply_udf. maybe a utils file in nodes/ ?
There was a problem hiding this comment.
Good point. Moved get_worker_info(), the thread-local storage, and the _set_worker_info() helper into a dedicated torchdata/nodes/_worker_info.py. _apply_udf.py now just calls _set_worker_info(worker_id, num_workers) and has no worker-info logic of its own.
| seed = torch.initial_seed() + worker_id | ||
| worker_info = _worker_module.WorkerInfo(id=worker_id, num_workers=num_workers, seed=seed, dataset=None) # type: ignore[attr-defined,arg-type] | ||
| # Thread-local: always returns the correct info for this worker, regardless of | ||
| # whether other workers (threads) have set their own worker info concurrently. | ||
| _thread_local.worker_info = worker_info | ||
| # Module-level global: correct for process workers (isolated memory); for thread | ||
| # workers this may race, so callers should use torchdata.nodes.get_worker_info(). | ||
| _worker_module._worker_info = worker_info # type: ignore[attr-defined] |
There was a problem hiding this comment.
I am not sure of this change. Overall it is safe, but I don't think its very clean to have this inside _apply_udf. Plus we are leaving dataset=None.
There was a problem hiding this comment.
Agreed on the placement, addressed in the same commit by moving everything to ``_worker_info.py.`
On dataset=None: ParallelMapper maps over arbitrary items rather than a PyTorch Dataset object, so there is no meaningful dataset to pass. We still construct a torch.utils.data.WorkerInfo (rather than a custom dataclass) so that torch.utils.data.get_worker_info() continues to work inside process workers, keeping compatibility with existing IterableDataset code that calls it. The # type: ignore[arg-type] on that line suppresses the mypy complaint. Happy to discuss if there's a cleaner alternative you have in mind.
Per reviewer feedback, get_worker_info() and the WorkerInfo setup logic don't belong in the private _apply_udf.py implementation file. - New torchdata/nodes/_worker_info.py owns the thread-local storage, get_worker_info(), and _set_worker_info() helper - _apply_udf.py calls _set_worker_info() and is otherwise unchanged - __init__.py imports get_worker_info from _worker_info - dataset=None is documented: ParallelMapper maps over items, not PyTorch Dataset objects, so there is no dataset to pass https://claude.ai/code/session_01SjAXQKfwGwDdLAzXdevKwt
psaikaushik
left a comment
There was a problem hiding this comment.
Addressed the comments. Please take a look @divyanshk
Summary
Adds num_workers parameter to _apply_udf and sets a WorkerInfo object on each worker before the processing loop begins, so user-defined functions can identify which worker they are running in.
Introduces torchdata.nodes.get_worker_info() backed by threading.local() — always returns the correct WorkerInfo for the calling worker regardless of thread vs. process mode.
For process workers, torch.utils.data.get_worker_info() also works correctly (each process has isolated memory).
For thread workers, torchdata.nodes.get_worker_info() is the safe API; torch.utils.data.get_worker_info() may race between threads so the thread-local version is preferred.
Exports get_worker_info from torchdata.nodes.
Fixes #1388
Test plan
test_none_outside_worker — returns None when called from outside any worker
test_thread_workers — all items report a valid worker_id ∈ [0, num_workers) and correct num_workers
test_thread_workers_each_id_used — all worker ids appear across 40 items with 4 workers
test_thread_workers_unique_seeds — each worker reports one consistent seed; seeds differ across workers
test_process_workers_get_worker_info — torch.utils.data.get_worker_info() works in forkserver process workers
test_process_workers_torchdata_get_worker_info — torchdata.nodes.get_worker_info() works in process workers
test_num_workers_zero_no_worker_info — inline (num_workers=0) UDF sees None