Skip to content

Commit

Permalink
Added more tests and some minor documentation changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rishabh Jain committed Sep 24, 2020
1 parent b44a394 commit 8d1c8bb
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 4 deletions.
4 changes: 2 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,14 +1190,14 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"):
diagonal : relay.Expr
Values to be filled in the diagonal.
k : int or tuple of int
k : int or tuple of int, optional
Diagonal Offset(s). The diagonal or range of diagonals to set. (0 by default)
Positive value means superdiagonal, 0 refers to the main diagonal, and
negative value means subdiagonals. k can be a single integer (for a single diagonal)
or a pair of integers specifying the low and high ends of a matrix band.
k[0] must not be larger than k[1].
align : string
align : string, optional
Some diagonals are shorter than max_diag_len and need to be padded.
align is a string specifying how superdiagonals and subdiagonals should be aligned,
respectively. There are four possible alignments: "RIGHT_LEFT" (default), "LEFT_RIGHT",
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,14 +818,14 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"):
diagonal : relay.Expr
Values to be filled in the diagonal.
k : int or tuple of int
k : int or tuple of int, optional
Diagonal Offset(s). The diagonal or range of diagonals to set. (0 by default)
Positive value means superdiagonal, 0 refers to the main diagonal, and
negative value means subdiagonals. k can be a single integer (for a single diagonal)
or a pair of integers specifying the low and high ends of a matrix band.
k[0] must not be larger than k[1].
align : string
align : string, optional
Some diagonals are shorter than max_diag_len and need to be padded.
align is a string specifying how superdiagonals and subdiagonals should be aligned,
respectively. There are four possible alignments: "RIGHT_LEFT" (default), "LEFT_RIGHT",
Expand Down
2 changes: 2 additions & 0 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,8 @@ def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"):
_verify((4, 3, 3), (4, 3), "int32")
_verify((2, 3, 4), (2, 3), "float32", 1)
_verify((2, 3, 4), (2, 4, 3), "int32", (-1, 2), "LEFT_RIGHT")
_verify((2, 3, 4), (2, 4, 3), "int32", (-1, 2), "LEFT_LEFT")
_verify((2, 3, 4), (2, 4, 3), "int32", (-1, 2), "RIGHT_RIGHT")


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions tests/python/topi/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,8 @@ def test_matrix_set_diag():
verify_matrix_set_diag((4, 3, 3), (4, 3), dtype)
verify_matrix_set_diag((2, 3, 4), (2, 3), dtype, 1)
verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_RIGHT")
verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_LEFT")
verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "RIGHT_RIGHT")


@tvm.testing.uses_gpu
Expand Down

0 comments on commit 8d1c8bb

Please sign in to comment.