Skip to content

Commit

Permalink
feat: prune tensor expr inspect results (#1272)
Browse files Browse the repository at this point in the history
Co-authored-by: José Valim <jose.valim@dashbit.co>
  • Loading branch information
polvalente and josevalim authored Jul 27, 2023
1 parent d670b29 commit 101be51
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 3 deletions.
29 changes: 26 additions & 3 deletions nx/lib/nx/defn/expr.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1566,7 +1566,14 @@ defmodule Nx.Defn.Expr do

def inspect(tensor, opts) do
{_, %{exprs: exprs, parameters: parameters, length: length}} =
recur_inspect(tensor, %{cache: %{}, exprs: [], parameters: [], opts: opts, length: 0})
recur_inspect(tensor, %{
cache: %{},
exprs: [],
parameters: [],
opts: opts,
length: 0,
limit: opts.limit
})

concat(line(), color("Nx.Defn.Expr", :map, opts))
|> append_lines(parameters, length + 3)
Expand All @@ -1591,12 +1598,25 @@ defmodule Nx.Defn.Expr do
end

defp recur_inspect(%T{data: %Expr{id: id, op: op, args: args}} = tensor, state) do
case state.cache do
%{limit: limit, cache: cache} = state

case cache do
%{^id => var_name} ->
{var_name, state}

%{} when limit == 0 ->
state = state |> decrement_limit(limit) |> store_line(:parameters, "...", "")
var_name = var_name(state)
{var_name, put_in(state.cache[id], var_name)}

%{} when limit < 0 ->
var_name = var_name(state)
{var_name, put_in(state.cache[id], var_name)}

%{} ->
{var_name, state} = cached_recur_inspect(op, args, to_type_shape(tensor), state)
{var_name, state} =
cached_recur_inspect(op, args, to_type_shape(tensor), decrement_limit(state, limit))

{var_name, put_in(state.cache[id], var_name)}
end
end
Expand Down Expand Up @@ -1678,6 +1698,9 @@ defmodule Nx.Defn.Expr do
Enum.map_reduce(args, state, &recur_inspect/2)
end

defp decrement_limit(state, :infinity), do: state
defp decrement_limit(state, limit), do: %{state | limit: limit - 1}

defp doc_inspect(term, opts) do
term
|> Inspect.Algebra.to_doc(opts)
Expand Down
92 changes: 92 additions & 0 deletions nx/test/nx/defn/expr_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -392,5 +392,97 @@ defmodule Nx.Defn.ExprTest do
>\
"""
end

defn add_sub_mult_no_tokens(a, b, c, d) do
a
|> Nx.add(b)
|> Nx.subtract(c)
|> Nx.multiply(d)
end

test "with limit option" do
t = Nx.template({}, :f32)

result = add_sub_mult_no_tokens(t, t, t, t)

# greater than the number of exprs
assert inspect(result, limit: 8) == """
#Nx.Tensor<
f32
\s\s
Nx.Defn.Expr
parameter a:0 f32
parameter b:1 f32
parameter d:2 f32
parameter f:3 f32
c = add a, b f32
e = subtract c, d f32
g = multiply e, f f32
>\
"""

# equal to the number of exprs
assert inspect(result, limit: 7) == """
#Nx.Tensor<
f32
\s\s
Nx.Defn.Expr
parameter a:0 f32
parameter b:1 f32
parameter d:2 f32
parameter f:3 f32
c = add a, b f32
e = subtract c, d f32
g = multiply e, f f32
>\
"""

# one less than the number of exprs
assert inspect(result, limit: 6) == """
#Nx.Tensor<
f32
\s\s
Nx.Defn.Expr
parameter a:0 f32
parameter b:1 f32
parameter d:2 f32
... \s
c = add a, b f32
e = subtract c, d f32
g = multiply e, f f32
>\
"""

assert inspect(result, limit: 3) == """
#Nx.Tensor<
f32
\s\s
Nx.Defn.Expr
... \s
c = add a, b f32
e = subtract c, d f32
g = multiply e, f f32
>\
"""

assert inspect(result, limit: 1) == """
#Nx.Tensor<
f32
\s\s
Nx.Defn.Expr
... \s
c = multiply a, b f32
>\
"""

assert inspect(result, limit: 0) == """
#Nx.Tensor<
f32
\s\s
Nx.Defn.Expr
... \s
>\
"""
end
end
end

0 comments on commit 101be51

Please sign in to comment.