From ff58b31c00dd4de255209884e20601ba5cc0620a Mon Sep 17 00:00:00 2001 From: Andrei Zhabinski Date: Sat, 5 Aug 2023 22:54:06 +0300 Subject: [PATCH] Fix device of seed in the :auto mode --- Project.toml | 2 +- src/grad.jl | 7 ++----- src/helpers.jl | 2 +- test/test_grad.jl | 11 +++++++++++ 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index a153a46..01c3f07 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Yota" uuid = "cd998857-8626-517d-b929-70ad188a48f0" authors = ["Andrei Zhabinski "] -version = "0.8.4" +version = "0.8.5" [deps] ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" diff --git a/src/grad.jl b/src/grad.jl index 3426f2d..e881cac 100644 --- a/src/grad.jl +++ b/src/grad.jl @@ -132,8 +132,6 @@ with the following chain or calls: where `val = fn(args...)` and `pb` is the pullback function. """ function chainrules_transform!(tape::Tape) - # global TAPE = tape - # error("") i = 1 while i <= length(tape) # tape[V(i)] isa Call && tape[V(i)].fn == Core.kwcall && break @@ -183,7 +181,6 @@ function step_back!(tape::Tape, y::Variable) end for (i, x) in enumerate(y_fargs) if x isa V - global STATE = (tape, y, y_fargs, i, x) dx = push!(tape, mkcall(getfield, dxs, i; line="d$y/d$x")) # @debug "Updating derivative: $x -> $dx" set_or_add_deriv!(tape, x, dx) @@ -208,8 +205,8 @@ function back!(tape::Tape; seed=1) error("Gradient of a vector-valued function requires a seed") elseif seed == :auto zval = tape[z].val - # @assert zval isa Number || zval isa AbstractArray - seed = zval isa AbstractArray ? ones(eltype(zval), size(zval)) : one(zval) + @assert zval isa Number || zval isa AbstractArray + seed = zval isa AbstractArray ? array_like(1, zval, size(zval)) : one(zval) end dy = push!(tape, Constant(seed; line="seed for $(tape[V(1)].val)")) # save seed var to use in compilation later diff --git a/src/helpers.jl b/src/helpers.jl index 58033b1..d9a15b7 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -88,7 +88,7 @@ end unbroadcast_prod_y(x::ArrayOrBroadcasted, y::ArrayOrBroadcasted, Δ) = unbroadcast_prod_x(y, x, Δ) # device_like(example, a) = (device = guess_device([example]); device(a)) -array_like(value, example) = fill!(similar(example, (1,)), value) +array_like(value, example, sz=(1,)) = fill!(similar(example, sz), value) # unbroadcast_prod_x(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_x(device_like(y, [x]), y, Δ)[1] unbroadcast_prod_x(x::Number, y::ArrayOrBroadcasted, Δ) = unbroadcast_prod_x(array_like(x, y), y, Δ)[1] diff --git a/test/test_grad.jl b/test/test_grad.jl index 7d8e29f..41e0c5c 100644 --- a/test/test_grad.jl +++ b/test/test_grad.jl @@ -254,6 +254,17 @@ end val, g = grad(x -> 2x, [1.0, 2.0, 3.0]; seed=ones(3)) @test val == [2.0, 4.0, 6.0] @test g == (ZeroTangent(), [2.0, 2.0, 2.0]) + + val, g = grad(x -> 2x, [1.0, 2.0, 3.0]; seed=:auto) + @test val == [2.0, 4.0, 6.0] + @test g == (ZeroTangent(), [2.0, 2.0, 2.0]) + + if CUDA.functional() + CUDA.allowscalar(false) + val, g = grad(x -> 2x, cu([1.0, 2.0, 3.0]); seed=:auto) + @test val == cu([2.0, 4.0, 6.0]) + @test g == (ZeroTangent(), cu([2.0, 2.0, 2.0])) + end end