Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spring cleaning #80

Merged
merged 6 commits into from
Mar 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
CUDA = "1.2, 2.3"
CUDA = "1.2, 2"
ChainRulesCore = "0.9.5"
Distributions = "0.23.2"
Espresso = "0.6.0"
IRTools = "0.4.0"
IRTools = "0.4"
julia = "1.4"
65 changes: 0 additions & 65 deletions src/alloc.jl

This file was deleted.

4 changes: 2 additions & 2 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ include("compile.jl")
include("update.jl")
include("transform.jl")
include("cuda.jl")
include("gradcheck.jl")


const BEST_AVAILABLE_DEVICE = Ref{AbstractDevice}(CPU())

if CUDA.functional()
try
BEST_AVAILABLE_DEVICE[] = GPU(1)
BEST_AVAILABLE_DEVICE[] = GPU(1)
catch ex
# something is wrong with the user's set-up (or there's a bug in CuArrays)
@warn "CUDA is installed, but not working properly" exception=(ex,catch_backtrace())

end
end


best_available_device() = BEST_AVAILABLE_DEVICE[]
35 changes: 1 addition & 34 deletions src/cuda.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
# import CUDAnative
# using CuArrays


# CuArrays.cufunc(::typeof(^)) = CUDAnative.pow

@diffrule CUDA.exp(u::Real) u CUDA.exp(u) * dy
@diffrule CUDA.pow(u::Real, v::Real) u (v * CUDA.pow(u, (v-1)) * dy)
Expand Down Expand Up @@ -44,32 +39,4 @@ function to_device(device::GPU, x)
fld_vals = [to_device(device, getfield(x, fld)) for fld in flds]
return T(fld_vals...)
end
end


# function cuarray_compatible_tform(tape::Tape)
# new_tape = similar(tape)
# changed = false
# for op in tape
# if op isa Call && haskey(CUDANATIVE_OPS, op.fn)
# changed = true
# push!(new_tape, copy_with(op, fn=CUDANATIVE_OPS[op.fn]))
# else
# push!(new_tape, op)
# end
# end
# return new_tape, changed
# end


# """
# Transform function to CuArrays compatible.
# """
# function cuda_compatible(f, args)
# cf = CuArrays.cufunc(f)
# if f === cf
# return cf
# else
# return transform(cuarray_compatible_tform, f, args)
# end
# end
end
4 changes: 4 additions & 0 deletions src/devices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,9 @@ to_device(device::CPU, f::Function, args) = f
(device::CPU)(x) = to_device(device, x)
(device::GPU)(x) = to_device(device, x)

to_cpu(A) = A
to_cpu(A::CuArray) = convert(Array, A)
to_cuda(A) = cu(A)
to_cuda(A::CuArray) = A

to_same_device(A, example) = device_of(example)(A)
23 changes: 17 additions & 6 deletions src/diffrules/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,6 @@
@diffrule __getfield__(_s::Tuple, _f) _s ∇getfield(dy, _s, _f)
@nodiff __getfield__(_s, _f::Tuple) _f

# @nodiff __new__(t, u) t
# @nodiff __new__(t, u, v) t
# @nodiff __new__(t, u, v, w) t
# @nodiff __new__(t, u, v) t
# @diffrule __new__(t, u, v) u ∇__new__(dy, t, 1)
# @diffrule __new__(t, u, v) v ∇__new__(dy, t, 2)

const LONG_VAR_NAME_LIST = (:x, :u, :v, :w, :_a, :_b, :_c, :_d, :_e, :_f, :_g)
for n=1:length(LONG_VAR_NAME_LIST)
Expand All @@ -149,6 +143,23 @@ for n=1:length(LONG_VAR_NAME_LIST)
end
end

# Base.iterate

# here we explicitely stop propagation in iteration
# over ranges (e.g for i=1:3 ... end)
@nodiff Base.iterate(x::UnitRange) x
@nodiff Base.iterate(x::UnitRange, i::Int) x
@nodiff Base.iterate(x::UnitRange, i::Int) i

@diffrule Base.iterate(t::Tuple) t ∇getfield(getindex(dy, 1), t, 1)
@diffrule Base.iterate(t::Tuple, i::Int) t ∇getfield(getindex(dy, 1), t, i)
@nodiff Base.iterate(t::Tuple, i::Int) i

@diffrule Base.iterate(x::AbstractArray) x ungetindex(x, dy, 1)
@diffrule Base.iterate(x::AbstractArray, i::Int) x ungetindex(x, dy, i)
@nodiff Base.iterate(x::AbstractArray, i::Int) i


# Base.indexed_iterate (tuple unpacking)

@diffrule Base.indexed_iterate(t::Tuple, i::Int) t ∇getfield(getindex(dy, 1), t, i)
Expand Down
14 changes: 7 additions & 7 deletions src/diffrules/diffrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const DIFF_PHS = Set([:x, :u, :v, :w, :t, :i, :j, :k,])
function reset_rules!()
empty!(DIFF_RULES)
empty!(NO_DIFF_RULES)
empty!(CONSTRUCTORS)
empty!(CONSTRUCTORS)
end


Expand Down Expand Up @@ -142,7 +142,7 @@ end
Define a type constructor that should not be traced, but instead recorded
to the tape as is. Here's an example:

@ctor MvNormal(μ, Σ)
@ctor MvNormal(μ, Σ)

This should be read as:

Expand All @@ -158,7 +158,7 @@ Yota will _completely bypass_ internals of the constructor and jump directly to
1st variable passed to MvNormal().

Note that if you don't want to bypass the constructor (which you normally shouldn't do),
you can rely on Yota handling it automatically.
you can rely on Yota handling it automatically.

"""
macro ctor(ex)
Expand Down Expand Up @@ -256,7 +256,7 @@ function deriv_exprs(ex, dep_types, idx::Int)
result = Tuple[] # list of tuples (rewritten_expr, field_name | nothing)
for rule in DIFF_RULES
m, fldname = match_rule(rule, ex, dep_types, idx)
if m != nothing
if m !== nothing
pat, rpat = m
rex = rewrite_with_keywords(ex, pat, rpat)
push!(result, (rex, fldname))
Expand All @@ -270,9 +270,9 @@ function deriv_exprs(ex, dep_types, idx::Int)
# record it to the tape as constant; instead we must rewrite forward pass as well
# replacing all f(args...) with _, pb = rrule(f, args...)
# which breaks significant part of codebase
# @assert Meta.isexpr(ex, :call)
# @assert Meta.isexpr(ex, :call)
# _, df = rrule(ex.args[1], 2.0)
# dex = Expr(:call, df, )
# dex = Expr(:call, df, )
error("Can't find differentiation rule for $ex at $idx " *
"with types $dep_types)")
end
Expand Down Expand Up @@ -331,7 +331,7 @@ function dont_diff(tape::Tape, op::AbstractOp, idx::Int)
dep_types = [eltype(tape[arg].typ) for arg in op.args[2:end]]
idx_ = idx - 1
else
ex = to_expr(tape, op)
ex = to_expr(tape, op)
dep_types = [tape[arg].typ for arg in op.args]
idx_ = idx
end
Expand Down
29 changes: 2 additions & 27 deletions src/grad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,6 @@
# GRAD RESULT #
########################################################################

# function field_paths(tape::Tape)
# paths = Dict()
# for op in reverse(tape.ops)
# _op = op
# path = []
# while _op isa Call && _op.fn in (Base.getproperty,
# Base.getfield,
# __getfield__)
# field_name = tape[_op.args[2]].val
# push!(path, field_name)
# _op_id = _op.args[1]
# _op = tape[_op_id]
# end
# if !isempty(path)
# struct_id = _op.id
# if !haskey(paths, struct_id)
# paths[struct_id] = Dict()
# end
# paths[struct_id][(reverse(path)...,)] = op.id
# end
# end
# return paths
# end


struct GradResult
tape::Tape
gvars::Vector{Any} # gradient vars
Expand Down Expand Up @@ -89,14 +64,14 @@ function deriv!(tape::Tape, op::AbstractOp, i::Int, dy::AbstractOp)
st = Dict(Symbol("%$i") => i for i in op.args)
st[:dy] = dy.id
st[:y] = op.id
if dex_fldnames[1][2] == nothing
if dex_fldnames[1][2] === nothing
# not a derivative of a field - take only the 1st match
dex_fldnames = dex_fldnames[1:1]
end
op_deriv_attrs = Tuple[]
for (dex, fldname) in dex_fldnames
ret_id = record_expr!(tape, dex; st=st)
derivative_of = (fldname == nothing ? tape[op.args[i]] :
derivative_of = (fldname === nothing ? tape[op.args[i]] :
field_var_from_ctor_op(tape, tape[op.args[i]], fldname))
push!(op_deriv_attrs, (tape[ret_id], derivative_of))
end
Expand Down
10 changes: 1 addition & 9 deletions test/gradcheck.jl → src/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,6 @@ function ngradient(f, xs::AbstractArray...)
end


# function gradcheck(f, xs...)
# n_grads = ngradient(f, xs...)
# y_grads = Yota._grad(f, xs...)[2] |> collect
# all(isapprox.(n_grads, y_grads, rtol = 1e-5, atol = 1e-5))
# end



function ngradient2(f, xs, n)
x = xs[n]
Δ = zero(x)
Expand All @@ -42,7 +34,7 @@ end


function gradcheck2(f, xs...; var_idxs=1:length(xs))
y_grads = Yota._grad(f, xs...)[2] |> collect
y_grads = _grad(f, xs...)[2] |> collect
results = []
for n in var_idxs
n_grad = ngradient2(f, xs, n)
Expand Down
26 changes: 25 additions & 1 deletion src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,32 @@ function ∇__new__(dy, T, idx)
end


# TODO (when https://github.com/FluxML/NNlib.jl/pull/296 is done):
# 1. uncomment this implementation
# 2. remove the next 2 functions
#
# function ungetindex!(dx::AbstractArray, ::AbstractArray, dy, I...)
# if dy isa Number
# dy = to_same_device([dy], dx)
# end
# I = Any[i for i in I]
# for (d, i) in enumerate(I)
# if i == (:)
# I[d] = 1:size(dx, d)
# end
# if i isa Number
# I[d] = [i]
# end
# I[d] = to_cpu(I[d])
# end
# # cartesian product of all concerned indices
# idx = collect(Iterators.product(I...))
# idx = to_same_device(idx, dx)
# return scatter!(+, dx, dy, idx)
# end

function ungetindex!(dx::AbstractArray, ::AbstractArray, dy, I...)
return Scatter.scatter_add2!(dx, dy, I...)
return Scatter.scatter_add2!(dx, dy, I...)
end


Expand Down
Loading