Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
feat: better test integration in test_gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 18, 2024
1 parent 6e42fd2 commit 3f3cd9f
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 19 deletions.
1 change: 0 additions & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ indent = 4
format_docstrings = true
separate_kwargs_with_semicolon = true
always_for_in = true
join_lines_based_on_source = false
5 changes: 2 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
fail-fast: false
matrix:
version:
- "min"
- "1"
- "pre"
os:
Expand Down Expand Up @@ -64,7 +65,7 @@ jobs:
runs-on: ${{ matrix.os }}
timeout-minutes: 60
env:
GROUP: ${{ matrix.package.group }}
BACKEND_GROUP: ${{ matrix.package.group }}
strategy:
fail-fast: false
matrix:
Expand Down Expand Up @@ -126,8 +127,6 @@ jobs:
- uses: julia-actions/julia-downgrade-compat@v1
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
LUX_TEST_GROUP: ${{ matrix.test_group }}
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v4
with:
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project since the release of v1 will be documented i
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [1.2.0] - 2024-09-17

### Added

- By default, we no longer wrap the entire gradient computation in a `@test` macro.

## [1.1.4] - 2024-08-21

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxTestUtils"
uuid = "ac9de150-d08f-4546-94fb-7472b5760531"
authors = ["Avik Pal <avikpal@mit.edu>"]
version = "1.1.4"
version = "1.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
56 changes: 42 additions & 14 deletions src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,13 @@ julia> test_gradients(f, 1.0, x, nothing)
```
"""
function test_gradients(f, args...; skip_backends=[], broken_backends=[],
soft_fail::Union{Bool, Vector}=false, kwargs...)
soft_fail::Union{Bool, Vector}=false,
# Internal kwargs start
source=LineNumberNode(0, nothing),
test_expr=:(check_approx(∂args, ∂args_gt; kwargs...)),
# Internal kwargs end
kwargs...)
# TODO: We should add a macro version that propagates the line number info and the test_expr
on_gpu = get_device_type(args) <: AbstractGPUDevice
total_length = mapreduce(__length, +, Functors.fleaves(args); init=0)

Expand Down Expand Up @@ -157,36 +163,58 @@ function test_gradients(f, args...; skip_backends=[], broken_backends=[],

@testset "gradtest($(f))" begin
@testset "$(nameof(typeof(backends[1])))() vs $(nameof(typeof(backend)))()" for backend in backends[2:end]
if backend in skip_backends
@test_skip begin
∂args = allow_unstable() do
return gradient(f, backend, args...)
end
check_approx(∂args, ∂args_gt; kwargs...)
end
local_test_expr = :([$(nameof(typeof(backend)))] - $(test_expr))

result = if backend in skip_backends
Broken(:skipped, local_test_expr)
elseif (soft_fail isa Bool && soft_fail) ||
(soft_fail isa Vector && backend in soft_fail)
@test_softfail begin
try
∂args = allow_unstable() do
return gradient(f, backend, args...)
end
check_approx(∂args, ∂args_gt; kwargs...)
matched = check_approx(∂args, ∂args_gt; kwargs...)
if matched
Pass(:test, local_test_expr, nothing, nothing, source)
else
Broken(:test, local_test_expr)
end
catch
Broken(:test, local_test_expr)
end
elseif backend in broken_backends
@test_broken begin
try
∂args = allow_unstable() do
return gradient(f, backend, args...)
end
check_approx(∂args, ∂args_gt; kwargs...)
matched = check_approx(∂args, ∂args_gt; kwargs...)
if matched
Error(:test_unbroken, local_test_expr, matched, nothing, source)
else
Broken(:test, local_test_expr)
end
catch
Broken(:test, local_test_expr)
end
else
@test begin
try
∂args = allow_unstable() do
return gradient(f, backend, args...)
end
check_approx(∂args, ∂args_gt; kwargs...)
matched = check_approx(∂args, ∂args_gt; kwargs...)
if matched
Pass(:test, local_test_expr, nothing, nothing, source)
else
context = "\n ∂args: $(∂args)\n∂args_gt: $(∂args_gt)"
Fail(
:test, local_test_expr, matched, nothing, context, source, false)
end
catch err
err isa InterruptException && rethrow()
Error(:test, local_test_expr, err, Base.current_exceptions(), source)
end
end
Test.record(get_testset(), result)
end
end
end

0 comments on commit 3f3cd9f

Please sign in to comment.