Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: use TestExtras #1099

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lib/LuxLib/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand Down Expand Up @@ -61,5 +62,6 @@ Static = "0.8.4, 1"
StaticArrays = "1.9.7"
Statistics = "1.10"
Test = "1.10"
TestExtras = "0.3.1"
Tracker = "0.2.36"
Zygote = "0.6.70"
12 changes: 6 additions & 6 deletions lib/LuxLib/test/common_ops/activation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@
@test eltype(y2) == T
@test eltype(y3) == T

@test @inferred(apply_act(f, x)) isa Any
@test @inferred(apply_act_fast(f, x)) isa Any
@test @inferred(apply_act_fast2(f, x)) isa Any
@constinferred apply_act(f, x)
@constinferred apply_act_fast(f, x)
@constinferred apply_act_fast2(f, x)

@jet apply_act_fast(f, x)
@jet apply_act_fast2(f, x)

@test @inferred(Zygote.gradient(apply_act, f, x)) isa Any
@constinferred Zygote.gradient(apply_act, f, x)
if f !== lisht
@test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any
@constinferred Zygote.gradient(apply_act_fast, f, x)
end
@test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any
@constinferred Zygote.gradient(apply_act_fast2, f, x)

@test_gradients(apply_act, f, x; atol, rtol)
@test_gradients(apply_act_fast, f, x; atol, rtol, skip_backends=[AutoEnzyme()])
Expand Down
38 changes: 14 additions & 24 deletions lib/LuxLib/test/common_ops/bias_act_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@
bias_act_loss2(act, x, b) = sum(abs2, bias_activation(act, x, b))
bias_act_loss3(act, x, b) = sum(abs2, bias_activation!!(act, copy(x), b))

struct __Fix1{F, A}
f::F
act::A
end
(f::__Fix1)(x, b) = f.f(f.act, x, b)

@testset "$mode" for (mode, aType, ongpu, fp64) in MODES
@testset "$act, $T, $sz" for act in [
identity, relu, sigmoid, sigmoid_fast, softplus,
Expand All @@ -27,38 +21,34 @@
y2 = bias_act_loss2(act, x, b)
y3 = bias_act_loss3(act, x, b)

fp16 = T == Float16
atol = fp16 ? 1.0f-2 : 1.0f-3
rtol = fp16 ? 1.0f-2 : 1.0f-3
atol = 1.0f-3
rtol = 1.0f-3

@test y1≈y2 atol=atol rtol=rtol
@test y1≈y3 atol=atol rtol=rtol
@test eltype(y1) == T
@test eltype(y2) == T
@test eltype(y3) == T

@test @inferred(bias_act_loss1(act, x, b)) isa Any
@test @inferred(bias_act_loss2(act, x, b)) isa Any
@test @inferred(bias_act_loss3(act, x, b)) isa Any
@constinferred bias_act_loss1(act, x, b)
@constinferred bias_act_loss2(act, x, b)
@constinferred bias_act_loss3(act, x, b)

@jet bias_act_loss2(act, x, b)
@jet bias_act_loss3(act, x, b)

if act !== lisht && T != Float16
@test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any
@test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any
if act !== lisht
@constinferred Zygote.gradient(bias_act_loss2, act, x, b)
@constinferred Zygote.gradient(bias_act_loss3, act, x, b)
end

@test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol,
soft_fail=fp16 ? [AutoFiniteDiff()] : [])
@test_gradients(__Fix1(bias_act_loss2, act), x, b; atol, rtol,
soft_fail=fp16 ? [AutoFiniteDiff()] : [])
@test_gradients(__Fix1(bias_act_loss3, act), x, b; atol, rtol,
soft_fail=fp16 ? [AutoFiniteDiff()] : [])
@test_gradients(bias_act_loss1, act, x, b; atol, rtol)
@test_gradients(bias_act_loss2, act, x, b; atol, rtol)
@test_gradients(bias_act_loss3, act, x, b; atol, rtol)

∂x1, ∂b1 = Zygote.gradient(__Fix1(bias_act_loss1, act), x, b)
∂x2, ∂b2 = Zygote.gradient(__Fix1(bias_act_loss2, act), x, b)
∂x3, ∂b3 = Zygote.gradient(__Fix1(bias_act_loss3, act), x, b)
_, ∂x1, ∂b1 = Zygote.pullback(bias_act_loss1, act, x, b)
_, ∂x2, ∂b2 = Zygote.pullback(bias_act_loss2, act, x, b)
_, ∂x3, ∂b3 = Zygote.pullback(bias_act_loss3, act, x, b)

@test ∂x1≈∂x2 atol=atol rtol=rtol
@test ∂x1≈∂x3 atol=atol rtol=rtol
Expand Down
15 changes: 7 additions & 8 deletions lib/LuxLib/test/common_ops/conv_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module ConvSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, TestExtras

expand(_, i::Tuple) = i
expand(N, i::Integer) = ntuple(_ -> i, N)
Expand Down Expand Up @@ -43,20 +43,19 @@ function run_conv_testing(gen_f::Function, activation, kernel, stride, padding,

@test eltype(y) == promote_type(Tw, Tx)

@test @inferred(fused_conv_bias_activation(activation, weight, x, bias, cdims)) isa Any
@constinferred fused_conv_bias_activation(activation, weight, x, bias, cdims)
@jet fused_conv_bias_activation(activation, weight, x, bias, cdims)

if mode != "amdgpu" && activation !== anonact
@test @inferred(Zygote.gradient(
sumabs2conv, activation, weight, x, bias, cdims
)) isa Any
@constinferred Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)
else
try
@inferred(Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims))
@test true
@constinferred Zygote.gradient(sumabs2conv, activation, weight, x, bias, cdims)
catch e
e isa ErrorException || rethrow()
@test_broken false
@constinferred_broken Zygote.gradient(
sumabs2conv, activation, weight, x, bias, cdims
)
end
end

Expand Down
14 changes: 7 additions & 7 deletions lib/LuxLib/test/common_ops/dense_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module DenseSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, StableRNGs, TestExtras

anonact = x -> x^3

Expand Down Expand Up @@ -27,14 +27,14 @@ function run_dense_testing(Tw, Tx, M, N, hasbias, activation, aType, mode, ongpu
@test y ≈ y_generic
@test eltype(y) == promote_type(Tw, Tx)

@test @inferred(fused_dense_bias_activation(activation, w, x, bias)) isa Any
@constinferred fused_dense_bias_activation(activation, w, x, bias)
@jet fused_dense_bias_activation(activation, w, x, bias)

atol = 1.0f-3
rtol = 1.0f-3

if activation !== anonact
@test @inferred(Zygote.gradient(sumabs2dense, activation, w, x, bias)) isa Any
@constinferred Zygote.gradient(sumabs2dense, activation, w, x, bias)
end

skip_backends = []
Expand Down Expand Up @@ -117,23 +117,23 @@ end
end

@testitem "Fused Dense: StaticArrays" tags=[:dense] begin
using StaticArrays, NNlib
using StaticArrays, NNlib, TestExtras

x = @SArray rand(2, 4)
weight = @SArray rand(3, 2)
bias = @SArray rand(3)

@test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa SArray
@constinferred fused_dense_bias_activation(relu, weight, x, bias)
end

@testitem "Fused Dense: CPU No Scalar Indexing" tags=[:dense] begin
using JLArrays, NNlib
using JLArrays, NNlib, TestExtras

x = JLArray(rand(Float32, 2, 4))
weight = JLArray(rand(Float32, 3, 2))
bias = JLArray(rand(Float32, 3))

@test @inferred(fused_dense_bias_activation(relu, weight, x, bias)) isa JLArray
@constinferred fused_dense_bias_activation(relu, weight, x, bias)
@test LuxLib.internal_operation_mode(x) isa LuxLib.GenericBroadcastOp
end

Expand Down
25 changes: 11 additions & 14 deletions lib/LuxLib/test/common_ops/dropout_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

x = randn(rng, T, x_shape) |> aType

@test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any
@constinferred dropout(rng, x, T(0.5), Val(true), T(2), dims)

y, mask_, rng_ = dropout(rng, x, T(0.5), Val(true), T(2), dims)

Expand All @@ -21,10 +21,10 @@
@test rng != rng_

@jet sum(first(dropout(rng, x, T(0.5), Val(true), T(2), dims)))
@test @inferred(dropout(rng, x, T(0.5), Val(true), T(2), dims)) isa Any
@constinferred dropout(rng, x, T(0.5), Val(true), T(2), dims)

__f = x -> sum(first(dropout(StableRNG(0), x, 0.5, Val(true), 2.0, dims)))
@test @inferred(Zygote.gradient(__f, x)) isa Any
@constinferred Zygote.gradient(__f, x)

@test_gradients(sumabs2first,
dropout, rng, x, T(0.5), Val(true), T(2), dims; atol=1.0f-3, rtol=1.0f-3)
Expand Down Expand Up @@ -54,8 +54,7 @@ end
mask = rand(T, x_shape) |> aType

# Update mask
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)) isa Any
@constinferred dropout(rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)
Expand All @@ -69,7 +68,7 @@ end

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(true), 2.0, :)))
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any
@constinferred Zygote.gradient(__f, x, mask)

@test_gradients(sumabs2first,
dropout, rng, x, LuxTestUtils.Constant(mask), T(0.5), Val(true), Val(true),
Expand All @@ -79,8 +78,7 @@ end
rng, x, mask, T(0.5), Val(true), Val(true), T(2), :)))

# Try using mask if possible (possible!!)
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)) isa Any
@constinferred dropout(rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(true), Val(false), T(2), :)
Expand All @@ -94,7 +92,7 @@ end

__f = (x, mask) -> sum(first(dropout(
StableRNG(0), x, mask, 0.5, Val(true), Val(false), 2.0, :)))
@test @inferred(Zygote.gradient(__f, x, mask)) isa Any
@constinferred Zygote.gradient(__f, x, mask)

@test_gradients(sumabs2first,
dropout, rng, x, LuxTestUtils.Constant(mask),
Expand All @@ -107,8 +105,7 @@ end
mask = rand(T, (x_shape[1:(end - 1)]..., 13)) |> aType

# Testing Mode
@test @inferred(dropout(
rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)) isa Any
@constinferred dropout(rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)

y, mask_, rng_ = dropout(
rng, x, mask, T(0.5), Val(false), Val(false), T(2), :)
Expand All @@ -135,7 +132,7 @@ end

x = randn(rng, T, x_shape) |> aType

@test @inferred(alpha_dropout(rng, x, T(0.5), Val(true))) isa Any
@constinferred alpha_dropout(rng, x, T(0.5), Val(true))

y, rng_ = alpha_dropout(rng, x, T(0.5), Val(true))

Expand All @@ -146,13 +143,13 @@ end
@test_broken std(y)≈std(x) atol=1.0f-2 rtol=1.0f-2

__f = x -> sum(first(alpha_dropout(StableRNG(0), x, 0.5, Val(true))))
@test @inferred(Zygote.gradient(__f, x)) isa Any
@constinferred Zygote.gradient(__f, x)

@test_gradients(sumabs2first,
alpha_dropout, rng, x, T(0.5), Val(true); atol=1.0f-3, rtol=1.0f-3)

@jet sum(first(alpha_dropout(rng, x, T(0.5), Val(true))))
@test @inferred(alpha_dropout(rng, x, T(0.5), Val(false))) isa Any
@constinferred alpha_dropout(rng, x, T(0.5), Val(false))

y, rng_ = alpha_dropout(rng, x, T(0.5), Val(false))

Expand Down
8 changes: 3 additions & 5 deletions lib/LuxLib/test/normalization/batchnorm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module BatchNormSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, TestExtras

function setup_batchnorm(gen_f, aType, T, sz; affine::Bool=true, track_stats::Bool)
x = gen_f(T, sz) |> aType
Expand Down Expand Up @@ -69,8 +69,7 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act,
end
end

@test @inferred(batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)) isa
Any
@constinferred batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)
@jet batchnorm(x, scale, bias, rm, rv, training, act, T(0.9), epsilon)

@test y isa aType{T, length(sz)}
Expand All @@ -91,8 +90,7 @@ function run_batchnorm_testing(gen_f, T, sz, training, affine, track_stats, act,
if anonact !== act
lfn = (x, sc, b, rm, rv, tr, act, ϵ) -> sum(first(batchnorm(
x, sc, b, rm, rv, tr, act, ϵ)))
@test @inferred(Zygote.gradient(
lfn, x, scale, bias, rm, rv, training, act, epsilon)) isa Any
@constinferred Zygote.gradient(lfn, x, scale, bias, rm, rv, training, act, epsilon)
end
end

Expand Down
6 changes: 3 additions & 3 deletions lib/LuxLib/test/normalization/groupnorm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module GroupNormSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, StableRNGs
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, Static, StableRNGs, TestExtras
using LuxTestUtils: check_approx

function setup_groupnorm(rng, aType, T, sz, affine)
Expand Down Expand Up @@ -58,12 +58,12 @@ function run_groupnorm_testing(T, sz, groups, affine, act, aType, mode, ongpu)
@test ∂bias≈∂bias_simple atol=atol rtol=rtol
end

@test @inferred(groupnorm(x, scale, bias, groups, act, epsilon)) isa Any
@constinferred groupnorm(x, scale, bias, groups, act, epsilon)
@jet groupnorm(x, scale, bias, groups, act, epsilon)

if anonact !== act
lfn = (x, sc, b, g, act, ϵ) -> sum(groupnorm(x, sc, b, g, act, ϵ))
@test @inferred(Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)) isa Any
@constinferred Zygote.gradient(lfn, x, scale, bias, groups, act, epsilon)
end

@test y isa aType{T, length(sz)}
Expand Down
12 changes: 5 additions & 7 deletions lib/LuxLib/test/normalization/instancenorm_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testsetup module InstanceNormSetup
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib
using LuxLib, LuxTestUtils, Random, Test, Zygote, NNlib, TestExtras

is_training(::Val{training}) where {training} = training

Expand All @@ -24,12 +24,12 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType)
atol = 1.0f-2
rtol = 1.0f-2

@test @inferred(instancenorm(x, scale, bias, training, act, epsilon)) isa Any
@constinferred instancenorm(x, scale, bias, training, act, epsilon)
@jet instancenorm(x, scale, bias, training, act, epsilon)

if anonact !== act && is_training(training)
lfn = (x, sc, b, act, ϵ) -> sum(first(instancenorm(x, sc, b, Val(true), act, ϵ)))
@test @inferred(Zygote.gradient(lfn, x, scale, bias, act, epsilon)) isa Any
@constinferred Zygote.gradient(lfn, x, scale, bias, act, epsilon)
end

@test y isa aType{T, length(sz)}
Expand All @@ -46,15 +46,13 @@ function run_instancenorm_testing(gen_f, T, sz, training, act, aType)

y, nt = instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon)

@test @inferred(instancenorm(
x, scale, bias, rm, rv, training, act, T(0.1), epsilon)) isa Any
@constinferred instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon)
@jet instancenorm(x, scale, bias, rm, rv, training, act, T(0.1), epsilon)

if anonact !== act && is_training(training)
lfn = (x, sc, b, rm, rv, act, m, ϵ) -> sum(first(instancenorm(
x, sc, b, rm, rv, Val(true), act, m, ϵ)))
@test @inferred(Zygote.gradient(
lfn, x, scale, bias, rm, rv, act, T(0.1), epsilon)) isa Any
@constinferred Zygote.gradient(lfn, x, scale, bias, rm, rv, act, T(0.1), epsilon)
end

@test y isa aType{T, length(sz)}
Expand Down
Loading
Loading