diff --git a/nx/lib/nx/serving.ex b/nx/lib/nx/serving.ex index 3dd5272f1b..20ecd777f6 100644 --- a/nx/lib/nx/serving.ex +++ b/nx/lib/nx/serving.ex @@ -56,10 +56,12 @@ defmodule Nx.Serving do When defining a `Nx.Serving`, we can also customize how the data is batched by using the `client_preprocessing` as well as the result by - using `client_postprocessing` hooks. Let's give it another try: + using `client_postprocessing` hooks. Let's give it another try, + this time using `jit/2` to create the serving, which automatically + wraps the given function in `Nx.Defn.jit/2` for us: iex> serving = ( - ...> Nx.Serving.new(fn opts -> Nx.Defn.jit(&MyDefn.print_and_multiply/1, opts) end) + ...> Nx.Serving.jit(&MyDefn.print_and_multiply/1) ...> |> Nx.Serving.client_preprocessing(fn input -> {Nx.Batch.stack(input), :client_info} end) ...> |> Nx.Serving.client_postprocessing(&{&1, &2, &3}) ...> ) @@ -111,7 +113,7 @@ defmodule Nx.Serving do children = [ {Nx.Serving, - serving: Nx.Serving.new(fn opts -> Nx.Defn.jit(&MyDefn.print_and_multiply/1, opts) end), + serving: Nx.Serving.jit(&MyDefn.print_and_multiply/1), name: MyServing, batch_size: 10, batch_timeout: 100} @@ -370,6 +372,18 @@ defmodule Nx.Serving do new(module, arg, []) end + @doc """ + Creates a new serving by jitting the given `fun` with `defn_options`. + + This is equivalent to: + + new(fn opts -> Nx.Defn.jit(fun, opts) end, defn_options) + + """ + def jit(fun, defn_options \\ []) do + new(fn opts -> Nx.Defn.jit(fun, opts) end, defn_options) + end + @doc """ Creates a new module-based serving. diff --git a/nx/test/nx/serving_test.exs b/nx/test/nx/serving_test.exs index 505816a06c..252db5768a 100644 --- a/nx/test/nx/serving_test.exs +++ b/nx/test/nx/serving_test.exs @@ -53,12 +53,8 @@ defmodule Nx.ServingTest do assert Nx.Serving.run(serving, batch) == Nx.tensor([[2, 4, 6]]) end - test "with container" do - serving = - Nx.Serving.new(fn opts -> - Nx.Defn.jit(fn {a, b} -> {Nx.multiply(a, 2), Nx.divide(b, 2)} end, opts) - end) - + test "with container (and jit)" do + serving = Nx.Serving.jit(fn {a, b} -> {Nx.multiply(a, 2), Nx.divide(b, 2)} end) batch = Nx.Batch.concatenate([{Nx.tensor([1, 2, 3]), Nx.tensor([4, 5, 6])}]) assert Nx.Serving.run(serving, batch) == {Nx.tensor([2, 4, 6]), Nx.tensor([2, 2.5, 3])} end