Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improved function partial application design #56518

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ New library functions
* `waitany(tasks; throw=false)` and `waitall(tasks; failfast=false, throw=false)` which wait multiple tasks at once ([#53341]).
* `uuid7()` creates an RFC 9652 compliant UUID with version 7 ([#54834]).
* `insertdims(array; dims)` allows to insert singleton dimensions into an array which is the inverse operation to `dropdims`
* The new `Fix` type is a generalization of `Fix1/Fix2` for fixing a single argument ([#54653]).
* `Fix1`/`Fix2` are now generalized by `fix` ([#54653], [#56518]).

New library features
--------------------
Expand Down
1 change: 1 addition & 0 deletions base/Base_compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ include("error.jl")
include("bool.jl")
include("number.jl")
include("int.jl")
include("typedomainnumbers.jl")
include("operators.jl")
include("pointer.jl")
include("refvalue.jl")
Expand Down
158 changes: 126 additions & 32 deletions base/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1153,55 +1153,149 @@ julia> filter(!isletter, str)
!(f::Function) = (!) ∘ f
!(f::ComposedFunction{typeof(!)}) = f.inner #allows !!f === f

const _PositiveInteger = _TypeDomainNumbers.PositiveIntegers.PositiveInteger

struct PartiallyAppliedFunction{Position <: _PositiveInteger, Func, Arg} <: Function
partially_applied_argument_position::Position
f::Func
x::Arg

function (::Type{PartiallyAppliedFunction{Position}})(func::Func, arg) where {Position <: _PositiveInteger, Func}
Pos = Position::DataType
pos = Pos.instance
new{Pos, _stable_typeof(func), _stable_typeof(arg)}(pos, func, arg)
end
end

function getproperty((@nospecialize v::PartiallyAppliedFunction), s::Symbol)
getfield(v, s)
end # avoid overspecialization

function Base.show(
(@nospecialize io::Base.IO),
(@nospecialize unused::Type{PartiallyAppliedFunction{Position}}),
) where {Position <: _PositiveInteger}
if Position isa DataType
print(io, "fix(")
show(io, Position.instance)
print(io, ')')
else
show(io, PartiallyAppliedFunction)
print(io, '{')
show(io, Position)
print(io, '}')
end
end

function Base.show(
(@nospecialize io::Base.IO),
(@nospecialize unused::Type{PartiallyAppliedFunction{Position, Func}}),
) where {Position <: _PositiveInteger, Func}
show(io, PartiallyAppliedFunction{Position})
print(io, '{')
show(io, Func)
print(io, '}')
end

function Base.show(
(@nospecialize io::Base.IO),
(@nospecialize unused::Type{PartiallyAppliedFunction{Position, Func, Arg}}),
) where {Position <: _PositiveInteger, Func, Arg}
show(io, PartiallyAppliedFunction{Position, Func})
print(io, '{')
show(io, Arg)
print(io, '}')
end

function Base.show((@nospecialize io::Base.IO), @nospecialize p::PartiallyAppliedFunction)
print(io, "fix(")
show(io, p.partially_applied_argument_position)
print(io, ")(")
show(io, p.f)
print(io, ", ")
show(io, p.x)
print(io, ')')
end

function _partially_applied_function_check(m::Int, nm1::Int)
if m < nm1
throw(ArgumentError(LazyString("expected at least ", nm1, " arguments to `fix(", nm1 + 1, ")`, but got ", m)))
end
end

function (partial::PartiallyAppliedFunction)(args::Vararg{Any,M}; kws...) where {M}
n = partial.partially_applied_argument_position
nm1 = _TypeDomainNumbers.PositiveIntegers.natural_predecessor(n)
_partially_applied_function_check(M, Int(nm1))
(args_left, args_right) = _TypeDomainNumberTupleUtils.split_tuple(args, nm1)
partial.f(args_left..., partial.x, args_right...; kws...)
end

"""
Fix{N}(f, x)
fix(::Integer)::UnionAll

Return a [`UnionAll`](@ref) type such that:
* It's a constructor taking two arguments:
1. A function to be partially applied
2. An argument of the above function to be fixed
* Its instances are partial applications of the function, with one positional argument fixed. The argument to `fix` is the one-based index of the position argument to be fixed.

A type representing a partially-applied version of a function `f`, with the argument
`x` fixed at position `N::Int`. In other words, `Fix{3}(f, x)` behaves similarly to
`(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`.
For example, `fix(3)(f, x)` behaves similarly to `(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`.

See also: [`Fix1`](@ref), [`Fix2`](@ref).

!!! compat "Julia 1.12"
This general functionality requires at least Julia 1.12, while `Fix1` and `Fix2`
are available earlier.
Requires at least Julia 1.12 (`Fix1` and `Fix2` are available earlier, too).

!!! note
When nesting multiple `Fix`, note that the `N` in `Fix{N}` is _relative_ to the current
When nesting multiple `fix`, note that the `n` in `fix(n)` is _relative_ to the current
available arguments, rather than an absolute ordering on the target function. For example,
`Fix{1}(Fix{2}(f, 4), 4)` fixes the first and second arg, while `Fix{2}(Fix{1}(f, 4), 4)`
`fix(1)(fix(2)(f, 4), 4)` fixes the first and second arg, while `fix(2)(fix(1)(f, 4), 4)`
fixes the first and third arg.
"""
struct Fix{N,F,T} <: Function
f::F
x::T

function Fix{N}(f::F, x) where {N,F}
if !(N isa Int)
throw(ArgumentError(LazyString("expected type parameter in `Fix` to be `Int`, but got `", N, "::", typeof(N), "`")))
elseif N < 1
throw(ArgumentError(LazyString("expected `N` in `Fix{N}` to be integer greater than 0, but got ", N)))
end
new{N,_stable_typeof(f),_stable_typeof(x)}(f, x)
end
end
### Examples

function (f::Fix{N})(args::Vararg{Any,M}; kws...) where {N,M}
M < N-1 && throw(ArgumentError(LazyString("expected at least ", N-1, " arguments to `Fix{", N, "}`, but got ", M)))
return f.f(args[begin:begin+(N-2)]..., f.x, args[begin+(N-1):end]...; kws...)
end
```jldoctest
julia> Base.fix(2)(-, 3)(7)
4

# Special cases for improved constant propagation
(f::Fix{1})(arg; kws...) = f.f(f.x, arg; kws...)
(f::Fix{2})(arg; kws...) = f.f(arg, f.x; kws...)
julia> Base.fix(2) === Base.Fix2
true

julia> Base.fix(1)(Base.fix(2)(muladd, 3), 2)(5) === (x -> muladd(2, 3, x))(5)
true
```
"""
Alias for `Fix{1}`. See [`Fix`](@ref Base.Fix).
function fix(@nospecialize m::Integer)
n = Int(m)::Int
if n ≤ 0
throw(ArgumentError("the index of the partially applied argument must be positive"))
end
k = _TypeDomainNumbers.Utils.from_abs_int(n)
PartiallyAppliedFunction{typeof(k)}
end

"""
Fix1::UnionAll

[`fix(1)`](@ref Base.fix).
"""
const Fix1{F,T} = Fix{1,F,T}
const Fix1 = fix(1)

"""
Alias for `Fix{2}`. See [`Fix`](@ref Base.Fix).
Fix2::UnionAll

[`fix(2)`](@ref Base.fix).
"""
const Fix2{F,T} = Fix{2,F,T}
const Fix2 = fix(2)

# Special cases for improved constant propagation
function (partial::Fix1)(x; kws...)
partial.f(partial.x, x; kws...)
end
function (partial::Fix2)(x; kws...)
partial.f(x, partial.x; kws...)
end


"""
Expand Down
2 changes: 1 addition & 1 deletion base/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public
AsyncCondition,
CodeUnits,
Event,
Fix,
fix,
Fix1,
Fix2,
Generator,
Expand Down
19 changes: 14 additions & 5 deletions base/tuple.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module _TupleTypeByLength
export Tuple1OrMore, Tuple2OrMore, Tuple32OrMore
const Tuple1OrMore = Tuple{Any, Vararg}
const Tuple2OrMore = Tuple{Any, Any, Vararg}
const Tuple32OrMore = Tuple{
Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any, Any,
Vararg{Any, N},
} where {N}
end

# Document NTuple here where we have everything needed for the doc system
"""
NTuple{N, T}
Expand Down Expand Up @@ -358,11 +371,7 @@ map(f, t::Tuple{Any, Any}) = (@inline; (f(t[1]), f(t[2])))
map(f, t::Tuple{Any, Any, Any}) = (@inline; (f(t[1]), f(t[2]), f(t[3])))
map(f, t::Tuple) = (@inline; (f(t[1]), map(f,tail(t))...))
# stop inlining after some number of arguments to avoid code blowup
const Any32{N} = Tuple{Any,Any,Any,Any,Any,Any,Any,Any,
Any,Any,Any,Any,Any,Any,Any,Any,
Any,Any,Any,Any,Any,Any,Any,Any,
Any,Any,Any,Any,Any,Any,Any,Any,
Vararg{Any,N}}
const Any32 = _TupleTypeByLength.Tuple32OrMore
const All32{T,N} = Tuple{T,T,T,T,T,T,T,T,
T,T,T,T,T,T,T,T,
T,T,T,T,T,T,T,T,
Expand Down
Loading
Loading