Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues with mismatched templates #450

Merged
merged 1 commit into from
Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 20 additions & 24 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ defmodule Axon.Loop do
loss: Nx.tensor(0.0),
gradient_step: Nx.tensor(0),
model_state: model_state,
gradient_state: zeros_like(model_state),
gradient_state: zeros_like(model_state, type: :f32),
optimizer_state: optimizer_state,
loss_scale_state: loss_scale_state
}
Expand Down Expand Up @@ -458,30 +458,26 @@ defmodule Axon.Loop do
opts = keyword!(opts, [:steps])
steps = opts[:steps]

# TODO: temporarily disabled
# while {gradients, model_state, new_state, optimizer_state, gradient_state, gradient_step,
# flag = Nx.tensor(1)},
# flag do
# if Nx.greater_equal(gradient_step, steps - 1) do
{_, new_model_state, _, new_optimizer_state, new_gradient_state, new_gradient_step, _} =
(
{updates, new_optimizer_state} =
update_optimizer_fn.(gradients, optimizer_state, model_state)

new_gradient_state = zeros_like(model_state)
new_model_state = Axon.Updates.apply_updates(model_state, updates, new_state)

{gradients, new_model_state, new_state, new_optimizer_state, new_gradient_state, 0,
Nx.tensor(0)}
)

# else
# acc_gradients = deep_merge(gradient_state, gradients, fn x, y -> x + y end)

# {gradients, model_state, new_state, optimizer_state, acc_gradients, gradient_step + 1,
# Nx.tensor(0)}
# end
# end
while {gradients, model_state, new_state, optimizer_state, gradient_state, gradient_step,
flag = Nx.tensor(1)},
flag do
if Nx.greater_equal(gradient_step, steps - 1) do
{updates, new_optimizer_state} =
update_optimizer_fn.(gradients, optimizer_state, model_state)

new_gradient_state = zeros_like(model_state)
new_model_state = Axon.Updates.apply_updates(model_state, updates, new_state)

{gradients, new_model_state, new_state, new_optimizer_state, new_gradient_state, 0,
Nx.tensor(0)}
else
acc_gradients = deep_merge(gradient_state, gradients, fn x, y -> x + y end)

{gradients, model_state, new_state, optimizer_state, acc_gradients, gradient_step + 1,
Nx.tensor(0)}
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, both branches return Nx.tensor(0). Doesn't it mean you ever only do a single pass anyway?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I'm pretty sure this is how I got around larger graphs with if, iirc this reduces memory usage slightly

end

{new_model_state, new_optimizer_state, new_gradient_state, new_gradient_step}
end
Expand Down
17 changes: 9 additions & 8 deletions lib/axon/shared.ex
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,15 @@ defmodule Axon.Shared do
Creates a zeros-like structure which matches the structure
of the input.
"""
defn zeros_like(params) do
transform(
params,
&deep_new(&1, fn x ->
fun = Axon.Initializers.zeros()
fun.(Nx.shape(x), Nx.type(x))
end)
)
deftransform zeros_like(params, opts \\ []) do
opts = Keyword.validate!(opts, [:type])
fun = Axon.Initializers.zeros()

deep_new(params, fn x ->
type = opts[:type] || Nx.type(x)
fun = Axon.Initializers.zeros()
fun.(Nx.shape(x), type)
end)
end

@doc """
Expand Down
18 changes: 9 additions & 9 deletions lib/axon/updates.ex
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ defmodule Axon.Updates do
end

defnp init_my_update(params) do
state = zeros_like(params)
state = zeros_like(params, type: :f32)
%{state: state}
end

Expand Down Expand Up @@ -165,8 +165,8 @@ defmodule Axon.Updates do
end

defnp init_scale_by_adam(params) do
mus = zeros_like(params)
nus = zeros_like(params)
mus = zeros_like(params, type: :f32)
nus = zeros_like(params, type: :f32)
count = Nx.tensor(0)
%{mu: mus, nu: nus, count: count}
end
Expand Down Expand Up @@ -333,8 +333,8 @@ defmodule Axon.Updates do
end

defnp init_scale_by_belief(params) do
mus = zeros_like(params)
nus = zeros_like(params)
mus = zeros_like(params, type: :f32)
nus = zeros_like(params, type: :f32)
count = Nx.tensor(0)
%{mu: mus, nu: nus, count: count}
end
Expand Down Expand Up @@ -394,7 +394,7 @@ defmodule Axon.Updates do
end

defnp init_scale_by_stddev(params, value) do
mu = zeros_like(params)
mu = zeros_like(params, type: :f32)
nu = fulls_like(params, value)
%{mu: mu, nu: nu}
end
Expand Down Expand Up @@ -486,8 +486,8 @@ defmodule Axon.Updates do
end

defnp init_scale_by_radam(params) do
mu = zeros_like(params)
nu = zeros_like(params)
mu = zeros_like(params, type: :f32)
nu = zeros_like(params, type: :f32)
count = Nx.tensor(0)
%{mu: mu, nu: nu, count: count}
end
Expand Down Expand Up @@ -564,7 +564,7 @@ defmodule Axon.Updates do
end

defnp init_trace(params) do
trace = zeros_like(params)
trace = zeros_like(params, type: :f32)
%{trace: trace}
end

Expand Down