Skip to content

Comments

perf(autojac): Add optimize_gramian_computation to jac_to_grad#525

Merged
ValerianRey merged 38 commits intomainfrom
optimize_jac_to_grad
Feb 23, 2026
Merged

perf(autojac): Add optimize_gramian_computation to jac_to_grad#525
ValerianRey merged 38 commits intomainfrom
optimize_jac_to_grad

Conversation

@ValerianRey
Copy link
Contributor

@ValerianRey ValerianRey commented Jan 23, 2026

Avoid concatenation of the jacobians when the aggregator is gramian-based. Also use a deque to free each jacobian as soon as it is used.

  • Add gramian-based jac_to_grad
  • Update changelog

@ValerianRey ValerianRey added package: autojac cc: refactor Conventional commit type for any refactoring, not user-facing, and not typing or perf improvements labels Jan 23, 2026
@ValerianRey ValerianRey self-assigned this Jan 23, 2026
@codecov
Copy link

codecov bot commented Jan 23, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

Files with missing lines Coverage Δ
src/torchjd/_linalg/_gramian.py 100.00% <100.00%> (ø)
src/torchjd/autojac/_jac_to_grad.py 100.00% <100.00%> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@claude
Copy link

claude bot commented Jan 23, 2026

Code review

No issues found. Checked for bugs and CLAUDE.md compliance.

ValerianRey and others added 9 commits January 28, 2026 13:43
Tbh I don't like it very much (because it's an extra function + some cast is required) but it's the only way to easily test that the correct aggregators use the optimized _gramian_based method. I also tried using return type hint of TypeGuard[GramianWeightedAggergator] instead of bool for _can_skip_jacobian_combination, but it's not really correct since we also check that the aggregator has no forward hook, so that TypeGuard would be really weird. So in the end we have to use this cast.
@ValerianRey
Copy link
Contributor Author

See #524 (comment) for (a bit outdated) performance comparison.

@ValerianRey
Copy link
Contributor Author

ValerianRey commented Jan 29, 2026

With this PR, we rely a lot more on compute_gramian being competitively fast, because we call it hundreds of times instead of just once per iteration. So I optimized it for the case contracted_dims=-1 (the most usual case). This has a non-negligible impact on performance (for instance, reducing time of jac_to_grad from 38 ms to 32 ms for InstanceNormMobileNetV2()-bs2

Copy link
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can accept all commits together if you want, I prefer to use TypeGuards, it is prettier and clearer.

@ValerianRey
Copy link
Contributor Author

You can accept all commits together if you want, I prefer to use TypeGuards, it is prettier and clearer.

I think you missed one of my commit messages:

Tbh I don't like it very much (because it's an extra function + some cast is required) but it's the only way to easily test that the correct aggregators use the optimized _gramian_based method. I also tried using return type hint of TypeGuard[GramianWeightedAggergator] instead of bool for _can_skip_jacobian_combination, but it's not really correct since we also check that the aggregator has no forward hook, so that TypeGuard would be really weird. So in the end we have to use this cast.

=> I asked claude if it was ok to use typeguard here, but it said that it should really only be used if the function only checks that the type of the aggregator is GramianWeightedAggergator. So i'd rather use cast just to be safe.

@PierreQuinton
Copy link
Contributor

PierreQuinton commented Jan 29, 2026

I disagree with Claude here. The logic of typeguards is "if the check pass, then the type is A", not "the check passes if and only if the type is A". So here is is correct. Otherwise every function returning a typeguard would be exactly:

def check_B(a: A) -> TypeGuard[B]:
    return isinstance(a, B)

But of course if this was the only intended use of TypeGuards, then they would have provided exactly this function instead.

@ValerianRey
Copy link
Contributor Author

I disagree with Claude here. The logic of typeguards is "if the check pass, then the type is A", not "the check passes if and only if the type is A". So here is is correct. Otherwise every function returning a typeguard would be exactly:

def check_B(a: A) -> TypeGuard[B]:
    return isinstance(a, B)

But of course if this was the only intended use of TypeGuards, then they would have provided exactly this function instead.

You're right. Please read this, it's short and very interesting: https://docs.python.org/3/library/typing.html#typing.TypeGuard

Apparently there's also a TypeIs, which has a few differences with TypeGuard. Here, we wouldn't be allowed to use TypeIs, because returning False with a TypeIs would tell the type checker that the type is not what is in the brackets of the TypeIs. But with TypeGuard, it doesn't say anything to the type checker when returning False.

So I think that:

  • This can safely be a TypeGuard
  • TypeGuard can be changed to TypeIs in is_matrix, is_psd_tensor and is_psd_matrix

Should we fire @claude?

@claude
Copy link

claude bot commented Jan 30, 2026

Claude finished @ValerianRey's task in 52s —— View job


Analysis of TypeGuard vs TypeIs

You're absolutely right, and I apologize for my earlier incorrect guidance! Let me clarify the semantics:

Todo List

  • Read the current implementation in _jac_to_grad.py
  • Read other files with type guard functions
  • Analyze TypeGuard vs TypeIs semantics
  • Provide recommendations

TypeGuard vs TypeIs Semantics

TypeGuard:

  • When the function returns True, the type checker narrows the type to the guarded type
  • When the function returns False, the type checker learns nothing about the type
  • This is appropriate when the check is sufficient but not necessary for the type

TypeIs (Python 3.13+):

  • When the function returns True, the type checker narrows the type to the guarded type
  • When the function returns False, the type checker narrows the type to exclude the guarded type
  • This is appropriate when the check is both necessary and sufficient for the type (i.e., if and only if)

Analysis of Your Functions

1. _can_skip_jacobian_combination (src/torchjd/autojac/_jac_to_grad.py:85)

Current implementation:

def _can_skip_jacobian_combination(aggregator: Aggregator) -> bool:
    return isinstance(aggregator, GramianWeightedAggregator) and not _has_forward_hook(aggregator)

Recommendation: Use TypeGuard[GramianWeightedAggregator]

This is the correct choice because:

  • When True: The aggregator is indeed a GramianWeightedAggregator (and has no hooks, but that's an additional runtime constraint)
  • When False: The aggregator might still be a GramianWeightedAggregator (if it has hooks), so we can't exclude that type

Should NOT use TypeIs because returning False doesn't mean the aggregator is NOT a GramianWeightedAggregator - it just means we can't use the optimization.

2. is_matrix, is_psd_tensor, is_psd_matrix (src/torchjd/_linalg/_matrix.py:26-38)

Current implementation:

def is_matrix(t: Tensor) -> TypeGuard[Matrix]:
    return t.ndim == 2

def is_psd_tensor(t: Tensor) -> TypeGuard[PSDTensor]:
    half_dim = t.ndim // 2
    return t.ndim % 2 == 0 and t.shape[:half_dim] == t.shape[: half_dim - 1 : -1]

def is_psd_matrix(t: Tensor) -> TypeGuard[PSDMatrix]:
    return t.ndim == 2 and t.shape[0] == t.shape[1]

Recommendation: Use TypeIs for is_matrix

For is_matrix:

  • When True: The tensor is a 2D tensor (Matrix)
  • When False: The tensor is definitively NOT a 2D tensor (not a Matrix)
  • This is an if-and-only-if condition, so TypeIs is semantically correct

Recommendation: Keep TypeGuard for is_psd_tensor and is_psd_matrix

For is_psd_tensor and is_psd_matrix:

  • The checks verify shape requirements but NOT the actual PSD property (as noted in comments)
  • When False: A tensor could still conceptually be a PSDTensor/PSDMatrix even if it fails the shape check (e.g., if we had a different representation)
  • More importantly, these functions don't check the actual positive semi-definiteness, which is part of the semantic contract of these types
  • So these are sufficient but not strictly necessary conditions

However, there's an argument for TypeIs here too: if we consider the type definitions to be purely about shape (since PSD checking is expensive and done separately in tests), then TypeIs would be appropriate. But TypeGuard is safer and more conservative.

Summary

  1. _can_skip_jacobian_combination: Should use TypeGuard[GramianWeightedAggregator] ✅ (PierreQuinton is correct)
  2. is_matrix: Could use TypeIs[Matrix] (it's an exact if-and-only-if condition)
  3. is_psd_tensor and is_psd_matrix: Should keep TypeGuard (they don't check the full semantic meaning of PSD)

And no, you shouldn't fire me - I was being overly conservative before! 😄 The key insight is that TypeGuard is appropriate when you're checking sufficient conditions that don't rule out the type when they fail.


@ValerianRey
Copy link
Contributor Author

CUDA memory timelines:
image
image
image
image
image
image
image

@PierreQuinton
Copy link
Contributor

Following our discussion where we considered making the Generics Weighting[PSDMatrix] and Weighting[Matrix] into abstract classes which we can then use to differentiate the method used for jac_to_grad, we could in principle avoid testing instances and rather have methods in each specifying how we handle a list of Jacobians. This is good in terms of not have cases, but is bad in terms of having a less clear cut algebraic meaning to the aggregation package.

@ValerianRey
Copy link
Contributor Author

CUDA Traces Comparison

Model Batch Size Time before (jac_to_grad) Time after (jac_to_grad) Difference (jac_to_grad) Time before (autojac_forward_backward) Time after (autojac_forward_backward) Difference (autojac_forward_backward)
AlexNet 4 98 ms 98 ms 0 ms (+1%) 132 ms 129 ms -2 ms (-2%)
Cifar10Model 64 121 ms 163 ms +42 ms (+35%) 184 ms 210 ms +26 ms (+14%)
GroupNormMobileNetV3Small 8 27 ms 39 ms +12 ms (+44%) 87 ms 99 ms +12 ms (+14%)
InstanceNormMobileNetV2 2 6 ms 19 ms +13 ms (+201%) 93 ms 105 ms +12 ms (+13%)
InstanceNormResNet18 4 41 ms 41 ms 0 ms (0%) 84 ms 85 ms +2 ms (+2%)
SqueezeNet 4 18 ms 22 ms +4 ms (+21%) 40 ms 43 ms +3 ms (+8%)
WithTransformerLarge 4 76 ms 78 ms +2 ms (+3%) 175 ms 184 ms +9 ms (+5%)

=> This seems too slow on cuda. I think we can't merge this until we have a way to parallelize the cuda kernel launches or whetever it is that takes this extra time.

@github-actions github-actions bot changed the title refactor(autojac): Optimize jac_to_grad perf(autojac): Optimize jac_to_grad Feb 23, 2026
@ValerianRey
Copy link
Contributor Author

@claude Please review this.

@ValerianRey ValerianRey changed the title perf(autojac): Optimize jac_to_grad perf(autojac): Add optimize_gramian_computation to jac_to_grad Feb 23, 2026
@claude
Copy link

claude bot commented Feb 23, 2026

Claude finished @ValerianRey's task in 3m 3s —— View job


Code Review

  • Reading changed files and analyzing code
  • Providing review feedback

Summary

The approach is mathematically sound and the CPU memory savings are compelling (up to 54% peak memory reduction for WithTransformerLarge). The code is generally well-structured. Four issues worth addressing:


Issues

1. _compute_gramian_sum uses sum() starting from integer 0 — fragile

src/torchjd/autojac/_jac_to_grad.py:199

def _compute_gramian_sum(jacobians: deque[Tensor]) -> PSDMatrix:
    gramian = sum([compute_gramian(matrix) for matrix in jacobians])
    return cast(PSDMatrix, gramian)

sum([t1, t2, ...]) starts with 0 + t1, relying on Tensor.__radd__(0) working correctly. This is fragile. A clearer pattern:

from functools import reduce
import operator

def _compute_gramian_sum(jacobians: deque[Tensor]) -> PSDMatrix:
    gramian = reduce(operator.add, (compute_gramian(matrix) for matrix in jacobians))
    return cast(PSDMatrix, gramian)

Or equivalently: initialize with the first element and accumulate over the rest.

Fix this →

2. Error message references a private class

src/torchjd/autojac/_jac_to_grad.py:125–129

raise ValueError(
    "In order to use `jac_to_grad` with `optimize_gramian_computation=True`, you must "
    "provide a `GramianWeightedAggregator` that doesn't have any forward hooks attached"
    " to it."
)

GramianWeightedAggregator is defined in _aggregator_bases.py (private) and is not in aggregation.__all__. Users relying on the public API won't know what this refers to. The message should describe the requirement in terms users can act on — either list the compatible public aggregators (AlignedMTL, DualProj, IMTLG, Krum, MGDA, PCGrad, UPGrad, CAGrad) or say "an aggregator that computes weights based on the Gramian".

This is also what PierreQuinton's recent review comment refers to.

Fix this →

3. The ..note:: in the docstring is inaccurate for the new path

src/torchjd/autojac/_jac_to_grad.py:72–76

The note says the function "concatenates those matrices into a combined Jacobian matrix", but when optimize_gramian_computation=True it deliberately skips this step. The note should clarify that this only applies to the default (optimize_gramian_computation=False) path, or be updated to describe both cases.

4. CUDA regression (acknowledged as blocking)

Per ValerianRey's own benchmarks, CUDA shows significant slowdowns for smaller models:

Model jac_to_grad forward_backward
Cifar10Model +35% +14%
GroupNormMobileNetV3Small +44% +14%
InstanceNormMobileNetV2 +201% +13%

The root cause is kernel launch overhead from calling compute_gramian (which uses matrix @ matrix.T) hundreds of times with small matrices on CUDA, versus one large tensordot in the original path. ValerianRey already flagged this as blocking — listing it here for completeness.

A potential mitigation would be to batch the gramian computations using torch.bmm: stack the flattened jacobians into a 3D tensor (n_layers, batch, params) and compute torch.bmm(J, J.transpose(1, 2)).sum(dim=0). This would trade a single batched kernel launch for the current loop of individual launches.


Minor observations

  • _utils.py fix is correct: Removing jacobians[1:] (which raises TypeError on deque since slices aren't supported) and replacing with jacobians is necessary and functionally equivalent.
  • _compute_gramian_sum holds all gramians before summing: The list comprehension [compute_gramian(m) for m in jacobians] materializes all gramians before sum() accumulates them. Since gramians are n×n (small), this is negligible, but an incremental accumulation would be marginally more memory-efficient.
  • TypeGuard usage (_can_skip_jacobian_combination): Correct — TypeGuard is appropriate here because returning False doesn't mean the aggregator is not a GramianWeightedAggregator (it could be one with hooks). Well-resolved in the thread.

@ValerianRey ValerianRey merged commit 77309f4 into main Feb 23, 2026
15 checks passed
@ValerianRey ValerianRey deleted the optimize_jac_to_grad branch February 23, 2026 17:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: perf Conventional commit type for changes mostly focused on performance improvements (memory or speed). package: autojac

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants