From 2bb90676968198a4053d9391dd363ad92cfa941d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Thu, 18 May 2023 14:13:19 +0200 Subject: [PATCH 1/5] add view transformations --- src/aggregation.jl | 47 ++++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 14 ++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/src/aggregation.jl b/src/aggregation.jl index e31aab2..d881d20 100644 --- a/src/aggregation.jl +++ b/src/aggregation.jl @@ -106,6 +106,53 @@ function _domain_label(transformation::ArrayTransformation, index::Int) _array_domain_label(inner_transformation, dims, index) end +#### +#### array view +#### + +""" +$(TYPEDEF) + +View of an array with `dims`. + +!!! note + This feature is experimental, and not part of the stable API; it may disappear or change without + relevant changes in SemVer or deprecations. Inner transformations are not supported. +""" +struct ViewTransformation{M} <: VectorTransform + dims::NTuple{M, Int} +end + +function as(::typeof(view), dims::Tuple{Vararg{Int}}) + @argcheck all(d -> d ≥ 0, dims) "All dimensions need to be non-negative." + ViewTransformation(dims) +end + +as(::typeof(view), dims::Int...) = as(view, dims) + +dimension(transformation::ViewTransformation) = prod(transformation.dims) + +function transform_with(flag::LogJacFlag, t::ViewTransformation, x, index) + index′ = index + dimension(t) + y = reshape(@view(x[index:(index′-1)]), t.dims) + y, logjac_zero(flag, robust_eltype(x)), index′ +end + +function _domain_label(transformation::ViewTransformation, index::Int) + @unpack dims = transformation + _array_domain_label(asℝ, dims, index) +end + +inverse_eltype(transformation::ViewTransformation, y) = eltype(y) + +function inverse_at!(x::AbstractVector, index, transformation::ViewTransformation, + y::AbstractArray) + @argcheck size(y) == transformation.dims + index′ = index + dimension(transformation) + copy!(@view(x[index:(index′-1)]), vec(y)) + index′ +end + #### #### static array #### diff --git a/test/runtests.jl b/test/runtests.jl index 928c11b..240c727 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -644,3 +644,17 @@ end @testset "static arrays inference" begin @test @inferred transform_with(NOLOGJAC, as(SVector{3, Float64}), zeros(3), 1) == (SVector(0.0, 0.0, 0.0), NOLOGJAC, 4) end + +@testset "view transformations" begin + x = randn(10) + t = as((a = asℝ, b = as(view, 2, 4), c = asℝ)) + y, lj = transform_and_logjac(t, x) + @test typeof(y.b) <: AbstractMatrix + @test size(y.b) == (2, 4) + # test inverse + @test inverse(t, y) == x + # test that it is a view + z = y.b[3] + y.b[3] += 1 + @test x[4] == z + 1 +end From f47e210d5a7813243a7e07fbe1f2333c7553b46d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Thu, 18 May 2023 14:13:24 +0200 Subject: [PATCH 2/5] incidental: fix Enzyme test --- test/runtests.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 240c727..d3faa46 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -411,9 +411,9 @@ end y, lj = transform_and_logjac(ss, x) return -abs2(y) + lj end - ge = autodiff(enzyme, Const(ss), Active(0.5)) - g = ForwardDiff.derivative(x->enzyme(ss, x), 0.5) - @test g ≈ first(ge) + g, _ = autodiff(ReverseWithPrimal, enzyme, Const(ss), Active(0.5)) + g2 = ForwardDiff.derivative(x -> enzyme(ss, x), 0.5) + @test g[2] ≈ g2 end end From b4e44461fbeb04f4080d69f0ec33b4e78b1ef2ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Thu, 18 May 2023 14:19:12 +0200 Subject: [PATCH 3/5] replace UnPack with SimpleUnPack --- Project.toml | 4 ++-- src/TransformVariables.jl | 2 +- test/Project.toml | 2 +- test/runtests.jl | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 4041244..0e9aaa5 100644 --- a/Project.toml +++ b/Project.toml @@ -13,8 +13,8 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" -UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" [compat] ArgCheck = "1, 2" @@ -23,6 +23,6 @@ DocStringExtensions = "0.8, 0.9" ForwardDiff = "0.10" InverseFunctions = "0.1" LogExpFunctions = "0.3" +SimpleUnPack = "1" StaticArrays = "1" -UnPack = "1" julia = "1.6" diff --git a/src/TransformVariables.jl b/src/TransformVariables.jl index 14ed447..16c60a2 100644 --- a/src/TransformVariables.jl +++ b/src/TransformVariables.jl @@ -5,9 +5,9 @@ using DocStringExtensions: FUNCTIONNAME, SIGNATURES, TYPEDEF import ForwardDiff using LogExpFunctions using LinearAlgebra: UpperTriangular, logabsdet -using UnPack: @unpack using Random: AbstractRNG, GLOBAL_RNG using StaticArrays: MMatrix, SMatrix, SArray, SVector, pushfirst +using SimpleUnPack: @unpack import ChangesOfVariables import InverseFunctions diff --git a/test/Project.toml b/test/Project.toml index 61108d7..cc42be0 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,7 +9,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TransformedLogDensities = "f9bc47f6-f3f8-4f3b-ab21-f8bc73906f26" -UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" diff --git a/test/runtests.jl b/test/runtests.jl index d3faa46..a2476bb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,4 +1,4 @@ -using DocStringExtensions, LinearAlgebra, LogDensityProblems, OffsetArrays, UnPack, +using DocStringExtensions, LinearAlgebra, LogDensityProblems, OffsetArrays, SimpleUnPack, Random, Test, TransformVariables, StaticArrays, TransformedLogDensities, LogDensityProblemsAD import ForwardDiff From 7fa8d292408e1a9a35560d2a7d4bb9d4b8c33564 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Thu, 18 May 2023 14:19:20 +0200 Subject: [PATCH 4/5] fix imports --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index a2476bb..cbc30bc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,7 @@ using TransformVariables: AbstractTransform, ScalarTransform, VectorTransform, ArrayTransformation, unit_triangular_dimension, logistic, logistic_logjac, logit, inverse_and_logjac, NOLOGJAC, transform_with import ChangesOfVariables, InverseFunctions -using Enzyme: autodiff, Reverse, Active, Const +using Enzyme: autodiff, ReverseWithPrimal, Active, Const const CIENV = get(ENV, "CI", "") == "true" From 4ec444fe21f20b3eac78614bd7909e8ca61df626 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20K=2E=20Papp?= Date: Thu, 18 May 2023 14:27:16 +0200 Subject: [PATCH 5/5] bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 0e9aaa5..8018627 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TransformVariables" uuid = "84d833dd-6860-57f9-a1a7-6da5db126cff" authors = ["Tamas K. Papp "] -version = "0.8.6" +version = "0.8.7" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"