Skip to content

Commit 3f514e8

Browse files
Ickasertpappdevmotion
authored
[FR] Transform which maps a tuple onto a user-provided struct (#163)
* Sketch out the idea * Rename to Transform, not Transformation * Add ConstructionBase * Try implementing proper methods * minor fixes * Add methods for tuple, namedtuple, and scalar; add tests for each, with inverse * Add printing for tuple and scalar cases * Make a separate ScalarWrapperTransform * Eliminate method ambiguity by repeating a StaticArrays method for `as` * Drop full scalar transform interface for typewrapper, in favor of convenience constructor wrapping in tuple * Make precompile work * match namedtuple to struct fields * Switch semantics: NamedTuples unpack as kwargs * Implement inverses which do their best; add inference tests by using utility function * Help 1.10 constprop out a little bit to reach full inference, add a couple more tests * Formatting improvements Co-authored-by: Tamas K. Papp <tkpapp@gmail.com> * Dead code * Error for different number of fields than tuple fields, rather than returning first n fields * Test pretty printing on aggregations, make array of scalars go inline * Add an inline array transform test for longer scalar transform * Chase coverage on array pretty printing test * Docs suggestions from review Co-authored-by: David Müller-Widmann <devmotion@users.noreply.github.com> Co-authored-by: Tamas K. Papp <tkpapp@gmail.com> * Remove scalar case from struct transforms --------- Co-authored-by: Tamás K. Papp <tkpapp@gmail.com> Co-authored-by: David Müller-Widmann <devmotion@users.noreply.github.com>
1 parent c34d342 commit 3f514e8

5 files changed

Lines changed: 266 additions & 13 deletions

File tree

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "0.8.20"
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
88
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
99
CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b"
10+
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
1011
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1112
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1213
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -32,6 +33,7 @@ ArgCheck = "1, 2"
3233
ChangesOfVariables = "0.1"
3334
Compat = "4.10.0"
3435
CompositionsBase = "0.1.2"
36+
ConstructionBase = "1.6.0"
3537
DocStringExtensions = "0.8, 0.9"
3638
ForwardDiff = "0.10, 1"
3739
InverseFunctions = "0.1"

docs/src/index.md

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,39 @@ t_abc = merge(t_a, t_b, t_c)
9898
t_collision = merge(t_a, as((;a = asℝ₋))) # Will have a = asℝ₋, from rightmost
9999
```
100100

101+
In some of these cases, it may be helpful for some of the transformed variables to be wrapped in a user-provided type.
102+
For example, given
103+
```julia
104+
struct Foo
105+
a
106+
b
107+
end
108+
struct Bar
109+
c
110+
d
111+
end
112+
```
113+
it may be useful to transform a flat vector `[1, 2, 3, 4]` into a named tuple like `(e = 1, foo = Foo(2, Bar(3, 4)))`. This can be achieved with a transform like the following:
114+
```julia
115+
tb = as(Bar, (Identity(), Identity()))
116+
tf = as(Foo, (Identity(), tb))
117+
t = as((; e = Identity(), foo = tf))
118+
```
119+
where each instance of `Identity()` here could be replaced with an arbitrary scalar transform.
120+
121+
This relies on [`constructorof` from ConstructionBase](https://juliaobjects.github.io/ConstructionBase.jl/stable/#ConstructionBase.constructorof). If a `Tuple` transform is wrapped in a type this way, transform results will be unpacked and passed to the constructor, hence in the same order as the transform.
122+
If a `NamedTuple` transform is used, it will be unpacked as keyword arguments to the constructor of the type.
123+
`Tuple` or `NamedTuple` transforms that do not provide proper arguments to the constructor of a given type will simply result in `MethodError`s.
124+
125+
For structs that are constructed from a single argument, wrap a single transform in a `Tuple`, like so:
126+
```julia
127+
struct Baz; f; end
128+
tbaz = as(Baz, (Identity(),))
129+
```
130+
131+
Inverting these transforms would, in the most general case, require inverting the constructor of the given type, which may itself have several valid dispatches.
132+
Rather than conduct a close inspection of user types and their available constructors, inverting a `Tuple`-based struct transform will check that the struct has exactly `n` fields, where `n` is the length of the `Tuple` transform, and use those fields in their struct order; if there are not `n` fields the inversion will error. This will work for any structs with only the default constructor. Inverting a `NamedTuple`-based struct transform will attempt to use struct fields with names matching the names in the `NamedTuple` transform, which fails if the struct does not have matching fields.
133+
101134
## Scalar transforms
102135

103136
The symbol `` is a placeholder for infinity. It does not correspond to `Inf`, but acts as a placeholder for the correct dispatch. `-∞` is valid.
@@ -131,7 +164,7 @@ TVShift
131164
TVNeg
132165
```
133166

134-
Consistent with common notation, transforms are applied right-to-left; for example, `as(Real, ∞, 3)` is equivalent to `TVShift(3) ∘ TVNeg() ∘ TVExp()`.
167+
Consistent with common notation, transforms are applied right-to-left; for example, `as(Real, -∞, 3)` is equivalent to `TVShift(3) ∘ TVNeg() ∘ TVExp()`.
135168
If you are working in an editor where typing Unicode is difficult, `TransformVariables.compose` is also available, as in `TransformVariables.compose(TVScale(5.0), TVNeg(), TVExp())`.
136169

137170
This composition works with any scalar transform in any order, so `TVScale(4) ∘ as(Real, 2, ∞) ∘ TVShift(1e3)` is a valid transform.

src/TransformVariables.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using LinearAlgebra: UpperTriangular, logabsdet, norm, rmul!
99
using Random: AbstractRNG, GLOBAL_RNG
1010
using StaticArrays: MMatrix, SMatrix, SArray, SVector, pushfirst
1111
using CompositionsBase
12+
using ConstructionBase
1213

1314
include("utilities.jl")
1415
include("generic.jl")

src/aggregation.jl

Lines changed: 89 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ end
1717
function _summary_rows(transformation::ArrayTransformation, mime)
1818
(; inner_transformation, dims) = transformation
1919
_dims = foldr((a,b) -> "$(string(a))×$(string(b))", dims, init = "")
20+
if inner_transformation isa ScalarTransform
21+
return _summary_row(transformation, _dims*string(inner_transformation))
22+
end
2023
rows = _summary_row(transformation, _dims)
2124
for row in _summary_rows(inner_transformation, mime)
2225
push!(rows, (level = row.level + 1, indices = nothing, repr = row.repr))
@@ -177,7 +180,14 @@ as(SArray{2,3}, asℝ₊, 2, 3) # transform to a 2x3 SMatrix of positive num
177180
as(SVector{3}) # ℝ³ → ℝ³, identity, but an SVector
178181
```
179182
"""
180-
function as(::Type{<:SArray{S}}, inner_transformation = Identity()) where S
183+
function as(::Type{<:SArray{S}}, inner_transformation::AbstractTransform) where S
184+
dim = fieldtypes(S)
185+
@argcheck all(x -> x 1, dim)
186+
StaticArrayTransformation{prod(dim),S,typeof(inner_transformation)}(inner_transformation)
187+
end
188+
# Repeated with more specific typing to eliminate method ambiguity with
189+
# the ScalarWrapperTransform method for `as`
190+
function as(::Type{<:SArray{S}}, inner_transformation::ScalarTransform = Identity()) where S
181191
dim = fieldtypes(S)
182192
@argcheck all(x -> x 1, dim)
183193
StaticArrayTransformation{prod(dim),S,typeof(inner_transformation)}(inner_transformation)
@@ -264,7 +274,6 @@ struct TransformTuple{T} <: VectorTransform
264274
end
265275
end
266276

267-
268277
"""
269278
$(SIGNATURES)
270279
@@ -282,20 +291,25 @@ Base.getindex(t::TransformTuple, i::Int) = getindex(_inner(t), i)
282291
Base.propertynames(t::TransformTuple) = propertynames(_inner(t))
283292
Base.getproperty(t::TransformTuple, i::Int) = getproperty(_inner(t), i)
284293
Base.getproperty(t::TransformTuple{<:NamedTuple}, i::Symbol) = getproperty(_inner(t), i)
294+
285295
"""
286296
$(SIGNATURES)
287297
288298
Merge multiple `TransformTuple{<:NamedTuple}` by merging the underlying `NamedTuple`s.
289299
"""
290-
function Base.merge(t1::TransformTuple{<:NamedTuple},
300+
function Base.merge(t1::TransformTuple{<:NamedTuple},
291301
ts::Vararg{TransformTuple{<:NamedTuple}})
292302
TransformTuple(merge(_inner(t1), map(_inner, ts)...))
293303
end
294304

295305
function _summary_rows(transformation::TransformTuple, mime)
296306
inner = _inner(transformation)
297307
repr1 = string(nameof(typeof(inner)), " of transformations")
298-
rows = _summary_row(transformation, repr1)
308+
_tuple_summary_rows(repr1, transformation, mime)
309+
end
310+
function _tuple_summary_rows(named, transformation::TransformTuple, mime)
311+
inner = _inner(transformation)
312+
rows = _summary_row(transformation, named)
299313
_index = 0
300314
for (key, t) in pairs(inner)
301315
for row in _summary_rows(t, mime)
@@ -482,3 +496,74 @@ function _domain_label(t::TransformTuple, index::Int)
482496
end
483497
error("internal error")
484498
end
499+
500+
####
501+
#### type wrapper transformation
502+
####
503+
504+
"""
505+
$(TYPEDEF)
506+
"""
507+
struct TypeWrapperTransform{T,S} <: VectorTransform
508+
inner_transformation::S
509+
end
510+
511+
function as(::Type{T}, inner_transformation::S) where {T,S<:TransformTuple}
512+
@argcheck isstructtype(T)
513+
TypeWrapperTransform{T,S}(inner_transformation)
514+
end
515+
516+
as(::Type{T}, inner_transformation::NTransforms) where T = as(T, as(inner_transformation))
517+
518+
dimension(t::TypeWrapperTransform) = dimension(t.inner_transformation)
519+
520+
function _summary_rows(transformation::TypeWrapperTransform{T, S}, mime) where {T, S<:TransformTuple}
521+
(; inner_transformation) = transformation
522+
innerinner = _inner(inner_transformation)
523+
name = string("$T wrapper on ", nameof(typeof(innerinner)), " of transformations")
524+
_tuple_summary_rows(name, inner_transformation, mime)
525+
end
526+
527+
function transform_with(flag::LogJacFlag, t::TypeWrapperTransform{T}, x, index) where T
528+
(; inner_transformation) = t
529+
y, ℓ, index′ = transform_with(flag, inner_transformation, x, index)
530+
ctor = constructorof(T)
531+
ctor(y...), ℓ, index′
532+
end
533+
function transform_with(flag::LogJacFlag, t::TypeWrapperTransform{C, T}, x, index) where {C, N, T<:TransformTuple{<:NamedTuple{N}}}
534+
(; inner_transformation) = t
535+
y, ℓ, index′ = transform_with(flag, inner_transformation, x, index)
536+
ctor = constructorof(C)
537+
ctor(;y...), ℓ, index′
538+
end
539+
540+
# NamedTuple inner transformations
541+
function inverse_eltype(t::TypeWrapperTransform{C, S}, ::Type{T}) where {C, N, T<:C, S<:TransformTuple{<:NamedTuple{N}}}
542+
used_names = filter(n->n fieldnames(T), N)
543+
types = map(n -> fieldtype(T, n), used_names)
544+
inverse_eltype(t.inner_transformation, NamedTuple{used_names,Tuple{types...}})
545+
end
546+
function inverse_at!(x, index, t::TypeWrapperTransform{T, S}, y::T) where {T, N, S<:TransformTuple{<:NamedTuple{N}}}
547+
yvals = NamedTuple{N}(map(n->(getfield(y,n)), N))
548+
inverse_at!(x, index, t.inner_transformation, yvals)
549+
end
550+
551+
# Regular Tuple inner transformation
552+
function inverse_eltype(t::TypeWrapperTransform{C, S}, ::Type{T}) where {C, T<:C, S<:TransformTuple}
553+
num_args = length(t.inner_transformation)
554+
inverse_eltype(t.inner_transformation, Tuple{fieldtypes(T)[begin:num_args]...})
555+
end
556+
function inverse_at!(x, index, t::TypeWrapperTransform{T, S}, y::T) where {T, S<:TransformTuple}
557+
inner = t.inner_transformation
558+
num_args = length(inner)
559+
if length(fieldnames(typeof(y))) != num_args
560+
throw(ArgumentError("The provided type $T has a different number of fields than the inner transformation, so it cannot be inverted."))
561+
end
562+
yvals = Tuple(getfield(y, i) for i in 1:num_args)
563+
inverse_at!(x, index, inner, yvals)
564+
end
565+
566+
# Informative error for trying to invert an incompatible type
567+
function inverse_eltype(t::TypeWrapperTransform{C, S}, ::Type{T}) where {C, T, S<:TransformTuple}
568+
throw(ArgumentError("Cannot invert a $T as if it were a $C"))
569+
end

test/runtests.jl

Lines changed: 140 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ end
412412
@test tt[1] == inner[1]
413413
@test tt[2] == inner[2]
414414
@test tt[3] == inner[3]
415+
@test tt[1] == inner[1]
415416
@test_throws BoundsError tt[4]
416417
@test propertynames(tt) == propertynames(inner)
417418
@test (@set tt[3] = t2) == as((t1, t2, t2))
@@ -542,6 +543,107 @@ end
542543

543544
end
544545

546+
@testset "transform to custom type" begin
547+
548+
struct CustomType{A, B}
549+
a::A
550+
b::B
551+
end
552+
@kwdef struct KwCustomType{A, B}
553+
a::A
554+
b::B
555+
end
556+
struct MyType{C}
557+
c::C
558+
end
559+
struct YourType
560+
d::Float64
561+
end
562+
563+
# From named tuple to type
564+
t1 = as((a = asℝ, b = asℝ))
565+
t2 = as(CustomType, t1)
566+
@test_throws MethodError transform(t2, [1.0, 2.0])
567+
568+
t1 = as((a = asℝ, b = asℝ))
569+
t2 = as(KwCustomType, t1)
570+
x = [1.0, 2.0]
571+
test_transformation(t2, y -> y isa KwCustomType; N=1, jac=false)
572+
@test_throws ArgumentError inverse(t2, [1.0, 2.0])
573+
@test_throws ArgumentError inverse(t2, MyType(3.0))
574+
575+
# Named tuple with different ordering
576+
t1 = as((b = asℝ, a = asℝ))
577+
t2 = as(KwCustomType, t1)
578+
y = @inferred transform(t2, [1.0, 2.0])
579+
@test y == KwCustomType(a = 2.0, b = 1.0)
580+
test_transformation(t2, y -> y isa KwCustomType; N=1, jac=false)
581+
582+
# Named tuple with wrong number or names of fields
583+
t1 = as((;b = asℝ))
584+
t2 = as(KwCustomType, t1)
585+
@test_throws UndefKeywordError transform(t2, [1.0])
586+
@test inverse(t2, KwCustomType(1.0, 3.0)) == [3.0]
587+
t1 = as((a = asℝ, c = asℝ))
588+
t2 = as(KwCustomType, t1)
589+
@test_throws UndefKeywordError transform(t2, [1.0, 3.0])
590+
@test_throws ArgumentError inverse(t2, KwCustomType(1.0, 3.0))
591+
t1 = as((b = asℝ, a = asℝ, c = asℝ))
592+
t2 = as(KwCustomType, t1)
593+
@test_throws MethodError transform(t2, [1.0, 2.0, 3.0])
594+
@test_throws ArgumentError inverse(t2, KwCustomType(1.0, 3.0))
595+
596+
# Type with shortened constructor
597+
struct MaskedType{A, B}
598+
a::A
599+
b::B
600+
end
601+
MaskedType(x) = MaskedType(x, nothing)
602+
MaskedType(;b=0.0) = MaskedType(nothing, b)
603+
t = as(MaskedType, as((;b=asℝ)))
604+
test_transformation(t, y -> y isa MaskedType; N=1, jac=false)
605+
606+
# No kwarg constructor accepts `a` arg, so errors
607+
t = as(MaskedType, as((;a=asℝ)))
608+
@test_throws MethodError transform(t, [1.0])
609+
610+
# When constructor accepts less args than struct has fields,
611+
# inverse errors
612+
t = as(MaskedType, (asℝ,))
613+
x = [1.0]
614+
y = transform(t, x)
615+
@test y == MaskedType(1.0, nothing)
616+
@test_throws ArgumentError inverse(t, MaskedType(1.0, nothing))
617+
# test_transformation(t, y -> y isa MaskedType; N=1, jac=false)
618+
619+
# Insufficient arguments in provided tuple for constructor
620+
# Not specially caught, but good to check
621+
t1 = as(CustomType, (asℝ₊,))
622+
@test_throws MethodError transform(t1, [1.0])
623+
624+
# From tuple to type
625+
t1 = as(ntuple(i->asℝ₊, Val(2)))
626+
t2 = as(CustomType, t1)
627+
test_transformation(t2, y -> y isa CustomType; N=1, jac=false)
628+
# Trying to invert from another type should error
629+
@test_throws ArgumentError inverse(t2, [1.0, 2.0])
630+
@test_throws ArgumentError inverse(t2, MyType(3.0))
631+
632+
# Nested custom types
633+
t1 = as(MyType, (asℝ₋,))
634+
t2 = as(YourType, (asℝ₋,))
635+
t3 = as(KwCustomType, as((a = t1, b = t2)))
636+
x = [0.0, -1]
637+
y = transform(t3, x)
638+
test_transformation(t3, y -> y isa KwCustomType; N=1, jac=false)
639+
# Switched order should still work
640+
@test y == KwCustomType(;b = YourType(-exp(-1)), a = MyType(-1.0))
641+
# Inverting with wrong type
642+
@test_throws ArgumentError inverse(t3, KwCustomType(-1.0, YourType(-exp(-1))))
643+
@test_throws ArgumentError inverse(t3, CustomType(MyType(-1.0), YourType(-exp(-1))))
644+
645+
end
646+
545647
####
546648
#### log density correctness checks
547649
####
@@ -842,21 +944,51 @@ end
842944
end
843945

844946
@testset "pretty printing" begin
947+
struct SmallType
948+
f1
949+
end
950+
struct LargerType
951+
f2
952+
f3
953+
end
954+
LargerType(;f2, f3) = LargerType(f2, f3)
845955
t = as((a = asℝ₊,
846956
b = as(Array, asℝ₋, 3, 3),
847957
c = corr_cholesky_factor(13),
848-
d = as((asℝ, corr_cholesky_factor(SMatrix{3,3}), UnitSimplex(3), unit_vector_norm(4)))))
958+
d = as((asℝ, corr_cholesky_factor(SMatrix{3,3}), UnitSimplex(3), unit_vector_norm(4))),
959+
e = as(LargerType, as((f3 = as(SmallType, (asℝ₊,)), f2 = as𝕀))),
960+
))
849961
repr_t = """
850-
[1:97] NamedTuple of transformations
962+
[1:100] NamedTuple of transformations
851963
[1:1] :a → asℝ₊
852964
[2:10] :b → 3×3×asℝ₋
853965
[11:88] :c → 13×13 correlation cholesky factor
854-
[89:97] :d → Tuple of transformations
855-
[98:98] 1 → asℝ
856-
[108:110] 2 → SMatrix{3,3} correlation cholesky factor
857-
[120:121] 3 → 3 element unit simplex transformation
858-
[131:133] 4 → 4 element (unit vector, norm) transformation"""
859-
repr(MIME("text/plain"), t) == repr_t
966+
[89:98] :d → Tuple of transformations
967+
[89:89] 1 → asℝ
968+
[90:92] 2 → SMatrix{3,3} correlation cholesky factor
969+
[93:94] 3 → 3 element unit simplex transformation
970+
[95:98] 4 → 4 element (unit vector, norm) transformation
971+
[99:100] :e → LargerType wrapper on NamedTuple of transformations
972+
[99:99] :f3 → SmallType wrapper on Tuple of transformations
973+
[99:99] 1 → asℝ₊
974+
[100:100] :f2 → as𝕀"""
975+
@test repr(MIME("text/plain"), t) == repr_t
976+
977+
t = as((as(Array, asℝ₊, 3),
978+
as(Array, asℝ₋, 3, 3),
979+
as(Array, TVScale(5.0) asℝ₋, 3, 3, 3),
980+
as(Array, as((a=asℝ₊, b = as𝕀)), 3)
981+
))
982+
repr_t = """
983+
[1:45] Tuple of transformations
984+
[1:3] 1 → 3×asℝ₊
985+
[4:12] 2 → 3×3×asℝ₋
986+
[13:39] 3 → 3×3×3×TVScale(5.0) ∘ TVNeg() ∘ asℝ₊
987+
[40:45] 4 → 3×
988+
NamedTuple of transformations
989+
:a → asℝ₊
990+
:b → as𝕀"""
991+
@test repr(MIME("text/plain"), t) == repr_t
860992
end
861993

862994
@testset "print ∞" begin

0 commit comments

Comments
 (0)