Skip to content

Commit

Permalink
Add Fix{N} for fixing a single positional argument at any position (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer authored Aug 7, 2024
1 parent f8af0d1 commit 1e623ad
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Compat"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "4.15.0"
version = "4.16.0"

[deps]
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ changes in `julia`.

## Supported features

* `Compat.Fix{N}` which fixes an argument at the `N`th position ([#54653]) (since Compat 4.16.0)

* `chopprefix(s, prefix)` and `chopsuffix(s, suffix)` ([#40995]) (since Compat 4.15.0)

* `logrange(lo, hi; length)` is like `range` but with a constant ratio, not difference. ([#39071]) (since Compat 4.14.0) Note that on Julia 1.8 and earlier, the version from Compat has slightly lower floating-point accuracy than the one in Base (Julia 1.11 and later).
Expand Down Expand Up @@ -192,3 +194,4 @@ Note that you should specify the correct minimum version for `Compat` in the
[#47679]: https://github.com/JuliaLang/julia/pull/47679
[#48038]: https://github.com/JuliaLang/julia/issues/48038
[#50105]: https://github.com/JuliaLang/julia/issues/50105
[#54653]: https://github.com/JuliaLang/julia/issues/54653
68 changes: 68 additions & 0 deletions src/Compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,74 @@ if VERSION < v"1.8.0-DEV.1016"
export chopprefix, chopsuffix
end

# https://github.com/JuliaLang/julia/pull/54653: add Fix
@static if !isdefined(Base, :Fix) # VERSION < v"1.12.0-DEV.981"
@static if !isdefined(Base, :_stable_typeof)
_stable_typeof(x) = typeof(x)
_stable_typeof(::Type{T}) where {T} = Type{T}
else
using Base: _stable_typeof
end

@doc """
Fix{N}(f, x)
A type representing a partially-applied version of a function `f`, with the argument
`x` fixed at position `N::Int`. In other words, `Fix{3}(f, x)` behaves similarly to
`(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`.
!!! note
When nesting multiple `Fix`, note that the `N` in `Fix{N}` is _relative_ to the current
available arguments, rather than an absolute ordering on the target function. For example,
`Fix{1}(Fix{2}(f, 4), 4)` fixes the first and second arg, while `Fix{2}(Fix{1}(f, 4), 4)`
fixes the first and third arg.
!!! note
Note that `Compat.Fix{1}`/`Fix{2}` are not the same as `Base.Fix1`/`Fix2` on Julia
versions earlier than `1.12.0-DEV.981`. Therefore, if you wish to use this as a way
to _dispatch_ on `Fix{N}`, you may wish to declare a method for both
`Compat.Fix{1}`/`Fix{2}` as well as `Base.Fix1`/`Fix2`, conditional on
a `@static if !isdefined(Base, :Fix); ...; end`.
""" Fix

struct Fix{N,F,T} <: Function
f::F
x::T

function Fix{N}(f::F, x) where {N,F}
if !(N isa Int)
throw(ArgumentError("expected type parameter in `Fix` to be `Int`, but got `$N::$(typeof(N))`"))
elseif N < 1
throw(ArgumentError("expected `N` in `Fix{N}` to be integer greater than 0, but got $N"))
end
new{N,_stable_typeof(f),_stable_typeof(x)}(f, x)
end
end

function (f::Fix{N})(args::Vararg{Any,M}; kws...) where {N,M}
M < N-1 && throw(ArgumentError("expected at least $(N-1) arguments to `Fix{$N}`, but got $M"))
return f.f(args[begin:begin+(N-2)]..., f.x, args[begin+(N-1):end]...; kws...)
end

# Special cases for improved constant propagation
(f::Fix{1})(arg; kws...) = f.f(f.x, arg; kws...)
(f::Fix{2})(arg; kws...) = f.f(arg, f.x; kws...)

@doc """
Alias for `Fix{1}`. See [`Fix`](@ref Compat.Fix).
""" Fix1

const Fix1{F,T} = Fix{1,F,T}

@doc """
Alias for `Fix{2}`. See [`Fix`](@ref Compat.Fix).
""" Fix2

const Fix2{F,T} = Fix{2,F,T}
else
using Base: Fix, Fix1, Fix2
end

include("deprecated.jl")

end # module Compat
131 changes: 131 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -907,3 +907,134 @@ end
@test isa(chopsuffix(S("foo"), "oo"), SubString)
end
end

# https://github.com/JuliaLang/julia/pull/54653: add Fix
@testset "Fix" begin
function test_fix1(Fix1=Compat.Fix1)
increment = Fix1(+, 1)
@test increment(5) == 6
@test increment(-1) == 0
@test increment(0) == 1
@test map(increment, [1, 2, 3]) == [2, 3, 4]

concat_with_hello = Fix1(*, "Hello ")
@test concat_with_hello("World!") == "Hello World!"
# Make sure inference is good:
@inferred concat_with_hello("World!")

one_divided_by = Fix1(/, 1)
@test one_divided_by(10) == 1/10.0
@test one_divided_by(-5) == 1/-5.0

return nothing
end

function test_fix2(Fix2=Compat.Fix2)
return_second = Fix2((x, y) -> y, 999)
@test return_second(10) == 999
@inferred return_second(10)
@test return_second(-5) == 999

divide_by_two = Fix2(/, 2)
@test map(divide_by_two, (2, 4, 6)) == (1.0, 2.0, 3.0)
@inferred map(divide_by_two, (2, 4, 6))

concat_with_world = Fix2(*, " World!")
@test concat_with_world("Hello") == "Hello World!"
@inferred concat_with_world("Hello World!")

return nothing
end

# Test with normal Base.Fix1 and Base.Fix2
test_fix1()
test_fix2()

# Now, repeat the Fix1 and Fix2 tests, but
# with a Fix lambda function used in their place
test_fix1((op, arg) -> Compat.Fix{1}(op, arg))
test_fix2((op, arg) -> Compat.Fix{2}(op, arg))

# Now, we do more complex tests of Fix:
let Fix=Compat.Fix
@testset "Argument Fixation" begin
let f = (x, y, z) -> x + y * z
fixed_f1 = Fix{1}(f, 10)
@test fixed_f1(2, 3) == 10 + 2 * 3

fixed_f2 = Fix{2}(f, 5)
@test fixed_f2(1, 4) == 1 + 5 * 4

fixed_f3 = Fix{3}(f, 3)
@test fixed_f3(1, 2) == 1 + 2 * 3
end
end
@testset "Helpful errors" begin
let g = (x, y) -> x - y
# Test minimum N
fixed_g1 = Fix{1}(g, 100)
@test fixed_g1(40) == 100 - 40

# Test maximum N
fixed_g2 = Fix{2}(g, 100)
@test fixed_g2(150) == 150 - 100

# One over
fixed_g3 = Fix{3}(g, 100)
@test_throws ArgumentError("expected at least 2 arguments to `Fix{3}`, but got 1") fixed_g3(1)
end
end
@testset "Type Stability and Inference" begin
let h = (x, y) -> x / y
fixed_h = Fix{2}(h, 2.0)
@test @inferred(fixed_h(4.0)) == 2.0
end
end
@testset "Interaction with varargs" begin
vararg_f = (x, y, z...) -> x + 10 * y + sum(z; init=zero(x))
fixed_vararg_f = Fix{2}(vararg_f, 6)

# Can call with variable number of arguments:
@test fixed_vararg_f(1, 2, 3, 4) == 1 + 10 * 6 + sum((2, 3, 4))
if VERSION >= v"1.7.0"
@inferred fixed_vararg_f(1, 2, 3, 4)
end
@test fixed_vararg_f(5) == 5 + 10 * 6
if VERSION >= v"1.7.0"
@inferred fixed_vararg_f(5)
end
end
@testset "Errors should propagate normally" begin
error_f = (x, y) -> sin(x * y)
fixed_error_f = Fix{2}(error_f, Inf)
@test_throws DomainError fixed_error_f(10)
end
@testset "Chaining Fix together" begin
f1 = Fix{1}(*, "1")
f2 = Fix{1}(f1, "2")
f3 = Fix{1}(f2, "3")
@test f3() == "123"

g1 = Fix{2}(*, "1")
g2 = Fix{2}(g1, "2")
g3 = Fix{2}(g2, "3")
@test g3("") == "123"
end
@testset "Zero arguments" begin
f = Fix{1}(x -> x, 'a')
@test f() == 'a'
end
@testset "Dummy-proofing" begin
@test_throws ArgumentError("expected `N` in `Fix{N}` to be integer greater than 0, but got 0") Fix{0}(>, 1)
@test_throws ArgumentError("expected type parameter in `Fix` to be `Int`, but got `0.5::Float64`") Fix{0.5}(>, 1)
@test_throws ArgumentError("expected type parameter in `Fix` to be `Int`, but got `1::UInt64`") Fix{UInt64(1)}(>, 1)
end
@testset "Specialize to structs not in `Base`" begin
struct MyStruct
x::Int
end
f = Fix{1}(MyStruct, 1)
@test f isa Fix{1,Type{MyStruct},Int}
end
end
end

0 comments on commit 1e623ad

Please sign in to comment.