diff --git a/exla/lib/exla/defn.ex b/exla/lib/exla/defn.ex index 652a5fefc7..b95b91aa4a 100644 --- a/exla/lib/exla/defn.ex +++ b/exla/lib/exla/defn.ex @@ -32,13 +32,8 @@ defmodule EXLA.Defn do @doc false def __stream__(key, input, acc, vars, fun, [args], options) do - {debug?, options} = Keyword.pop(options, :debug, false) {run_options, compile_options} = Keyword.pop(options, :run_options, []) - - {client_name, compile_options} = - Keyword.pop_lazy(compile_options, :client, &EXLA.Client.default_name/0) - - client = EXLA.Client.fetch!(client_name) + debug? = Keyword.get(compile_options, :debug, false) compile_options = Keyword.put(compile_options, :lazy_transfers, :never) input_length = length(Nx.Defn.Composite.flatten_list([input])) @@ -50,21 +45,10 @@ defmodule EXLA.Defn do used_inputs = Enum.to_list(input_length..(input_length + acc_length - 1)//1) comp_fun = - &to_stream_computation(client, input_length, acc_length, &1, &2, &3, &4, compile_options) + &to_stream_computation(input_length, acc_length, &1, &2, &3, &4, &5, compile_options) {executable, {used_inputs, {output, acc_output}, outfeed, input_typespecs}} = - compile( - client, - key, - vars, - fun, - compile_options, - used_buffers, - used_inputs, - _stream = true, - debug?, - comp_fun - ) + compile(key, vars, fun, compile_options, used_buffers, used_inputs, true, comp_fun) # Now discard the infeed from used inputs, similar to how it is done to buffers. # Note we discard all lazy transfers too, as they are not possible with streams. @@ -136,13 +120,13 @@ defmodule EXLA.Defn do end defp to_stream_computation( - client, input_length, acc_length, %Function{} = builder, expr, used_typespecs, outfeed, + client, options ) do %{token: root_token, infeeds: []} = outfeed @@ -237,18 +221,12 @@ defmodule EXLA.Defn do @doc false def __compile__(key, vars, fun, options) do - {debug?, options} = Keyword.pop(options, :debug, false) {run_options, compile_options} = Keyword.pop(options, :run_options, []) - - {client_name, compile_options} = - Keyword.pop_lazy(compile_options, :client, &EXLA.Client.default_name/0) - - client = EXLA.Client.fetch!(client_name) - - callback = &to_root_computation(&1, &2, &3, &4, Keyword.put(compile_options, :client, client)) + debug? = Keyword.get(compile_options, :debug, false) + callback = &to_root_computation(&1, &2, &3, &4, &5, compile_options) {executable, {used_inputs, outputs, outfeed, _input_typespecs?}} = - compile(client, key, vars, fun, compile_options, 0, [], _stream = false, debug?, callback) + compile(key, vars, fun, compile_options, 0, [], _stream = false, callback) fn [args] -> {time, lock} = @@ -270,14 +248,12 @@ defmodule EXLA.Defn do end end - defp to_root_computation(%Function{} = function, expr, used_typespecs, outfeed, options) do + defp to_root_computation(%Function{} = function, expr, used_typespecs, outfeed, client, options) do params = Enum.zip_with(used_typespecs, Function.get_arguments(function), fn {pos, _typespec}, arg -> {pos, arg} end) - client = Keyword.fetch!(options, :client) - unless client do raise ArgumentError, "missing client" end @@ -342,22 +318,15 @@ defmodule EXLA.Defn do ## Compile - defp compile( - client, - key, - vars, - fun, - options, - used_buffers, - used_inputs, - stream?, - debug?, - to_computation - ) do + defp compile(key, vars, fun, options, used_buffers, used_inputs, stream?, to_computation) do {cache, options} = Keyword.pop(options, :cache, true) {hooks, options} = Keyword.pop(options, :hooks, %{}) + {debug?, options} = Keyword.pop(options, :debug, false) {lazy_transfers, options} = Keyword.pop(options, :lazy_transfers, :opt_in) + {client_name, options} = Keyword.pop_lazy(options, :client, &EXLA.Client.default_name/0) + client = EXLA.Client.fetch!(client_name) + {args_key, reverse_args_identifiers} = Enum.map_reduce(vars, [], fn var, acc -> Nx.Defn.Composite.traverse(var, acc, fn @@ -453,7 +422,7 @@ defmodule EXLA.Defn do end expr = Nx.Defn.Composite.traverse(expr || fun.(vars), &Nx.devectorize/1) - outfeed = to_computation.(builder, expr, inputs_and_typespecs, outfeed) + outfeed = to_computation.(builder, expr, inputs_and_typespecs, outfeed, client) {xla_time, executable} = :timer.tc(fn ->