Skip to content

Add sketch of new Tensor interface#2457

Draft
dennisYatunin wants to merge 7 commits intomainfrom
dy/tensors
Draft

Add sketch of new Tensor interface#2457
dennisYatunin wants to merge 7 commits intomainfrom
dy/tensors

Conversation

@dennisYatunin
Copy link
Copy Markdown
Member

This PR refactors the Geometry module to use a much simpler interface, so that all vector/tensor operations can expressed using standard math operations instead of custom API functions. This will allow us to remove a large amount of duplicate code, reduce compilation latency, and speed up GPU runs by optimizing the geometry data passed to each kernel.

  • Code follows the style guidelines OR N/A.
  • Unit tests are included OR N/A.
  • Code is exercised in an integration test OR N/A.
  • Documentation has been added/updated OR N/A.

@dennisYatunin dennisYatunin force-pushed the dy/tensors branch 23 times, most recently from d1c85a9 to dc71352 Compare February 21, 2026 03:40
@dennisYatunin dennisYatunin force-pushed the dy/tensors branch 6 times, most recently from 271e3cd to 1f8a383 Compare February 25, 2026 06:57
Copy link
Copy Markdown
Member Author

@dennisYatunin dennisYatunin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking pretty great! My main comments are that we should avoid allowing tensor adjoints to be specified in two different ways, and that a lot of the code remaining in conversions.jl can be simplified or eliminated. Also, it would be helpful to add detailed docstrings for Basis, Metric, Tensor, and TensorWithAnyBasis, with some examples that indicate how reshape works and when it gets called.

Comment thread test/runtests.jl
UnitTest("Geometry" ,"Geometry/geometry.jl"),
UnitTest("rmul_with_projection" ,"Geometry/rmul_with_projection.jl"),
UnitTest("AxisTensors" ,"Geometry/axistensors.jl"),
UnitTest("Tensors" ,"Geometry/tensors.jl"),
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
UnitTest("Tensors" ,"Geometry/tensors.jl"),
UnitTest("Tensors" ,"Geometry/tensors.jl"),

Comment thread test/Spaces/opt_spaces.jl
Comment on lines -42 to +44
test_n_failures(1147, TU.SphereSpectralElementSpace, context)
test_n_failures(1146, TU.CenterExtrudedFiniteDifferenceSpace, context)
test_n_failures(1146, TU.FaceExtrudedFiniteDifferenceSpace, context)
test_n_failures(872, TU.SphereSpectralElementSpace, context)
test_n_failures(881, TU.CenterExtrudedFiniteDifferenceSpace, context)
test_n_failures(881, TU.FaceExtrudedFiniteDifferenceSpace, context)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, nice to see these go down by so much. I was not expecting AxisTensors to account for a quarter of our inference failures here.

@test S ≈ S_ref

@test norm(S_scalar) ≈ norm(Geometry.components(S_scalar))
@test norm(S_scalar) ≈ norm(Geometry.parent(S_scalar))
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@test norm(S_scalar) norm(Geometry.parent(S_scalar))
@test norm(S_scalar) norm(parent(S_scalar))

Comment on lines -41 to 46
AxisTensor{
GFT,
Tensor{
2,
Tuple{CovariantAxis{(3,)}, ContravariantAxis{(3,)}},
SArray{Tuple{1, 1}, GFT, 2, 1},
GFT,
Tuple{Basis{Covariant, (3,)}, Basis{Contravariant, (3,)}},
SMatrix{1, 1, GFT, 1},
},
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
AxisTensor{
GFT,
Tensor{
2,
Tuple{CovariantAxis{(3,)}, ContravariantAxis{(3,)}},
SArray{Tuple{1, 1}, GFT, 2, 1},
GFT,
Tuple{Basis{Covariant, (3,)}, Basis{Contravariant, (3,)}},
SMatrix{1, 1, GFT, 1},
},
typeof(C3(GFT(0)) * CT3(GFT(0))'),

Bit easier to read if you write out the value, since C3 and CT3 are already defined

Comment thread test/Geometry/tensor_conversion_benchmarks.jl
Comment on lines +109 to 115
function blockmat(a::Tensor{2}, b::Tensor{2}, ::Nothing = nothing)
new_bases = (
combine_bases(axes(a, 1), axes(b, 1)),
combine_bases(axes(a, 2), axes(b, 2)),
)
return reshape(a, new_bases) + reshape(b, new_bases)
end
Copy link
Copy Markdown
Member Author

@dennisYatunin dennisYatunin Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this equivalent to the automatic reshaping performed by +? I feel like you can just replace blockmat with + without changing any results, but maybe I'm missing something.

Comment on lines +5 to +7
# AbstractCovector (Tensor{2} with ScalarBasis) is already covered by AbstractTensor.
# Adjoint{T, <:AbstractTensor} covers the case where adjoint() returns a Julia Adjoint
# wrapper rather than our Covector type (e.g., from composition or old codepaths).
Copy link
Copy Markdown
Member Author

@dennisYatunin dennisYatunin Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd strongly recommend dropping support for Adjoint{<:Any, <:AbstractTensor}, if that's possible. The new interface is designed to ensure that all tensors are subtypes of AbstractTensor, so that we don't need to define duplicate code for adjoints. This simplification would let you eliminate a lot of the code that was added here and in the MatrixFields module.

Comment on lines +299 to +302
_∂x∂ξ_bases2D = (
Geometry.Basis{Geometry.Orthonormal, AIdx}(),
Geometry.Basis{Geometry.Covariant, AIdx}(),
)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_∂x∂ξ_bases2D = (
Geometry.Basis{Geometry.Orthonormal, AIdx}(),
Geometry.Basis{Geometry.Covariant, AIdx}(),
)

I don't see this used anywhere


function promote_axis_tensor(
at::Geometry.AxisTensor{T, N, A, S},
at::Geometry.Tensor{N, T, B, S},
Copy link
Copy Markdown
Member Author

@dennisYatunin dennisYatunin Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
at::Geometry.Tensor{N, T, B, S},
at::Geometry.Tensor,

Unused parameters should be avoided in type constraints (it's not great for readability, and I think it also sightly increases latency)

Comment thread src/Spaces/pointspace.jl
(Geometry.LocalAxis{AIdx}(), Geometry.CovariantAxis{AIdx}()),
FT(1.0),
Geometry.Tensor(
SMatrix{1, 1}(FT(1.0)),
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
SMatrix{1, 1}(FT(1.0)),
I,

If you can avoid converting I to an SMatrix when constructing a Tensor, using I here might eliminate some unnecessary floating-point multiplications.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants