Skip to content

Nodes/get worker info parallel mapper#1546

Open
psaikaushik wants to merge 12 commits into
meta-pytorch:mainfrom
psaikaushik:nodes/get-worker-info-parallel-mapper
Open

Nodes/get worker info parallel mapper#1546
psaikaushik wants to merge 12 commits into
meta-pytorch:mainfrom
psaikaushik:nodes/get-worker-info-parallel-mapper

Conversation

@psaikaushik
Copy link
Copy Markdown
Contributor

Summary

Adds num_workers parameter to _apply_udf and sets a WorkerInfo object on each worker before the processing loop begins, so user-defined functions can identify which worker they are running in.
Introduces torchdata.nodes.get_worker_info() backed by threading.local() — always returns the correct WorkerInfo for the calling worker regardless of thread vs. process mode.
For process workers, torch.utils.data.get_worker_info() also works correctly (each process has isolated memory).
For thread workers, torchdata.nodes.get_worker_info() is the safe API; torch.utils.data.get_worker_info() may race between threads so the thread-local version is preferred.
Exports get_worker_info from torchdata.nodes.

Fixes #1388

Test plan

test_none_outside_worker — returns None when called from outside any worker
test_thread_workers — all items report a valid worker_id ∈ [0, num_workers) and correct num_workers
test_thread_workers_each_id_used — all worker ids appear across 40 items with 4 workers
test_thread_workers_unique_seeds — each worker reports one consistent seed; seeds differ across workers
test_process_workers_get_worker_info — torch.utils.data.get_worker_info() works in forkserver process workers
test_process_workers_torchdata_get_worker_info — torchdata.nodes.get_worker_info() works in process workers
test_num_workers_zero_no_worker_info — inline (num_workers=0) UDF sees None

…orch#1388)

Passes num_workers to _apply_udf and sets a WorkerInfo on each worker
before the processing loop so torch.utils.data.get_worker_info() works
in process workers and torchdata.nodes.get_worker_info() (thread-local,
always correct) works in both thread and process workers. Exports
get_worker_info from torchdata.nodes.

https://claude.ai/code/session_01SjAXQKfwGwDdLAzXdevKwt
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 16, 2026
@psaikaushik psaikaushik self-assigned this Apr 16, 2026
- Add type: ignore[import] for private torch.utils.data._utils.worker module
- Import WorkerInfo from public torch.utils.data API for proper typing
- Fix get_worker_info() return type to Optional[WorkerInfo]
- Add type: ignore[arg-type] for dataset=None (not a Dataset[Any])
- Add type: ignore[attr-defined] for private _worker_info attribute assignment

https://claude.ai/code/session_01SjAXQKfwGwDdLAzXdevKwt
@divyanshk divyanshk self-requested a review April 16, 2026 21:59
- Remove extra blank lines flagged by black formatter
- Move get_worker_info to end of __all__ (lowercase sorts after uppercase)

https://claude.ai/code/session_01SjAXQKfwGwDdLAzXdevKwt
Remove top-level 'from torch.utils.data import WorkerInfo' — WorkerInfo
is not part of the public torch.utils.data API in all supported versions.
Construct WorkerInfo through the already-ignored private module and use
Optional[Any] as the return type of get_worker_info().

https://claude.ai/code/session_01SjAXQKfwGwDdLAzXdevKwt
… coverage

type: ignore only suppresses errors on the line it appears — splitting the
call across lines left dataset=None on an uncovered line.

https://claude.ai/code/session_01SjAXQKfwGwDdLAzXdevKwt
@psaikaushik psaikaushik marked this pull request as draft April 16, 2026 22:21
test_thread_workers_each_id_used and test_thread_workers_unique_seeds
asserted that all worker IDs appeared across N items, but thread work
distribution is non-deterministic — fast workers can drain the queue
before slower ones start, so some IDs may never appear. Replace with
a subset check that only validates correctness (id in [0, num_workers)).

https://claude.ai/code/session_01SjAXQKfwGwDdLAzXdevKwt
@psaikaushik psaikaushik marked this pull request as ready for review April 17, 2026 00:54
Comment thread torchdata/nodes/_apply_udf.py Outdated
_thread_local = threading.local()


def get_worker_info() -> Optional[Any]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't live inside _apply_udf. maybe a utils file in nodes/ ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Moved get_worker_info(), the thread-local storage, and the _set_worker_info() helper into a dedicated torchdata/nodes/_worker_info.py. _apply_udf.py now just calls _set_worker_info(worker_id, num_workers) and has no worker-info logic of its own.

Comment thread torchdata/nodes/_apply_udf.py Outdated
Comment on lines +66 to +73
seed = torch.initial_seed() + worker_id
worker_info = _worker_module.WorkerInfo(id=worker_id, num_workers=num_workers, seed=seed, dataset=None) # type: ignore[attr-defined,arg-type]
# Thread-local: always returns the correct info for this worker, regardless of
# whether other workers (threads) have set their own worker info concurrently.
_thread_local.worker_info = worker_info
# Module-level global: correct for process workers (isolated memory); for thread
# workers this may race, so callers should use torchdata.nodes.get_worker_info().
_worker_module._worker_info = worker_info # type: ignore[attr-defined]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure of this change. Overall it is safe, but I don't think its very clean to have this inside _apply_udf. Plus we are leaving dataset=None.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed on the placement, addressed in the same commit by moving everything to ``_worker_info.py.`

On dataset=None: ParallelMapper maps over arbitrary items rather than a PyTorch Dataset object, so there is no meaningful dataset to pass. We still construct a torch.utils.data.WorkerInfo (rather than a custom dataclass) so that torch.utils.data.get_worker_info() continues to work inside process workers, keeping compatibility with existing IterableDataset code that calls it. The # type: ignore[arg-type] on that line suppresses the mypy complaint. Happy to discuss if there's a cleaner alternative you have in mind.

Per reviewer feedback, get_worker_info() and the WorkerInfo setup
logic don't belong in the private _apply_udf.py implementation file.

- New torchdata/nodes/_worker_info.py owns the thread-local storage,
  get_worker_info(), and _set_worker_info() helper
- _apply_udf.py calls _set_worker_info() and is otherwise unchanged
- __init__.py imports get_worker_info from _worker_info
- dataset=None is documented: ParallelMapper maps over items, not
  PyTorch Dataset objects, so there is no dataset to pass

https://claude.ai/code/session_01SjAXQKfwGwDdLAzXdevKwt
@psaikaushik psaikaushik requested a review from divyanshk April 21, 2026 10:17
Copy link
Copy Markdown
Contributor Author

@psaikaushik psaikaushik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed the comments. Please take a look @divyanshk

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Nodes] Ensure get_worker_info works correctly in ParallelMapper

2 participants