Skip to content

Commit

Permalink
remove Dot subclass
Browse files Browse the repository at this point in the history
  • Loading branch information
purna135 authored and brandonwillard committed Oct 6, 2022
1 parent 73bffe8 commit 6cda5c3
Showing 1 changed file with 7 additions and 14 deletions.
21 changes: 7 additions & 14 deletions tests/tensor/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,6 @@
from tests.unittest_tools import check_infer_shape, verify_grad


class DotBW(Dot):
gufunc_sig = ((("m", "n"), ("n", "p")), (("m", "p"),))


dot_bw = DotBW()


def test_update_dim_sizes():
with pytest.raises(ValueError, match=".*dimensional argument.*"):
_update_dim_sizes({}, at.tensor("float64", ()), ("m",))
Expand Down Expand Up @@ -79,7 +72,7 @@ def test_parse_input_dimensions(args, arg_vals, input_core_dims, output_core_dim
"op, args, arg_vals, np_fn",
[
(
dot_bw,
Dot(),
(
at.tensor("float64", (None, None, None)),
at.tensor("float64", (None, None, None)),
Expand All @@ -88,7 +81,7 @@ def test_parse_input_dimensions(args, arg_vals, input_core_dims, output_core_dim
lambda x, y: np.dot(x, y),
),
(
dot_bw,
Dot(),
(
at.tensor("float64", (None, None, None)),
at.tensor("float64", (None, None)),
Expand Down Expand Up @@ -116,10 +109,10 @@ def test_Blockwise_perform(op, args, arg_vals, np_fn):
@pytest.mark.parametrize(
"op, s_left, s_right",
[
(dot_bw, (3, 5, 6), (3, 6, 7)),
(dot_bw, (3, 1, 2), (3, 2, 1)),
(Dot(), (3, 5, 6), (3, 6, 7)),
(Dot(), (3, 1, 2), (3, 2, 1)),
(
dot_bw,
Dot(),
(5, 4, 3),
(
3,
Expand All @@ -146,7 +139,7 @@ def test_Blockwise_infer_shape(op, s_left, s_right):
"op, args, arg_vals, np_fn",
[
(
dot_bw,
Dot(),
(
at.tensor("float64", (None, None, None)),
at.tensor("float64", (None, None, None)),
Expand All @@ -155,7 +148,7 @@ def test_Blockwise_infer_shape(op, s_left, s_right):
lambda x, y: np.dot(x, y),
),
(
dot_bw,
Dot(),
(
at.tensor("float64", (None, None, None)),
at.tensor("float64", (None, None)),
Expand Down

0 comments on commit 6cda5c3

Please sign in to comment.