diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index fb6d57e4618ef..14ac454aec646 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1190,14 +1190,14 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"): diagonal : relay.Expr Values to be filled in the diagonal. - k : int or tuple of int + k : int or tuple of int, optional Diagonal Offset(s). The diagonal or range of diagonals to set. (0 by default) Positive value means superdiagonal, 0 refers to the main diagonal, and negative value means subdiagonals. k can be a single integer (for a single diagonal) or a pair of integers specifying the low and high ends of a matrix band. k[0] must not be larger than k[1]. - align : string + align : string, optional Some diagonals are shorter than max_diag_len and need to be padded. align is a string specifying how superdiagonals and subdiagonals should be aligned, respectively. There are four possible alignments: "RIGHT_LEFT" (default), "LEFT_RIGHT", diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 7d77115a9060f..c4e51a8858d17 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -818,14 +818,14 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"): diagonal : relay.Expr Values to be filled in the diagonal. - k : int or tuple of int + k : int or tuple of int, optional Diagonal Offset(s). The diagonal or range of diagonals to set. (0 by default) Positive value means superdiagonal, 0 refers to the main diagonal, and negative value means subdiagonals. k can be a single integer (for a single diagonal) or a pair of integers specifying the low and high ends of a matrix band. k[0] must not be larger than k[1]. - align : string + align : string, optional Some diagonals are shorter than max_diag_len and need to be padded. align is a string specifying how superdiagonals and subdiagonals should be aligned, respectively. There are four possible alignments: "RIGHT_LEFT" (default), "LEFT_RIGHT", diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 577d1c78d8e2e..68dd243363831 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -530,6 +530,8 @@ def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"): _verify((4, 3, 3), (4, 3), "int32") _verify((2, 3, 4), (2, 3), "float32", 1) _verify((2, 3, 4), (2, 4, 3), "int32", (-1, 2), "LEFT_RIGHT") + _verify((2, 3, 4), (2, 4, 3), 'int32', (-1, 2), "LEFT_LEFT") + _verify((2, 3, 4), (2, 4, 3), 'int32', (-1, 2), "RIGHT_RIGHT") if __name__ == "__main__": diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 77e25407fc70c..f18b5397eefe8 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -1167,6 +1167,8 @@ def test_matrix_set_diag(): verify_matrix_set_diag((4, 3, 3), (4, 3), dtype) verify_matrix_set_diag((2, 3, 4), (2, 3), dtype, 1) verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_RIGHT") + verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_LEFT") + verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "RIGHT_RIGHT") @tvm.testing.uses_gpu