Switch torch dependency from ~=2.9.1 to ~=2.10.0 (silent bfloat16 memory regression)#118
Switch torch dependency from ~=2.9.1 to ~=2.10.0 (silent bfloat16 memory regression)#118
Conversation
|
@claude review this PR 👀 👀 👀 |
|
Claude finished @hanaol's task in 2m 33s —— View job PR Review: Switch torch ~=2.9.1 → ~=2.10.0
Summary: The fix is correct and well-motivated. The bfloat16 Conv3d regression in torch 2.9.x is real (pytorch/pytorch#166643), and 2.10.0 cleanly resolves it. The benchmark methodology is solid. A few things worth flagging below. pyproject.tomlThe one-line change is correct: -"torch~=2.9.1",
+"torch~=2.10.0",
benchmark script (
|
Summary
This PR updates the torch dependency from ~=2.9.1 to ~=2.10.0 to fix a silent bfloat16 memory regression introduced in torch 2.9.0.
The problem
torch 2.9.0 and 2.9.1 contain a cuDNN regression that inflates the nn.Conv3d bfloat16 forward-pass workspace by 26x -- from ~77 MB to ~2,053 MB -- relative to both the preceding (2.8.0) and following (2.10.0) releases. These numbers were measured on a fixed tensor of shape [1, 32, 64, 64, 64] with a Conv3d(in=32, out=32, k=5, padding=2) layer. float32 memory is completely unaffected (stable at ~123 MB across all versions), confirming the bug is specific to the bfloat16 cuDNN kernel selection path.
This matters because we use (or plan to use) bf16-mixed precision training. This regression would silently consume an extra ~2 GB per Conv3d layer, directly undermining the memory savings that bf16 is supposed to provide -- without any crash or warning.
This issue has been raised in the PyTorch community:
F.conv3dwithbfloat16Inputs inPyTorch 2.9.0pytorch/pytorch#166643 (issue)Benchmark results (A100-SXM4-80GB, CUDA 12.8, input shape [1, 32, 64, 64, 64])
Peak GPU memory —
float32Peak GPU memory —
bfloat16The benchmark script is included at scripts/benchmark_conv3d_memory.py and can be run standalone on any CUDA node.
Decision: 2.10.0 vs 2.11.0
Both 2.10.0 and 2.11.0 are clean. This PR pins to 2.10.0 for now. Upgrading to 2.11.0 is possible but introduces a CUDA 13.0 dependency (vs 12.8 for all prior versions), which pulls in a new set of nvidia-*-cu13 libraries and we have not tested it against our full stack (lightning, etc.). Once our ecosystem catches up to CUDA 13.0, bumping to ~=2.11.0 is an option worth revisiting.
Files changed