Skip to content

Commit

Permalink
Merge pull request #141 from dfdx/fix/auto-seed-on-cuda
Browse files Browse the repository at this point in the history
Fix device of seed in the :auto mode
  • Loading branch information
dfdx authored Aug 5, 2023
2 parents ea4f470 + ff58b31 commit 54dc712
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Yota"
uuid = "cd998857-8626-517d-b929-70ad188a48f0"
authors = ["Andrei Zhabinski <andrei.zhabinski@gmail.com>"]
version = "0.8.4"
version = "0.8.5"

[deps]
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
Expand Down
7 changes: 2 additions & 5 deletions src/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,6 @@ with the following chain or calls:
where `val = fn(args...)` and `pb` is the pullback function.
"""
function chainrules_transform!(tape::Tape)
# global TAPE = tape
# error("")
i = 1
while i <= length(tape)
# tape[V(i)] isa Call && tape[V(i)].fn == Core.kwcall && break
Expand Down Expand Up @@ -183,7 +181,6 @@ function step_back!(tape::Tape, y::Variable)
end
for (i, x) in enumerate(y_fargs)
if x isa V
global STATE = (tape, y, y_fargs, i, x)
dx = push!(tape, mkcall(getfield, dxs, i; line="d$y/d$x"))
# @debug "Updating derivative: $x -> $dx"
set_or_add_deriv!(tape, x, dx)
Expand All @@ -208,8 +205,8 @@ function back!(tape::Tape; seed=1)
error("Gradient of a vector-valued function requires a seed")
elseif seed == :auto
zval = tape[z].val
# @assert zval isa Number || zval isa AbstractArray
seed = zval isa AbstractArray ? ones(eltype(zval), size(zval)) : one(zval)
@assert zval isa Number || zval isa AbstractArray
seed = zval isa AbstractArray ? array_like(1, zval, size(zval)) : one(zval)
end
dy = push!(tape, Constant(seed; line="seed for $(tape[V(1)].val)"))
# save seed var to use in compilation later
Expand Down
2 changes: 1 addition & 1 deletion src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ end
unbroadcast_prod_y(x::ArrayOrBroadcasted, y::ArrayOrBroadcasted, Δ) = unbroadcast_prod_x(y, x, Δ)

# device_like(example, a) = (device = guess_device([example]); device(a))
array_like(value, example) = fill!(similar(example, (1,)), value)
array_like(value, example, sz=(1,)) = fill!(similar(example, sz), value)

# unbroadcast_prod_x(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_x(device_like(y, [x]), y, Δ)[1]
unbroadcast_prod_x(x::Number, y::ArrayOrBroadcasted, Δ) = unbroadcast_prod_x(array_like(x, y), y, Δ)[1]
Expand Down
11 changes: 11 additions & 0 deletions test/test_grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,17 @@ end
val, g = grad(x -> 2x, [1.0, 2.0, 3.0]; seed=ones(3))
@test val == [2.0, 4.0, 6.0]
@test g == (ZeroTangent(), [2.0, 2.0, 2.0])

val, g = grad(x -> 2x, [1.0, 2.0, 3.0]; seed=:auto)
@test val == [2.0, 4.0, 6.0]
@test g == (ZeroTangent(), [2.0, 2.0, 2.0])

if CUDA.functional()
CUDA.allowscalar(false)
val, g = grad(x -> 2x, cu([1.0, 2.0, 3.0]); seed=:auto)
@test val == cu([2.0, 4.0, 6.0])
@test g == (ZeroTangent(), cu([2.0, 2.0, 2.0]))
end
end


Expand Down

0 comments on commit 54dc712

Please sign in to comment.