Skip to content

Commit

Permalink
Support numpy scalar as input type case (#69139)
Browse files Browse the repository at this point in the history
* support numpy scalar as input type case

* update docstring

* fix
  • Loading branch information
HydrogenSulfate authored Nov 4, 2024
1 parent 0c3830f commit 1f4ea48
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,7 +1363,7 @@ def _check_attr(attr, message):
assert len(attr.shape) == 0 or (
len(attr.shape) == 1 and attr.shape[0] in [1, -1]
)
elif not isinstance(attr, int) or attr < 0:
elif not isinstance(attr, (int, np.integer)) or attr < 0:
raise TypeError(f"{message} should be a non-negative int.")

_check_attr(num_rows, "num_rows")
Expand Down
16 changes: 16 additions & 0 deletions test/legacy_test/test_eye_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,22 @@ def test_check_output(self):
self.check_output(check_pir=True)


class TestEyeOp3(OpTest):
def setUp(self):
'''
Test eye op with np.int32 scalar
'''
self.python_api = paddle.eye
self.op_type = "eye"

self.inputs = {}
self.attrs = {'num_rows': np.int32(99), 'num_columns': np.int32(1)}
self.outputs = {'Out': np.eye(99, 1, dtype=float)}

def test_check_output(self):
self.check_output(check_pir=True)


class API_TestTensorEye(unittest.TestCase):

def test_static_out(self):
Expand Down

0 comments on commit 1f4ea48

Please sign in to comment.