Skip to content

Commit

Permalink
Return init params on compile (#400)
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim authored Nov 15, 2022
1 parent 8d32ba8 commit bcf675f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 13 deletions.
11 changes: 4 additions & 7 deletions examples/generative/mnist_gan.exs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,7 @@ defmodule MNISTGAN do
|> Nx.divide(Nx.add(i, 1))
end

defn init(init_d, init_g, init_optim_d, init_optim_g) do
d_params = init_d.(%{})
g_params = init_g.(%{})

defn init(d_params, g_params, init_optim_d, init_optim_g) do
%{
iteration: Nx.tensor(0),
discriminator: %{
Expand Down Expand Up @@ -135,11 +132,11 @@ defmodule MNISTGAN do
{init_optim_d, optim_d} = Axon.Optimizers.adam(2.0e-3, b1: 0.5)
{init_optim_g, optim_g} = Axon.Optimizers.adam(2.0e-3, b1: 0.5)

{d_init, d_model} = Axon.compile(d_model, mode: :train)
{g_init, g_model} = Axon.compile(g_model, mode: :train)
{d_init_params, d_model} = Axon.compile(d_model, mode: :train)
{g_init_params, g_model} = Axon.compile(g_model, mode: :train)

step = &batch_step(d_model, g_model, optim_d, optim_g, &1, &2)
init = fn %{} -> init(d_init, g_init, init_optim_d, init_optim_g) end
init = fn %{} -> init(d_init_params, g_init_params, init_optim_d, init_optim_g) end

Axon.Loop.loop(step, init)
end
Expand Down
9 changes: 3 additions & 6 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3231,12 +3231,9 @@ defmodule Axon do
@doc type: :model
def compile(model, template, init_params \\ %{}, opts \\ []) when is_list(opts) do
{init_fn, predict_fn} = build(model, opts)
init_compiled_fn = Nx.Defn.compile(init_fn, [template, init_params], opts)

predict_compiled_fn =
Nx.Defn.compile(predict_fn, [init_compiled_fn.(template, init_params), template], opts)

{init_compiled_fn, predict_compiled_fn}
init_params = Nx.Defn.jit_apply(init_fn, [template, init_params], opts)
predict_compiled_fn = Nx.Defn.compile(predict_fn, [init_params, template], opts)
{init_params, predict_compiled_fn}
end

@doc """
Expand Down

0 comments on commit bcf675f

Please sign in to comment.