Skip to content

add experimental as(view, dims...) transformation #115

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TransformVariables"
uuid = "84d833dd-6860-57f9-a1a7-6da5db126cff"
authors = ["Tamas K. Papp <tkpapp@gmail.com>"]
version = "0.8.6"
version = "0.8.7"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand All @@ -13,8 +13,8 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[compat]
ArgCheck = "1, 2"
Expand All @@ -23,6 +23,6 @@ DocStringExtensions = "0.8, 0.9"
ForwardDiff = "0.10"
InverseFunctions = "0.1"
LogExpFunctions = "0.3"
SimpleUnPack = "1"
StaticArrays = "1"
UnPack = "1"
julia = "1.6"
2 changes: 1 addition & 1 deletion src/TransformVariables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ using DocStringExtensions: FUNCTIONNAME, SIGNATURES, TYPEDEF
import ForwardDiff
using LogExpFunctions
using LinearAlgebra: UpperTriangular, logabsdet
using UnPack: @unpack
using Random: AbstractRNG, GLOBAL_RNG
using StaticArrays: MMatrix, SMatrix, SArray, SVector, pushfirst
using SimpleUnPack: @unpack

import ChangesOfVariables
import InverseFunctions
Expand Down
47 changes: 47 additions & 0 deletions src/aggregation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,53 @@ function _domain_label(transformation::ArrayTransformation, index::Int)
_array_domain_label(inner_transformation, dims, index)
end

####
#### array view
####

"""
$(TYPEDEF)

View of an array with `dims`.

!!! note
This feature is experimental, and not part of the stable API; it may disappear or change without
relevant changes in SemVer or deprecations. Inner transformations are not supported.
"""
struct ViewTransformation{M} <: VectorTransform
dims::NTuple{M, Int}
end

function as(::typeof(view), dims::Tuple{Vararg{Int}})
@argcheck all(d -> d ≥ 0, dims) "All dimensions need to be non-negative."
ViewTransformation(dims)
end

as(::typeof(view), dims::Int...) = as(view, dims)

dimension(transformation::ViewTransformation) = prod(transformation.dims)

function transform_with(flag::LogJacFlag, t::ViewTransformation, x, index)
index′ = index + dimension(t)
y = reshape(@view(x[index:(index′-1)]), t.dims)
y, logjac_zero(flag, robust_eltype(x)), index′
end

function _domain_label(transformation::ViewTransformation, index::Int)
@unpack dims = transformation
_array_domain_label(asℝ, dims, index)
end

inverse_eltype(transformation::ViewTransformation, y) = eltype(y)

function inverse_at!(x::AbstractVector, index, transformation::ViewTransformation,
y::AbstractArray)
@argcheck size(y) == transformation.dims
index′ = index + dimension(transformation)
copy!(@view(x[index:(index′-1)]), vec(y))
index′
end

####
#### static array
####
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TransformedLogDensities = "f9bc47f6-f3f8-4f3b-ab21-f8bc73906f26"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
24 changes: 19 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using DocStringExtensions, LinearAlgebra, LogDensityProblems, OffsetArrays, UnPack,
using DocStringExtensions, LinearAlgebra, LogDensityProblems, OffsetArrays, SimpleUnPack,
Random, Test, TransformVariables, StaticArrays, TransformedLogDensities,
LogDensityProblemsAD
import ForwardDiff
Expand All @@ -8,7 +8,7 @@ using TransformVariables:
AbstractTransform, ScalarTransform, VectorTransform, ArrayTransformation,
unit_triangular_dimension, logistic, logistic_logjac, logit, inverse_and_logjac, NOLOGJAC, transform_with
import ChangesOfVariables, InverseFunctions
using Enzyme: autodiff, Reverse, Active, Const
using Enzyme: autodiff, ReverseWithPrimal, Active, Const

const CIENV = get(ENV, "CI", "") == "true"

Expand Down Expand Up @@ -411,9 +411,9 @@ end
y, lj = transform_and_logjac(ss, x)
return -abs2(y) + lj
end
ge = autodiff(enzyme, Const(ss), Active(0.5))
g = ForwardDiff.derivative(x->enzyme(ss, x), 0.5)
@test g ≈ first(ge)
g, _ = autodiff(ReverseWithPrimal, enzyme, Const(ss), Active(0.5))
g2 = ForwardDiff.derivative(x -> enzyme(ss, x), 0.5)
@test g[2]g2
end
end

Expand Down Expand Up @@ -644,3 +644,17 @@ end
@testset "static arrays inference" begin
@test @inferred transform_with(NOLOGJAC, as(SVector{3, Float64}), zeros(3), 1) == (SVector(0.0, 0.0, 0.0), NOLOGJAC, 4)
end

@testset "view transformations" begin
x = randn(10)
t = as((a = asℝ, b = as(view, 2, 4), c = asℝ))
y, lj = transform_and_logjac(t, x)
@test typeof(y.b) <: AbstractMatrix
@test size(y.b) == (2, 4)
# test inverse
@test inverse(t, y) == x
# test that it is a view
z = y.b[3]
y.b[3] += 1
@test x[4] == z + 1
end