perf(autojac): Add optimize_gramian_computation to jac_to_grad#525
perf(autojac): Add optimize_gramian_computation to jac_to_grad#525ValerianRey merged 38 commits intomainfrom
Conversation
Codecov Report✅ All modified and coverable lines are covered by tests.
🚀 New features to boost your workflow:
|
Code reviewNo issues found. Checked for bugs and CLAUDE.md compliance. |
Tbh I don't like it very much (because it's an extra function + some cast is required) but it's the only way to easily test that the correct aggregators use the optimized _gramian_based method. I also tried using return type hint of TypeGuard[GramianWeightedAggergator] instead of bool for _can_skip_jacobian_combination, but it's not really correct since we also check that the aggregator has no forward hook, so that TypeGuard would be really weird. So in the end we have to use this cast.
|
See #524 (comment) for (a bit outdated) performance comparison. |
a48b0dd to
9d4c41c
Compare
|
With this PR, we rely a lot more on |
PierreQuinton
left a comment
There was a problem hiding this comment.
You can accept all commits together if you want, I prefer to use TypeGuards, it is prettier and clearer.
I think you missed one of my commit messages:
=> I asked claude if it was ok to use typeguard here, but it said that it should really only be used if the function only checks that the type of the aggregator is GramianWeightedAggergator. So i'd rather use cast just to be safe. |
|
I disagree with Claude here. The logic of typeguards is "if the check pass, then the type is A", not "the check passes if and only if the type is A". So here is is correct. Otherwise every function returning a typeguard would be exactly: def check_B(a: A) -> TypeGuard[B]:
return isinstance(a, B)But of course if this was the only intended use of TypeGuards, then they would have provided exactly this function instead. |
You're right. Please read this, it's short and very interesting: https://docs.python.org/3/library/typing.html#typing.TypeGuard Apparently there's also a So I think that:
Should we fire @claude? |
|
Claude finished @ValerianRey's task in 52s —— View job Analysis of TypeGuard vs TypeIsYou're absolutely right, and I apologize for my earlier incorrect guidance! Let me clarify the semantics: Todo List
TypeGuard vs TypeIs SemanticsTypeGuard:
TypeIs (Python 3.13+):
Analysis of Your Functions1.
|
|
Following our discussion where we considered making the Generics |
|
CUDA Traces Comparison
=> This seems too slow on cuda. I think we can't merge this until we have a way to parallelize the cuda kernel launches or whetever it is that takes this extra time. |
for more information, see https://pre-commit.ci
|
@claude Please review this. |
|
Claude finished @ValerianRey's task in 3m 3s —— View job Code Review
SummaryThe approach is mathematically sound and the CPU memory savings are compelling (up to 54% peak memory reduction for Issues1.
|
| Model | jac_to_grad | forward_backward |
|---|---|---|
| Cifar10Model | +35% | +14% |
| GroupNormMobileNetV3Small | +44% | +14% |
| InstanceNormMobileNetV2 | +201% | +13% |
The root cause is kernel launch overhead from calling compute_gramian (which uses matrix @ matrix.T) hundreds of times with small matrices on CUDA, versus one large tensordot in the original path. ValerianRey already flagged this as blocking — listing it here for completeness.
A potential mitigation would be to batch the gramian computations using torch.bmm: stack the flattened jacobians into a 3D tensor (n_layers, batch, params) and compute torch.bmm(J, J.transpose(1, 2)).sum(dim=0). This would trade a single batched kernel launch for the current loop of individual launches.
Minor observations
_utils.pyfix is correct: Removingjacobians[1:](which raisesTypeErrorondequesince slices aren't supported) and replacing withjacobiansis necessary and functionally equivalent._compute_gramian_sumholds all gramians before summing: The list comprehension[compute_gramian(m) for m in jacobians]materializes all gramians beforesum()accumulates them. Since gramians aren×n(small), this is negligible, but an incremental accumulation would be marginally more memory-efficient.- TypeGuard usage (
_can_skip_jacobian_combination): Correct —TypeGuardis appropriate here because returningFalsedoesn't mean the aggregator is not aGramianWeightedAggregator(it could be one with hooks). Well-resolved in the thread.







Avoid concatenation of the jacobians when the aggregator is gramian-based. Also use a deque to free each jacobian as soon as it is used.