|
412 | 412 | @test tt[1] == inner[1] |
413 | 413 | @test tt[2] == inner[2] |
414 | 414 | @test tt[3] == inner[3] |
| 415 | + @test tt[1] == inner[1] |
415 | 416 | @test_throws BoundsError tt[4] |
416 | 417 | @test propertynames(tt) == propertynames(inner) |
417 | 418 | @test (@set tt[3] = t2) == as((t1, t2, t2)) |
@@ -542,6 +543,107 @@ end |
542 | 543 |
|
543 | 544 | end |
544 | 545 |
|
| 546 | +@testset "transform to custom type" begin |
| 547 | + |
| 548 | + struct CustomType{A, B} |
| 549 | + a::A |
| 550 | + b::B |
| 551 | + end |
| 552 | + @kwdef struct KwCustomType{A, B} |
| 553 | + a::A |
| 554 | + b::B |
| 555 | + end |
| 556 | + struct MyType{C} |
| 557 | + c::C |
| 558 | + end |
| 559 | + struct YourType |
| 560 | + d::Float64 |
| 561 | + end |
| 562 | + |
| 563 | + # From named tuple to type |
| 564 | + t1 = as((a = asℝ, b = asℝ)) |
| 565 | + t2 = as(CustomType, t1) |
| 566 | + @test_throws MethodError transform(t2, [1.0, 2.0]) |
| 567 | + |
| 568 | + t1 = as((a = asℝ, b = asℝ)) |
| 569 | + t2 = as(KwCustomType, t1) |
| 570 | + x = [1.0, 2.0] |
| 571 | + test_transformation(t2, y -> y isa KwCustomType; N=1, jac=false) |
| 572 | + @test_throws ArgumentError inverse(t2, [1.0, 2.0]) |
| 573 | + @test_throws ArgumentError inverse(t2, MyType(3.0)) |
| 574 | + |
| 575 | + # Named tuple with different ordering |
| 576 | + t1 = as((b = asℝ, a = asℝ)) |
| 577 | + t2 = as(KwCustomType, t1) |
| 578 | + y = @inferred transform(t2, [1.0, 2.0]) |
| 579 | + @test y == KwCustomType(a = 2.0, b = 1.0) |
| 580 | + test_transformation(t2, y -> y isa KwCustomType; N=1, jac=false) |
| 581 | + |
| 582 | + # Named tuple with wrong number or names of fields |
| 583 | + t1 = as((;b = asℝ)) |
| 584 | + t2 = as(KwCustomType, t1) |
| 585 | + @test_throws UndefKeywordError transform(t2, [1.0]) |
| 586 | + @test inverse(t2, KwCustomType(1.0, 3.0)) == [3.0] |
| 587 | + t1 = as((a = asℝ, c = asℝ)) |
| 588 | + t2 = as(KwCustomType, t1) |
| 589 | + @test_throws UndefKeywordError transform(t2, [1.0, 3.0]) |
| 590 | + @test_throws ArgumentError inverse(t2, KwCustomType(1.0, 3.0)) |
| 591 | + t1 = as((b = asℝ, a = asℝ, c = asℝ)) |
| 592 | + t2 = as(KwCustomType, t1) |
| 593 | + @test_throws MethodError transform(t2, [1.0, 2.0, 3.0]) |
| 594 | + @test_throws ArgumentError inverse(t2, KwCustomType(1.0, 3.0)) |
| 595 | + |
| 596 | + # Type with shortened constructor |
| 597 | + struct MaskedType{A, B} |
| 598 | + a::A |
| 599 | + b::B |
| 600 | + end |
| 601 | + MaskedType(x) = MaskedType(x, nothing) |
| 602 | + MaskedType(;b=0.0) = MaskedType(nothing, b) |
| 603 | + t = as(MaskedType, as((;b=asℝ))) |
| 604 | + test_transformation(t, y -> y isa MaskedType; N=1, jac=false) |
| 605 | + |
| 606 | + # No kwarg constructor accepts `a` arg, so errors |
| 607 | + t = as(MaskedType, as((;a=asℝ))) |
| 608 | + @test_throws MethodError transform(t, [1.0]) |
| 609 | + |
| 610 | + # When constructor accepts less args than struct has fields, |
| 611 | + # inverse errors |
| 612 | + t = as(MaskedType, (asℝ,)) |
| 613 | + x = [1.0] |
| 614 | + y = transform(t, x) |
| 615 | + @test y == MaskedType(1.0, nothing) |
| 616 | + @test_throws ArgumentError inverse(t, MaskedType(1.0, nothing)) |
| 617 | + # test_transformation(t, y -> y isa MaskedType; N=1, jac=false) |
| 618 | + |
| 619 | + # Insufficient arguments in provided tuple for constructor |
| 620 | + # Not specially caught, but good to check |
| 621 | + t1 = as(CustomType, (asℝ₊,)) |
| 622 | + @test_throws MethodError transform(t1, [1.0]) |
| 623 | + |
| 624 | + # From tuple to type |
| 625 | + t1 = as(ntuple(i->asℝ₊, Val(2))) |
| 626 | + t2 = as(CustomType, t1) |
| 627 | + test_transformation(t2, y -> y isa CustomType; N=1, jac=false) |
| 628 | + # Trying to invert from another type should error |
| 629 | + @test_throws ArgumentError inverse(t2, [1.0, 2.0]) |
| 630 | + @test_throws ArgumentError inverse(t2, MyType(3.0)) |
| 631 | + |
| 632 | + # Nested custom types |
| 633 | + t1 = as(MyType, (asℝ₋,)) |
| 634 | + t2 = as(YourType, (asℝ₋,)) |
| 635 | + t3 = as(KwCustomType, as((a = t1, b = t2))) |
| 636 | + x = [0.0, -1] |
| 637 | + y = transform(t3, x) |
| 638 | + test_transformation(t3, y -> y isa KwCustomType; N=1, jac=false) |
| 639 | + # Switched order should still work |
| 640 | + @test y == KwCustomType(;b = YourType(-exp(-1)), a = MyType(-1.0)) |
| 641 | + # Inverting with wrong type |
| 642 | + @test_throws ArgumentError inverse(t3, KwCustomType(-1.0, YourType(-exp(-1)))) |
| 643 | + @test_throws ArgumentError inverse(t3, CustomType(MyType(-1.0), YourType(-exp(-1)))) |
| 644 | + |
| 645 | +end |
| 646 | + |
545 | 647 | #### |
546 | 648 | #### log density correctness checks |
547 | 649 | #### |
@@ -842,21 +944,51 @@ end |
842 | 944 | end |
843 | 945 |
|
844 | 946 | @testset "pretty printing" begin |
| 947 | + struct SmallType |
| 948 | + f1 |
| 949 | + end |
| 950 | + struct LargerType |
| 951 | + f2 |
| 952 | + f3 |
| 953 | + end |
| 954 | + LargerType(;f2, f3) = LargerType(f2, f3) |
845 | 955 | t = as((a = asℝ₊, |
846 | 956 | b = as(Array, asℝ₋, 3, 3), |
847 | 957 | c = corr_cholesky_factor(13), |
848 | | - d = as((asℝ, corr_cholesky_factor(SMatrix{3,3}), UnitSimplex(3), unit_vector_norm(4))))) |
| 958 | + d = as((asℝ, corr_cholesky_factor(SMatrix{3,3}), UnitSimplex(3), unit_vector_norm(4))), |
| 959 | + e = as(LargerType, as((f3 = as(SmallType, (asℝ₊,)), f2 = as𝕀))), |
| 960 | + )) |
849 | 961 | repr_t = """ |
850 | | -[1:97] NamedTuple of transformations |
| 962 | +[1:100] NamedTuple of transformations |
851 | 963 | [1:1] :a → asℝ₊ |
852 | 964 | [2:10] :b → 3×3×asℝ₋ |
853 | 965 | [11:88] :c → 13×13 correlation cholesky factor |
854 | | - [89:97] :d → Tuple of transformations |
855 | | - [98:98] 1 → asℝ |
856 | | - [108:110] 2 → SMatrix{3,3} correlation cholesky factor |
857 | | - [120:121] 3 → 3 element unit simplex transformation |
858 | | - [131:133] 4 → 4 element (unit vector, norm) transformation""" |
859 | | - repr(MIME("text/plain"), t) == repr_t |
| 966 | + [89:98] :d → Tuple of transformations |
| 967 | + [89:89] 1 → asℝ |
| 968 | + [90:92] 2 → SMatrix{3,3} correlation cholesky factor |
| 969 | + [93:94] 3 → 3 element unit simplex transformation |
| 970 | + [95:98] 4 → 4 element (unit vector, norm) transformation |
| 971 | + [99:100] :e → LargerType wrapper on NamedTuple of transformations |
| 972 | + [99:99] :f3 → SmallType wrapper on Tuple of transformations |
| 973 | + [99:99] 1 → asℝ₊ |
| 974 | + [100:100] :f2 → as𝕀""" |
| 975 | + @test repr(MIME("text/plain"), t) == repr_t |
| 976 | + |
| 977 | + t = as((as(Array, asℝ₊, 3), |
| 978 | + as(Array, asℝ₋, 3, 3), |
| 979 | + as(Array, TVScale(5.0) ∘ asℝ₋, 3, 3, 3), |
| 980 | + as(Array, as((a=asℝ₊, b = as𝕀)), 3) |
| 981 | + )) |
| 982 | + repr_t = """ |
| 983 | +[1:45] Tuple of transformations |
| 984 | + [1:3] 1 → 3×asℝ₊ |
| 985 | + [4:12] 2 → 3×3×asℝ₋ |
| 986 | + [13:39] 3 → 3×3×3×TVScale(5.0) ∘ TVNeg() ∘ asℝ₊ |
| 987 | + [40:45] 4 → 3× |
| 988 | + NamedTuple of transformations |
| 989 | + :a → asℝ₊ |
| 990 | + :b → as𝕀""" |
| 991 | + @test repr(MIME("text/plain"), t) == repr_t |
860 | 992 | end |
861 | 993 |
|
862 | 994 | @testset "print ∞" begin |
|
0 commit comments