Skip to content

Commit

Permalink
Introduce Nx.Serving.jit/2 as a convenient shortcut (#1252)
Browse files Browse the repository at this point in the history
  • Loading branch information
josevalim authored Jun 14, 2023
1 parent 3240ffb commit 13213c4
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
20 changes: 17 additions & 3 deletions nx/lib/nx/serving.ex
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ defmodule Nx.Serving do
When defining a `Nx.Serving`, we can also customize how the data is
batched by using the `client_preprocessing` as well as the result by
using `client_postprocessing` hooks. Let's give it another try:
using `client_postprocessing` hooks. Let's give it another try,
this time using `jit/2` to create the serving, which automatically
wraps the given function in `Nx.Defn.jit/2` for us:
iex> serving = (
...> Nx.Serving.new(fn opts -> Nx.Defn.jit(&MyDefn.print_and_multiply/1, opts) end)
...> Nx.Serving.jit(&MyDefn.print_and_multiply/1)
...> |> Nx.Serving.client_preprocessing(fn input -> {Nx.Batch.stack(input), :client_info} end)
...> |> Nx.Serving.client_postprocessing(&{&1, &2, &3})
...> )
Expand Down Expand Up @@ -111,7 +113,7 @@ defmodule Nx.Serving do
children = [
{Nx.Serving,
serving: Nx.Serving.new(fn opts -> Nx.Defn.jit(&MyDefn.print_and_multiply/1, opts) end),
serving: Nx.Serving.jit(&MyDefn.print_and_multiply/1),
name: MyServing,
batch_size: 10,
batch_timeout: 100}
Expand Down Expand Up @@ -370,6 +372,18 @@ defmodule Nx.Serving do
new(module, arg, [])
end

@doc """
Creates a new serving by jitting the given `fun` with `defn_options`.
This is equivalent to:
new(fn opts -> Nx.Defn.jit(fun, opts) end, defn_options)
"""
def jit(fun, defn_options \\ []) do
new(fn opts -> Nx.Defn.jit(fun, opts) end, defn_options)
end

@doc """
Creates a new module-based serving.
Expand Down
8 changes: 2 additions & 6 deletions nx/test/nx/serving_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,8 @@ defmodule Nx.ServingTest do
assert Nx.Serving.run(serving, batch) == Nx.tensor([[2, 4, 6]])
end

test "with container" do
serving =
Nx.Serving.new(fn opts ->
Nx.Defn.jit(fn {a, b} -> {Nx.multiply(a, 2), Nx.divide(b, 2)} end, opts)
end)

test "with container (and jit)" do
serving = Nx.Serving.jit(fn {a, b} -> {Nx.multiply(a, 2), Nx.divide(b, 2)} end)
batch = Nx.Batch.concatenate([{Nx.tensor([1, 2, 3]), Nx.tensor([4, 5, 6])}])
assert Nx.Serving.run(serving, batch) == {Nx.tensor([2, 4, 6]), Nx.tensor([2, 2.5, 3])}
end
Expand Down

0 comments on commit 13213c4

Please sign in to comment.