Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

feat: better test integration in test_gradients #35

Merged
merged 3 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ indent = 4
format_docstrings = true
separate_kwargs_with_semicolon = true
always_for_in = true
join_lines_based_on_source = false
5 changes: 2 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
fail-fast: false
matrix:
version:
- "min"
- "1"
- "pre"
os:
Expand Down Expand Up @@ -64,7 +65,7 @@ jobs:
runs-on: ${{ matrix.os }}
timeout-minutes: 60
env:
GROUP: ${{ matrix.package.group }}
BACKEND_GROUP: ${{ matrix.package.group }}
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -126,8 +127,6 @@ jobs:
- uses: julia-actions/julia-downgrade-compat@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
LUX_TEST_GROUP: ${{ matrix.test_group }}
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v4
with:
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project since the release of v1 will be documented i
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.2.0] - 2024-09-18

### Added

- By default, we no longer wrap the entire gradient computation in a `@test` macro.

## [1.1.4] - 2024-08-21

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxTestUtils"
uuid = "ac9de150-d08f-4546-94fb-7472b5760531"
authors = ["Avik Pal <avikpal@mit.edu>"]
version = "1.1.4"
version = "1.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion src/LuxTestUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ include("autodiff.jl")
include("jet.jl")

export AutoEnzyme, AutoFiniteDiff, AutoTracker, AutoForwardDiff, AutoReverseDiff, AutoZygote
export test_gradients
export test_gradients, @test_gradients
export @jet, jet_target_modules!
export @test_softfail

Expand Down
77 changes: 63 additions & 14 deletions src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,13 @@ julia> test_gradients(f, 1.0, x, nothing)
```
"""
function test_gradients(f, args...; skip_backends=[], broken_backends=[],
soft_fail::Union{Bool, Vector}=false, kwargs...)
soft_fail::Union{Bool, Vector}=false,
# Internal kwargs start
source::LineNumberNode=LineNumberNode(0, nothing),
test_expr::Expr=:(check_approx(∂args, ∂args_gt; kwargs...)),
# Internal kwargs end
kwargs...)
# TODO: We should add a macro version that propagates the line number info and the test_expr
on_gpu = get_device_type(args) <: AbstractGPUDevice
total_length = mapreduce(__length, +, Functors.fleaves(args); init=0)

Expand Down Expand Up @@ -157,36 +163,79 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[],

@testset "gradtest($(f))" begin
@testset "$(nameof(typeof(backends[1])))() vs $(nameof(typeof(backend)))()" for backend in backends[2:end]
if backend in skip_backends
@test_skip begin
∂args = allow_unstable() do
return gradient(f, backend, args...)
end
check_approx(∂args, ∂args_gt; kwargs...)
end
local_test_expr = :([$(nameof(typeof(backend)))] - $(test_expr))

result = if backend in skip_backends
Broken(:skipped, local_test_expr)
elseif (soft_fail isa Bool && soft_fail) ||
(soft_fail isa Vector && backend in soft_fail)
@test_softfail begin
try
∂args = allow_unstable() do
return gradient(f, backend, args...)
end
check_approx(∂args, ∂args_gt; kwargs...)
matched = check_approx(∂args, ∂args_gt; kwargs...)
if matched
Pass(:test, local_test_expr, nothing, nothing, source)
else
Broken(:test, local_test_expr)
end
catch
Broken(:test, local_test_expr)
end
elseif backend in broken_backends
@test_broken begin
try
∂args = allow_unstable() do
return gradient(f, backend, args...)
end
check_approx(∂args, ∂args_gt; kwargs...)
matched = check_approx(∂args, ∂args_gt; kwargs...)
if matched
Error(:test_unbroken, local_test_expr, matched, nothing, source)
else
Broken(:test, local_test_expr)
end
catch
Broken(:test, local_test_expr)
end
else
@test begin
try
∂args = allow_unstable() do
return gradient(f, backend, args...)
end
check_approx(∂args, ∂args_gt; kwargs...)
matched = check_approx(∂args, ∂args_gt; kwargs...)
if matched
Pass(:test, local_test_expr, nothing, nothing, source)
else
context = "\n ∂args: $(∂args)\n∂args_gt: $(∂args_gt)"
Fail(
:test, local_test_expr, matched, nothing, context, source, false)
end
catch err
err isa InterruptException && rethrow()
Error(:test, local_test_expr, err, Base.current_exceptions(), source)
end
end
Test.record(get_testset(), result)
end
end
end

"""
@test_gradients(f, args...; kwargs...)

See the documentation of [`test_gradients`](@ref) for more details. This macro provides
correct line information for the failing tests.
"""
macro test_gradients(exprs...)
exs = reorder_macro_kw_params(exprs)
kwarg_idx = findfirst(ex -> Meta.isexpr(ex, :kw), exs)
if kwarg_idx === nothing
args = [exs...]
kwargs = []
else
args = [exs[1:(kwarg_idx - 1)]...]
kwargs = [exs[kwarg_idx:end]...]
end
push!(kwargs, Expr(:kw, :source, QuoteNode(__source__)))
push!(kwargs, Expr(:kw, :test_expr, QuoteNode(:(test_gradients($(exs...))))))
return esc(:($(test_gradients)($(args...); $(kwargs...))))
end
14 changes: 14 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,17 @@ check_approx(x::AbstractArray, y::NamedTuple; kwargs...) = length(x) == 0 && len
check_approx(x::NamedTuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0
check_approx(x::AbstractArray, y::Tuple; kwargs...) = length(x) == 0 && length(y) == 0
check_approx(x::Tuple, y::AbstractArray; kwargs...) = length(x) == 0 && length(y) == 0

# Taken from discourse. normalizes the order of keyword arguments in a macro
function reorder_macro_kw_params(exs)
exs = Any[exs...]
i = findfirst([(ex isa Expr && ex.head == :parameters) for ex in exs])
if i !== nothing
extra_kw_def = exs[i].args
for ex in extra_kw_def
push!(exs, ex isa Symbol ? Expr(:kw, ex, ex) : ex)
end
deleteat!(exs, i)
end
return Tuple(exs)
end
13 changes: 13 additions & 0 deletions test/unit_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,38 @@ end
test_gradients(f, 1.0, x, nothing)

test_gradients(f, 1.0, x, nothing; skip_backends=[AutoTracker()])
@test_gradients(f, 1.0, x, nothing; skip_backends=[AutoTracker()])

@test errors() do
test_gradients(f, 1.0, x, nothing; broken_backends=[AutoTracker()])
end

@test errors() do
@test_gradients(f, 1.0, x, nothing; broken_backends=[AutoTracker()])
end

@test_throws ArgumentError test_gradients(
f, 1.0, x, nothing; broken_backends=[AutoTracker()],
skip_backends=[AutoTracker(), AutoEnzyme()])
@test_throws ArgumentError @test_gradients(
f, 1.0, x, nothing; broken_backends=[AutoTracker()],
skip_backends=[AutoTracker(), AutoEnzyme()])

test_gradients(f, 1.0, x, nothing; soft_fail=[AutoTracker()])
@test_gradients(f, 1.0, x, nothing; soft_fail=[AutoTracker()])

test_gradients(f, 1.0, x, nothing; soft_fail=true)
@test_gradients(f, 1.0, x, nothing; soft_fail=true)

x_ca = ComponentArray(x)

test_gradients(f, 1.0, x_ca, nothing)
@test_gradients(f, 1.0, x_ca, nothing)

x_2 = (; t=x.t', x=(z=x.x.z',))

test_gradients(f, 1.0, x_2, nothing)
@test_gradients(f, 1.0, x_2, nothing)
end

@testitem "test_gradients (CUDA.jl)" skip=:(using CUDA; !CUDA.functional()) begin
Expand Down
Loading