Skip to content

Make inverse(::VectorTransfrom, x) return a vector of floats #133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "TransformVariables"
uuid = "84d833dd-6860-57f9-a1a7-6da5db126cff"
authors = ["Tamas K. Papp <tkpapp@gmail.com>"]
version = "0.8.15"
version = "0.8.16"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
11 changes: 10 additions & 1 deletion src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,17 @@ function transform_and_logjac(t::VectorTransform, x::AbstractVector)
y, ℓ
end

# We want to avoid vectors with non-numerical element types
# Ref https://github.com/tpapp/TransformVariables.jl/issues/132
function inverse(t::VectorTransform, y)
inverse!(Vector{inverse_eltype(t, y)}(undef, dimension(t)), t, y)
inverse!(Vector{_float_or_Float64(inverse_eltype(t, y))}(undef, dimension(t)), t, y)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this catches too much. I would simply condition on dimension(t) == 0, I don't think that Union{} can happen outside that. Something like

T = inverse_eltype(t, y)
d = dimension(t)    
if T === Union{} || d == 0
    T = Float64
end
inverse!(Vector{T)}(undef, d), t, y)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In your suggestion you seem to condition on Union{} as well?

I was actually considering checking only dimsnion(t) == 0 as for my use case it shouldn't matter. I went with checking only whether inverse_eltype(t, y) === Union{} since I was worried that branching on the value of dimension(t) could introduce a type instability.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that makes sense.

But now that I think about it, it is kind of funny to make eltype(inverse(t, y)) != inverse_eltype(t, y). I think it would be best to fix it in inverse_eltype, the following way:

  1. Introduce a function, eg narrow_inverse_eltype, basically rename inverse_eltype to that, in the state before this PR. Document that types should extent it.

  2. In inverse_eltype, just check for Union{} and replace with Float64.

Now that I look at the code, the eltype determination is a bit flaky in a lot of places:

julia> t = as(Array, 3)
[1:3] 3×
  asℝ

julia> inverse(t, Any[1, 2.1, 3])
ERROR: InexactError: Int64(2.1)

We should probably rethink it as a whole for all kinds of corner cases. Suggestions welcome.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially I wanted to change inverse_eltype. But making it concrete seemed wrong - similar to the scalar transforms (such as inverse_eltype(::Constant) = Union{}), per se an inverse_eltype seems completely correct: It avoids prematurely introducing e.g. Float64 in a "non-leaf" transform (ie before applying all compositions or combinations). IMO the element type is only relevant when the output array is actually instantiated, and changes (such as changing to floating point number types or turning Union{} to Float64) should only be performed at this final stage - as done in this PR.

Now that I look at the code, the eltype determination is a bit flaky in a lot of places:

I think this example is quite special as it is caused by the same bug as #73: The element type of the output (inverse_eltype) is in this case computed based on the first element of input x. Whereas actually for non-concrete element types, I think inverse_eltype take into account the full input x, e.g. by promoting the inverse_eltypes for each element (can't be inferred in this case) or returning an abstract type such as Any.

julia> TransformVariables.inverse_eltype(t, Any[1, 2.1, 3])
Int64

julia> TransformVariables.inverse_eltype(t, Any[2.1, 1, 3])
Float64

By construction of this failing example, it is actually fixed by this PR since Int would be changed to Float64 when constructing the Vector. With this PR,

julia> inverse(t, Any[1, 2.1, 3])
3-element Vector{Float64}:
 1.0
 2.1
 3.0

end
function _float_or_Float64(::Type{T}) where T
if T !== Union{} && T <: Number # heuristic: it is assumed that every `Number` type defines `float`
return float(T)
else
return Float64
end
end

"""
Expand Down
34 changes: 34 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -730,3 +730,37 @@ end
end
@test d1 == d2
end

@testset "inverse of VectorTransform" begin
# Empty `inverse(::VectorTransform, _)`
for a in (3, 4.7, [5], 3f0, 4.7f0, [5f0])
x = @inferred(inverse(as((; a = Constant(a))), (; a)))
@test x isa Vector{Float64}
@test isempty(x)

x = @inferred(inverse(as((Constant(a),)), (a,)))
@test x isa Vector{Float64}
@test isempty(x)

x = @inferred(inverse(as(Vector, Constant(a), 1), [a]))
@test x isa Vector{Float64}
@test isempty(x)
end

# Element type of `inverse(::VectorTransform, _)`
for a in (3, 3.0, 3f0)
T = float(typeof(a))

x = @inferred(inverse(as((; a = asℝ)), (; a)))
@test x isa Vector{T}
@test x == [3]

x = @inferred(inverse(as((asℝ,)), (a,)))
@test x isa Vector{T}
@test x == [3]

x = @inferred(inverse(as(Vector, asℝ, 1), [a]))
@test x isa Vector{T}
@test x == [3]
end
end
Loading