Skip to content

Commit

Permalink
traits only get applied and defined on instances
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Mar 7, 2023
1 parent c5886a7 commit b460b71
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 71 deletions.
53 changes: 24 additions & 29 deletions docs/src/algorithm_traits.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,46 +4,41 @@
> *This algorithm supports per-observation weights, which must appear as the third
> argument of `fit`*, or *This algorithm's `transform` method predicts `Real` vectors*.
For any (non-trivial) algorithm, [`LearnAPI.functions`](@ref)`(algorithm)` must be
overloaded to list the LearnAPI methods that have been explicitly implemented/overloaded
(algorithm traits excluded). Overloading other traits is optional, except where required
by the implementation of some LearnAPI method and explicitly documented in that method's
docstring.

Traits are often called on instances but are usually *defined* on algorithm *types*, as in
Algorithm traits are functions whose first (and usually only) argument is an algorithm. In
a new implementation, a single-argument trait is declared following this pattern:

```julia
LearnAPI.is_pure_julia(::Type{<:MyAlgorithmType}) = true
LearnAPI.is_pure_julia(algorithm::MyAlgorithmType) = true
```

which has the shorthand
!!! important

```julia
@trait MyAlgorithmType is_pure_julia=true
```
The value of a trait must be the same for all algorithms of the same type,
even if the types differ only in type parameters. There are exceptions for
some traits, if
`is_wrapper(algorithm) = true` for all instances `algorithm` of some type
(composite algorithms). This requirement occasionally requires that
an existing algorithm implementation be split into separate LearnAPI
implementations (e.g., one for regression and another for classification).

So, for convenience, every trait `t` is provided the fallback implementation
The declaration above has the shorthand

```julia
t(algorithm) = t(typeof(algorithm))
@trait MyAlgorithmType is_pure_julia=true
```

This means `LearnAPI.is_pure_julia(algorithm) = true` whenever `algorithm isa MyAlgorithmType` in the
above example.

Every trait has a global fallback implementation for `::Type`. See the table below.
Multiple traits can be declared like this:

## When traits depdend on more than algorithm type

Traits that vary from instance to instance of the same type are disallowed, except in the
case of composite algorithms (`is_wrapper(algorithm) = true`) where this is typically
unavoidable. The reason for this is so one can associate, with each non-composite
algorithm type, unique trait-based "algorithm metadata", for inclusion in searchable
algorithm databases. This requirement occasionally requires that an existing algorithm
implementation be split into separate LearnAPI implementations (e.g., one for regression
and another for classification).
```julia
@trait(
MyAlgorithmType,
is_pure_julia = true,
pkg_name = "MyPackage",
)
```

## Special two-argument traits
### Special two-argument traits

The two-argument version of [`LearnAPI.predict_output_scitype`](@ref) and
[`LearnAPI.predict_output_scitype`](@ref) are the only overloadable traits with more than
Expand All @@ -55,7 +50,7 @@ one argument. They cannot be declared using the `@trait` macro.
implementation. **Derived traits** are not, and should not be called by performance
critical code

## Overloadable traits
### Overloadable traits

In the examples column of the table below, `Table`, `Continuous`, `Sampleable` are names owned by the
package [ScientificTypesBase.jl](https://github.com/JuliaAI/ScientificTypesBase.jl/).
Expand Down Expand Up @@ -100,7 +95,7 @@ include the variable.
for the general case.


## Derived Traits
### Derived Traits

The following convenience methods are provided but intended for overloading:

Expand Down
8 changes: 4 additions & 4 deletions docs/src/anatomy_of_an_implementation.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
> returning the absolute values of the linear coefficients. The ridge regressor has a
> target variable and outputs literal predictions of the target (rather than, say,
> probabilistic predictions); accordingly the overloaded `predict` method is dispatched on
> the `LiteralTarget` subtype of `KindOfProxy`. An **algorithm trait** declares this type as the
> preferred kind of target proxy. Other traits articulate the algorithm's training data type
> requirements and the input/output type of `predict`.
> the `LiteralTarget` subtype of `KindOfProxy`. An **algorithm trait** declares this type
> as the preferred kind of target proxy. Other traits articulate the algorithm's training
> data type requirements and the input/output type of `predict`.
We begin by describing an implementation of LearnAPI.jl for basic ridge regression
(without intercept) to introduce the main actors in any implementation.
Expand Down Expand Up @@ -159,7 +159,7 @@ list). Accordingly, we are required to declare a preferred target proxy, which w
[`LearnAPI.preferred_kind_of_proxy`](@ref):

```@example anatomy
LearnAPI.preferred_kind_of_proxy(::Type{<:MyRidge}) = LearnAPI.LiteralTarget()
LearnAPI.preferred_kind_of_proxy(::MyRidge) = LearnAPI.LiteralTarget()
nothing # hide
```
Or, you can use the shorthand
Expand Down
66 changes: 29 additions & 37 deletions src/algorithm_traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ this list, do `LearnAPI.functions()`.
See also [`LearnAPI.Algorithm`](@ref).
"""
functions(::Type) = ()
functions(::Any) = ()


"""
Expand Down Expand Up @@ -104,13 +104,13 @@ Then we can declare
which is shorthand for
```julia
LearnAPI.preferred_kind_of_proxy(::Type{<:MyNewAlgorithmType}) = LearnAPI.Distribution()
LearnAPI.preferred_kind_of_proxy(::MyNewAlgorithmType) = LearnAPI.Distribution()
```
For more on target variables and target proxies, refer to the LearnAPI documentation.
"""
preferred_kind_of_proxy(::Type) = nothing
preferred_kind_of_proxy(::Any) = nothing

"""
LearnAPI.position_of_target(algorithm)
Expand All @@ -122,7 +122,7 @@ If this number is `0`, then no target is expected. If this number exceeds `lengt
then `data` is understood to exclude the target variable.
"""
position_of_target(::Type) = 0
position_of_target(::Any) = 0

"""
LearnAPI.position_of_weights(algorithm)
Expand All @@ -135,7 +135,7 @@ If this number is `0`, then no weights are expected. If this number exceeds
uniform.
"""
position_of_weights(::Type) = 0
position_of_weights(::Any) = 0

descriptors() = [
:regression,
Expand Down Expand Up @@ -180,7 +180,7 @@ Lists one or more suggestive algorithm descriptors from this list: $DOC_DESCRIPT
This trait should return a tuple of symbols, as in `(:classifier, :probabilistic)`.
"""
descriptors(::Type) = ()
descriptors(::Any) = ()

"""
LearnAPI.is_pure_julia(algorithm)
Expand All @@ -192,7 +192,7 @@ Returns `true` if training `algorithm` requires evaluation of pure Julia code on
The fallback is `false`.
"""
is_pure_julia(::Type) = false
is_pure_julia(::Any) = false

"""
LearnAPI.pkg_name(algorithm)
Expand All @@ -208,7 +208,7 @@ $DOC_UNKNOWN
Must return a string, as in `"DecisionTree"`.
"""
pkg_name(::Type) = "unknown"
pkg_name(::Any) = "unknown"

"""
LearnAPI.pkg_license(algorithm)
Expand All @@ -217,7 +217,7 @@ Return the name of the software license, such as `"MIT"`, applying to the packag
core algorithm for `algorithm` is implemented.
"""
pkg_license(::Type) = "unknown"
pkg_license(::Any) = "unknown"

"""
LearnAPI.doc_url(algorithm)
Expand All @@ -231,7 +231,7 @@ $DOC_UNKNOWN
Must return a string, such as `"https://en.wikipedia.org/wiki/Decision_tree_learning"`.
"""
doc_url(::Type) = "unknown"
doc_url(::Any) = "unknown"

"""
LearnAPI.load_path(algorithm)
Expand All @@ -250,7 +250,7 @@ $DOC_UNKNOWN
"""
load_path(::Type) = "unknown"
load_path(::Any) = "unknown"


"""
Expand All @@ -268,7 +268,7 @@ $DOC_ON_TYPE
"""
is_wrapper(::Type) = false
is_wrapper(::Any) = false

"""
LearnAPI.human_name(algorithm)
Expand All @@ -284,7 +284,7 @@ to return `"K-nearest neighbors regressor"`. Ideally, this is a "concrete" noun
`"ridge regressor"` rather than an "abstract" noun like `"ridge regression"`.
"""
human_name(M::Type{}) = snakecase(name(M), delim=' ') # `name` defined below
human_name(M) = snakecase(name(M), delim=' ') # `name` defined below

"""
LearnAPI.iteration_parameter(algorithm)
Expand All @@ -297,7 +297,7 @@ iterative.
Implement if algorithm is iterative. Returns a symbol or `nothing`.
"""
iteration_parameter(::Type) = nothing
iteration_parameter(::Any) = nothing

"""
LearnAPI.fit_keywords(algorithm)
Expand All @@ -314,7 +314,7 @@ Here's a sample implementation for a classifier that implements a `LearnAPI.fit`
with signature `fit(algorithm::MyClassifier, verbosity, X, y; class_weights=nothing)`:
```
LearnAPI.fit_keywords(::Type{<:MyClassifier}) = (:class_weights,)
LearnAPI.fit_keywords(::Any{<:MyClassifier}) = (:class_weights,)
```
or the shorthand
Expand All @@ -325,7 +325,7 @@ or the shorthand
"""
fit_keywords(::Type) = ()
fit_keywords(::Any) = ()

"""
LearnAPI.fit_scitype(algorithm)
Expand Down Expand Up @@ -353,7 +353,7 @@ See also [`LearnAPI.fit_type`](@ref), [`LearnAPI.fit_observation_scitype`](@ref)
Optional. The fallback return value is `Union{}`. $DOC_ONLY_ONE
"""
fit_scitype(::Type) = Union{}
fit_scitype(::Any) = Union{}

"""
LearnAPI.fit_observation_scitype(algorithm)
Expand Down Expand Up @@ -386,7 +386,7 @@ See also See also [`LearnAPI.fit_type`](@ref), [`LearnAPI.fit_scitype`](@ref),
Optional. The fallback return value is `Union{}`. $DOC_ONLY_ONE
"""
fit_observation_scitype(::Type) = Union{}
fit_observation_scitype(::Any) = Union{}

"""
LearnAPI.fit_type(algorithm)
Expand All @@ -413,7 +413,7 @@ See also [`LearnAPI.fit_scitype`](@ref), [`LearnAPI.fit_observation_type`](@ref)
Optional. The fallback return value is `Union{}`. $DOC_ONLY_ONE
"""
fit_type(::Type) = Union{}
fit_type(::Any) = Union{}

"""
LearnAPI.fit_observation_type(algorithm)
Expand Down Expand Up @@ -446,7 +446,7 @@ See also See also [`LearnAPI.fit_type`](@ref), [`LearnAPI.fit_scitype`](@ref),
Optional. The fallback return value is `Union{}`. $DOC_ONLY_ONE
"""
fit_observation_type(::Type) = Union{}
fit_observation_type(::Any) = Union{}

DOC_INPUT_SCITYPE(op) =
"""
Expand Down Expand Up @@ -543,22 +543,22 @@ DOC_OUTPUT_TYPE(op) =
"""

"$(DOC_INPUT_SCITYPE(:predict))"
predict_input_scitype(::Type) = Union{}
predict_input_scitype(::Any) = Union{}

"$(DOC_INPUT_TYPE(:predict))"
predict_input_type(::Type) = Union{}
predict_input_type(::Any) = Union{}

"$(DOC_INPUT_SCITYPE(:transform))"
transform_input_scitype(::Type) = Union{}
transform_input_scitype(::Any) = Union{}

"$(DOC_OUTPUT_SCITYPE(:transform))"
transform_output_scitype(::Type) = Any
transform_output_scitype(::Any) = Any

"$(DOC_INPUT_TYPE(:transform))"
transform_input_type(::Type) = Union{}
transform_input_type(::Any) = Union{}

"$(DOC_OUTPUT_TYPE(:transform))"
transform_output_type(::Type) = Any
transform_output_type(::Any) = Any


# # TWO-ARGUMENT TRAITS
Expand Down Expand Up @@ -591,7 +591,7 @@ const DOC_PREDICT_OUTPUT(s) =
regressor type `MyRgs` that only predicts actual values of the target:
LearnAPI.predict(alogrithm::MyRgs, ::LearnAPI.LiteralTarget, data...) = ...
LearnAPI.predict_output_$(s)(::Type{<:MyRgs}, ::LearnAPI.LiteralTarget) =
LearnAPI.predict_output_$(s)(::MyRgs, ::LearnAPI.LiteralTarget) =
AbstractVector{ScientificTypesBase.Continuous}
The fallback method returns `Any`.
Expand All @@ -607,9 +607,9 @@ predict_output_type(algorithm, kind_of_proxy) = Any

# # DERIVED TRAITS

name(A::Type) = string(typename(A))
name(A) = string(typename(A))

is_algorithm(A::Type) = !isempty(functions(A))
is_algorithm(A) = !isempty(functions(A))

const DOC_PREDICT_OUTPUT2(s) =
"""
Expand Down Expand Up @@ -651,11 +651,3 @@ predict_output_type(algorithm) =
for T in CONCRETE_TARGET_PROXY_TYPES)


# # FALLBACK FOR INSTANCES

for trait in TRAITS
ex = quote
$trait(x) = $trait(typeof(x))
end
eval(ex)
end
2 changes: 1 addition & 1 deletion src/tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ macro trait(algorithm_ex, exs...)
trait_ex, value_ex = name_value_pair(ex)
push!(
program.args,
:($LearnAPI.$trait_ex(::Type{<:$algorithm_ex}) = $value_ex),
:($LearnAPI.$trait_ex(::$algorithm_ex) = $value_ex),
)
end
return esc(program)
Expand Down
12 changes: 12 additions & 0 deletions test/tools.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
module Fruit
using LearnAPI

struct RedApple{T}
x::T
end

@trait(
RedApple,
is_pure_julia = true,
pkg_name = "Fruity",
)

end

import .Fruit
Expand All @@ -30,4 +37,9 @@ end
@test LearnAPI.snakecase(:TheLASERBeam) == :the_laser_beam
end

@testset "@trait" begin
@test LearnAPI.is_pure_julia(Fruit.RedApple(1))
@test LearnAPI.pkg_name(Fruit.RedApple(1)) == "Fruity"
end

true

0 comments on commit b460b71

Please sign in to comment.