Skip to content

Commit

Permalink
improved function partial application design
Browse files Browse the repository at this point in the history
This replaces `Fix` (xref #54653) with `fix`. The usage is similar: use
`fix(i)(f, x)` instead of `Fix{i}(f, x)`.

Benefits:
* Improved type safety: creating an invalid type such as
  `Fix{:some_symbol}` or `Fix{-7}` is not possible.
* The design should be friendlier to future extensions. E.g., suppose
  that publicly-facing functionality for fixing a keyword (instead of
  positional) argument was desired, it could be achieved by adding a
  new method to `fix` taking a `Symbol`, instead of adding new public
  names.

Lots of changes are shared with PR #56425, if one of them gets merged
the other will be greatly simplified.
  • Loading branch information
nsajko committed Nov 10, 2024
1 parent afdba95 commit ac0a69e
Show file tree
Hide file tree
Showing 11 changed files with 354 additions and 72 deletions.
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ New library functions
* `waitany(tasks; throw=false)` and `waitall(tasks; failfast=false, throw=false)` which wait multiple tasks at once ([#53341]).
* `uuid7()` creates an RFC 9652 compliant UUID with version 7 ([#54834]).
* `insertdims(array; dims)` allows to insert singleton dimensions into an array which is the inverse operation to `dropdims`
* The new `Fix` type is a generalization of `Fix1/Fix2` for fixing a single argument ([#54653]).
* `Fix1`/`Fix2` are now generalized by `fix` ([#54653], [#56518]).

New library features
--------------------
Expand Down
1 change: 1 addition & 0 deletions base/Base_compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ include("error.jl")
include("bool.jl")
include("number.jl")
include("int.jl")
include("typedomainnumbers.jl")
include("operators.jl")
include("pointer.jl")
include("refvalue.jl")
Expand Down
136 changes: 104 additions & 32 deletions base/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1153,55 +1153,127 @@ julia> filter(!isletter, str)
!(f::Function) = (!) f
!(f::ComposedFunction{typeof(!)}) = f.inner #allows !!f === f

struct PartiallyAppliedFunction{Position <: _TypeDomainNumbers.PositiveIntegers.PositiveInteger, Func, Arg} <: Function
partially_applied_argument_position::Position
f::Func
x::Arg

global function new_partially_applied_function(pos::_TypeDomainNumbers.PositiveIntegers.PositiveInteger, f::Func, x) where {Func}
new{typeof(pos), _stable_typeof(f), _stable_typeof(x)}(pos, f, x)
end
end

function (::Type{PartiallyAppliedFunction{Position}})(func::Func, arg) where {Position <: _TypeDomainNumbers.PositiveIntegers.PositiveInteger, Func}
pos = (Position::DataType).instance
new_partially_applied_function(pos, func, arg)
end

function getproperty((@nospecialize v::PartiallyAppliedFunction), s::Symbol)
getfield(v, s)
end # avoid overspecialization

function Base.show((@nospecialize io::Base.IO), @nospecialize unused::Type{PartiallyAppliedFunction{Position}}) where {Position <: _TypeDomainNumbers.PositiveIntegers.PositiveInteger}
if Position isa DataType
print(io, "fix(")
show(io, Position.instance)
print(io, ')')
else
show(io, PartiallyAppliedFunction)
print(io, '{')
show(io, Position)
print(io, '}')
end
end

function Base.show((@nospecialize io::Base.IO), @nospecialize p::PartiallyAppliedFunction)
print(io, "fix(")
show(io, p.partially_applied_argument_position)
print(io, ")(")
show(io, p.f)
print(io, ", ")
show(io, p.x)
print(io, ')')
end

function _partially_applied_function_check(m::Int, nm1::Int)
if m < nm1
throw(ArgumentError(LazyString("expected at least ", nm1, " arguments to `fix(", nm1 + 1, ")`, but got ", m)))
end
end

function (partial::PartiallyAppliedFunction)(args::Vararg{Any,M}; kws...) where {M}
n = partial.partially_applied_argument_position
nm1 = _TypeDomainNumbers.PositiveIntegers.natural_predecessor(n)
_partially_applied_function_check(M, Int(nm1))
(args_left, args_right) = _TypeDomainNumberTupleUtils.split_tuple(args, nm1)
partial.f(args_left..., partial.x, args_right...; kws...)
end

"""
Fix{N}(f, x)
fix(::Integer)::UnionAll
Return a [`UnionAll`](@ref) type such that:
* It's a constructor taking two arguments:
1. A function to be partially applied
2. An argument of the above function to be fixed
* Its instances are partial applications of the function, with one positional argument fixed. The argument to `fix` is the one-based index of the position argument to be fixed.
For example, `fix(3)(f, x)` behaves similarly to `(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`.
A type representing a partially-applied version of a function `f`, with the argument
`x` fixed at position `N::Int`. In other words, `Fix{3}(f, x)` behaves similarly to
`(y1, y2, y3...; kws...) -> f(y1, y2, x, y3...; kws...)`.
See also: [`Fix1`](@ref), [`Fix2`](@ref).
!!! compat "Julia 1.12"
This general functionality requires at least Julia 1.12, while `Fix1` and `Fix2`
are available earlier.
Requires at least Julia 1.12 (`Fix1` and `Fix2` are available earlier, too).
!!! note
When nesting multiple `Fix`, note that the `N` in `Fix{N}` is _relative_ to the current
When nesting multiple `fix`, note that the `n` in `fix(n)` is _relative_ to the current
available arguments, rather than an absolute ordering on the target function. For example,
`Fix{1}(Fix{2}(f, 4), 4)` fixes the first and second arg, while `Fix{2}(Fix{1}(f, 4), 4)`
`fix(1)(fix(2)(f, 4), 4)` fixes the first and second arg, while `fix(2)(fix(1)(f, 4), 4)`
fixes the first and third arg.
"""
struct Fix{N,F,T} <: Function
f::F
x::T
function Fix{N}(f::F, x) where {N,F}
if !(N isa Int)
throw(ArgumentError(LazyString("expected type parameter in `Fix` to be `Int`, but got `", N, "::", typeof(N), "`")))
elseif N < 1
throw(ArgumentError(LazyString("expected `N` in `Fix{N}` to be integer greater than 0, but got ", N)))
end
new{N,_stable_typeof(f),_stable_typeof(x)}(f, x)
end
end
### Examples
function (f::Fix{N})(args::Vararg{Any,M}; kws...) where {N,M}
M < N-1 && throw(ArgumentError(LazyString("expected at least ", N-1, " arguments to `Fix{", N, "}`, but got ", M)))
return f.f(args[begin:begin+(N-2)]..., f.x, args[begin+(N-1):end]...; kws...)
end
```jldoctest
julia> Base.fix(2)(-, 3)(7)
4
# Special cases for improved constant propagation
(f::Fix{1})(arg; kws...) = f.f(f.x, arg; kws...)
(f::Fix{2})(arg; kws...) = f.f(arg, f.x; kws...)
julia> Base.fix(2) === Base.Fix2
true
julia> Base.fix(1)(Base.fix(2)(muladd, 3), 2)(5) === (x -> muladd(2, 3, x))(5)
true
```
"""
function fix(@nospecialize m::Integer)
n = Int(m)::Int
if n 0
throw(ArgumentError("the index of the partially applied argument must be positive"))
end
k = _TypeDomainNumbers.Utils.from_abs_int(n)
PartiallyAppliedFunction{typeof(k)}
end

"""
Alias for `Fix{1}`. See [`Fix`](@ref Base.Fix).
Fix1::UnionAll
[`fix(1)`](@ref Base.fix).
"""
const Fix1{F,T} = Fix{1,F,T}
const Fix1 = fix(1)

"""
Alias for `Fix{2}`. See [`Fix`](@ref Base.Fix).
Fix2::UnionAll
[`fix(2)`](@ref Base.fix).
"""
const Fix2{F,T} = Fix{2,F,T}
const Fix2 = fix(2)

# Special cases for improved constant propagation
function (partial::Fix1)(x; kws...)
partial.f(partial.x, x; kws...)
end
function (partial::Fix2)(x; kws...)
partial.f(x, partial.x; kws...)
end


"""
Expand Down
2 changes: 1 addition & 1 deletion base/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public
AsyncCondition,
CodeUnits,
Event,
Fix,
fix,
Fix1,
Fix2,
Generator,
Expand Down
19 changes: 14 additions & 5 deletions base/tuple.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module _TupleTypeByLength
export Tuple1OrMore, Tuple2OrMore, Tuple32OrMore
const Tuple1OrMore = Tuple{Any, Vararg}
const Tuple2OrMore = Tuple{Any, Any, Vararg}
const Tuple32OrMore = Tuple{
Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any, Any,
Vararg{Any, N},
} where {N}
end

# Document NTuple here where we have everything needed for the doc system
"""
NTuple{N, T}
Expand Down Expand Up @@ -358,11 +371,7 @@ map(f, t::Tuple{Any, Any}) = (@inline; (f(t[1]), f(t[2])))
map(f, t::Tuple{Any, Any, Any}) = (@inline; (f(t[1]), f(t[2]), f(t[3])))
map(f, t::Tuple) = (@inline; (f(t[1]), map(f,tail(t))...))
# stop inlining after some number of arguments to avoid code blowup
const Any32{N} = Tuple{Any,Any,Any,Any,Any,Any,Any,Any,
Any,Any,Any,Any,Any,Any,Any,Any,
Any,Any,Any,Any,Any,Any,Any,Any,
Any,Any,Any,Any,Any,Any,Any,Any,
Vararg{Any,N}}
const Any32 = _TupleTypeByLength.Tuple32OrMore
const All32{T,N} = Tuple{T,T,T,T,T,T,T,T,
T,T,T,T,T,T,T,T,
T,T,T,T,T,T,T,T,
Expand Down
167 changes: 167 additions & 0 deletions base/typedomainnumbers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

# Adapted from the TypeDomainNaturalNumbers.jl package.
module _TypeDomainNumbers
module Zeros
export Zero
struct Zero end
end

module PositiveIntegers
module RecursiveStep
using ...Zeros
export recursive_step
function recursive_step(@nospecialize t::Type)
Union{Zero, t}
end
end
module UpperBounds
using ..RecursiveStep
abstract type A end
abstract type B{P <: recursive_step(A)} <: A end
abstract type C{P <: recursive_step(B)} <: B{P} end
abstract type D{P <: recursive_step(C)} <: C{P} end
end
using .RecursiveStep
const PositiveIntegerUpperBound = UpperBounds.A
const PositiveIntegerUpperBoundTighter = UpperBounds.D
export
natural_successor, natural_predecessor,
NonnegativeInteger, NonnegativeIntegerUpperBound,
PositiveInteger, PositiveIntegerUpperBound
struct PositiveInteger{
Predecessor <: recursive_step(PositiveIntegerUpperBoundTighter),
} <: PositiveIntegerUpperBoundTighter{Predecessor}
predecessor::Predecessor
global const NonnegativeInteger = recursive_step(PositiveInteger)
global const NonnegativeIntegerUpperBound = recursive_step(PositiveIntegerUpperBound)
global function natural_successor(p::P) where {P <: NonnegativeInteger}
new{P}(p)
end
end
function natural_predecessor(@nospecialize o::PositiveInteger)
getfield(o, :predecessor) # avoid specializing `getproperty` for each number
end
end

module IntegersGreaterThanOne
using ..PositiveIntegers
export
IntegerGreaterThanOne, IntegerGreaterThanOneUpperBound,
natural_predecessor_predecessor
const IntegerGreaterThanOne = let t = PositiveInteger
t{P} where {P <: t}
end
const IntegerGreaterThanOneUpperBound = let t = PositiveIntegerUpperBound
PositiveIntegers.UpperBounds.B{P} where {P <: t}
end
function natural_predecessor_predecessor(@nospecialize x::IntegerGreaterThanOne)
natural_predecessor(natural_predecessor(x))
end
end

module Constants
using ..Zeros, ..PositiveIntegers
export n0, n1
const n0 = Zero()
const n1 = natural_successor(n0)
end

module Utils
using ..PositiveIntegers, ..IntegersGreaterThanOne, ..Constants
using Base: @_foldable_meta
function subtracted_nonnegative((@nospecialize l::NonnegativeInteger), @nospecialize r::NonnegativeInteger)
@_foldable_meta
if r isa PositiveIntegerUpperBound
let a = natural_predecessor(l), b = natural_predecessor(r)
subtracted_nonnegative(a, b)
end
else
l
end
end
function abs_decrement(n::Int)
@_foldable_meta
if signbit(n)
n + true
else
n - true
end
end
function to_int(@nospecialize o::NonnegativeInteger)
@_foldable_meta
if o isa PositiveIntegerUpperBound
let p = natural_predecessor(o), t = to_int(p)
t + true
end
else
0
end
end
function from_abs_int(n::Int)
@_foldable_meta
ret = n0
while !iszero(n)
n = abs_decrement(n)
ret = natural_successor(ret)
end
ret
end
end

module Overloads
using ..PositiveIntegers, ..Utils
function (::Type{Int})(@nospecialize o::NonnegativeInteger)
Utils.to_int(o)
end
function Base.show((@nospecialize io::Base.IO), @nospecialize n::NonnegativeInteger)
i = Int(n)
Base.show(io, i)
end
end
end

module _TypeDomainNumberTupleUtils
using
.._TypeDomainNumbers.PositiveIntegers, .._TypeDomainNumbers.IntegersGreaterThanOne,
.._TypeDomainNumbers.Constants, .._TypeDomainNumbers.Utils, .._TupleTypeByLength
using Base: @_total_meta, @_foldable_meta, front, tail
export tuple_type_domain_length, split_tuple, skip_from_front, skip_from_tail
function tuple_type_domain_length(@nospecialize tup::Tuple)
@_total_meta
if tup isa Tuple1OrMore
let t = tail(tup), rec = tuple_type_domain_length(t)
natural_successor(rec)
end
else
n0
end
end
function skip_from_front((@nospecialize tup::Tuple), @nospecialize skip_count::NonnegativeInteger)
@_foldable_meta
if skip_count isa PositiveIntegerUpperBound
let cm1 = natural_predecessor(skip_count), t = tail(tup)
@inline skip_from_front(t, cm1)
end
else
tup
end
end
function skip_from_tail((@nospecialize tup::Tuple), @nospecialize skip_count::NonnegativeInteger)
@_foldable_meta
if skip_count isa PositiveIntegerUpperBound
let cm1 = natural_predecessor(skip_count), t = front(tup)
@inline skip_from_tail(t, cm1)
end
else
tup
end
end
function split_tuple((@nospecialize tup::Tuple), @nospecialize len_l::NonnegativeInteger)
len = tuple_type_domain_length(tup)
len_r = Utils.subtracted_nonnegative(len, len_l)
tup_l = skip_from_tail(tup, len_r)
tup_r = skip_from_front(tup, len_l)
(tup_l, tup_r)
end
end
2 changes: 1 addition & 1 deletion doc/src/base/base.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ Base.:(|>)
Base.:(∘)
Base.ComposedFunction
Base.splat
Base.Fix
Base.fix
Base.Fix1
Base.Fix2
```
Expand Down
Loading

0 comments on commit ac0a69e

Please sign in to comment.