Skip to content

Commit

Permalink
support output 0D, correct usage
Browse files Browse the repository at this point in the history
  • Loading branch information
zhwesky2010 committed Mar 15, 2023
1 parent f24115e commit 05b0d11
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 36 deletions.
4 changes: 2 additions & 2 deletions framework/api/loss/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def run(self, model=None, expect=None):
# logging.info('at {}, res is: {}'.format(i, loss))
if self.debug:
print(loss)
self.result.append(loss.numpy()[0])
self.result.append(float(loss))
# logging.info('at {}, result is: {}'.format(i, self.result))
self.check(result=self.result, expect=expect)

Expand Down Expand Up @@ -172,7 +172,7 @@ def run(self, model=None, expect=None):
# logging.info('at {}, res is: {}'.format(i, loss))
if self.debug:
print(loss)
self.result.append(loss.numpy()[0])
self.result.append(float(loss))
# logging.info('at {}, result is: {}'.format(i, self.result))
self.check(result=self.result, expect=expect)

Expand Down
2 changes: 1 addition & 1 deletion framework/api/optimizer/lrbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def run(self):
# self.optimizer.clear_grad()
# if self.debug:
# print(loss)
# self.result.append(loss.numpy()[0])
# self.result.append(float(loss))

def check(self, result=None, expect=None, delta=1e-6, rtol=1e-7):
"""
Expand Down
2 changes: 1 addition & 1 deletion framework/api/optimizer/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def run(self):
self.optimizer.clear_grad()
if self.debug:
print(loss)
self.result.append(loss.numpy()[0])
self.result.append(float(loss))

def check(self, expect=None, rtol=1e-05, atol=1e-08):
"""
Expand Down
22 changes: 11 additions & 11 deletions framework/api/paddlebase/test_allclose.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_allclose_base():
a = 0.01
r = 0.01
res = np.allclose(x, y, rtol=r, atol=a, equal_nan=False)
res = np.array([res])
res = np.array(res)
obj.base(res=res, x=x, y=y, rtol=r, atol=a, equal_nan=False)


Expand All @@ -65,7 +65,7 @@ def test_allclose1():
a = 0.0
r = 0.01
res = np.allclose(x, y, rtol=r, atol=a, equal_nan=False)
res = np.array([res])
res = np.array(res)
obj2.run(res=res, x=x, y=y, rtol=r, atol=a, equal_nan=False)


Expand All @@ -79,7 +79,7 @@ def test_allclose2():
a = 0.001
r = 0.000001
res = np.allclose(x, y, rtol=r, atol=a, equal_nan=False)
res = np.array([res])
res = np.array(res)
obj.run(res=res, x=x, y=y, rtol=r, atol=a, equal_nan=False)


Expand All @@ -93,7 +93,7 @@ def test_allclose3():
a = 0.001
r = 0.000001
res = np.allclose(x, y, rtol=r, atol=a, equal_nan=True)
res = np.array([res])
res = np.array(res)
obj.run(res=res, x=x, y=y, rtol=r, atol=a, equal_nan=True)


Expand All @@ -107,7 +107,7 @@ def test_allclose4():
a = 0.001
r = 0.00001
res = np.allclose(x, y, rtol=r, atol=a, equal_nan=True)
res = np.array([res])
res = np.array(res)
obj.run(res=res, x=x, y=y, rtol=r, atol=a, equal_nan=True)


Expand All @@ -121,7 +121,7 @@ def test_allclose5():
a = 0.001
r = 0.00001
res = np.allclose(x, y, rtol=r, atol=a, equal_nan=False)
res = np.array([res])
res = np.array(res)
obj.run(res=res, x=x, y=y, rtol=r, atol=a, equal_nan=False)


Expand All @@ -133,7 +133,7 @@ def test_allclose6():
x = np.array([10.001])
y = np.array([10.00001])
res = np.allclose(x, y)
res = np.array([res])
res = np.array(res)
obj.run(res=res, x=x, y=y)


Expand All @@ -145,7 +145,7 @@ def test_allclose7():
x = 0.1 + np.arange(24).reshape(2, 2, 2, 3)
y = np.arange(24).reshape(2, 2, 2, 3)
res = np.allclose(x, y)
res = np.array([res])
res = np.array(res)
obj.run(res=res, x=x, y=y)


Expand All @@ -159,7 +159,7 @@ def test_allclose8():
a = -2.0
r = -3.0
res = np.allclose(x, y, rtol=r, atol=a)
res = np.array([res])
res = np.array(res)
obj.run(res=res, x=x, y=y, rtol=r, atol=a)


Expand All @@ -171,7 +171,7 @@ def test_allclose9():
x = np.array([])
y = np.array([])
res = np.allclose(x, y)
res = np.array([res])
res = np.array(res)
obj.run(res=res, x=x, y=y)


Expand Down Expand Up @@ -199,4 +199,4 @@ def test_allclose10():
"""
x = np.array([10.001])
y = np.array([10.00001])
obj1.exception(mode="c", etype="NotFoundError", x=x, y=y)
obj1.exception(mode="c", etype="NotFound", x=x, y=y)
10 changes: 5 additions & 5 deletions framework/api/paddlebase/test_argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_argmax():
"""
x = randtool("float", -10, 10, [3, 3, 3])
res = np.argmax(a=x)
obj.run(res=[res], x=x)
obj.run(res=res, x=x)


@pytest.mark.api_base_argmax_parameters
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_argmax3():
x = randtool("int", -10, 10, [3, 3])
dtype = "int32"
res = np.argmax(a=x)
obj.run(res=[res], x=x, dtype=dtype)
obj.run(res=res, x=x, dtype=dtype)


@pytest.mark.api_base_argmax_parameters
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_argmax7():
x = np.array([[-1], [2], [3]])
keepdim = None
res = np.argmax(a=x)
obj.run(res=[res], x=x, keepdim=keepdim)
obj.run(res=res, x=x, keepdim=keepdim)


@pytest.mark.api_base_argmax_parameters
Expand All @@ -155,7 +155,7 @@ def test_argmax9():
"""
x = randtool("float", -1, 1, [3, 3])
axis = 2
obj.exception(mode="c", etype="InvalidArgumentError", x=x, axis=axis)
obj.exception(mode="c", etype="InvalidArgument", x=x, axis=axis)


@pytest.mark.api_base_argmax_exception
Expand All @@ -175,7 +175,7 @@ def test_argmax11():
"""
x = randtool("float", -10, 10, [3, 3])
dtype = np.float32
obj.exception(mode="c", etype="InvalidArgumentError", x=x, dtype=dtype)
obj.exception(mode="c", etype="InvalidArgument", x=x, dtype=dtype)


@pytest.mark.api_base_argmax_exception
Expand Down
10 changes: 5 additions & 5 deletions framework/api/paddlebase/test_argmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_argmin():
"""
x = randtool("float", -10, 10, [3, 3, 3])
res = np.argmin(a=x)
obj.run(res=[res], x=x)
obj.run(res=res, x=x)


@pytest.mark.api_base_argmin_parameters
Expand Down Expand Up @@ -87,7 +87,7 @@ def test_argmin3():
x = randtool("int", -10, 10, [3, 3])
dtype = "int32"
res = np.argmin(a=x)
obj.run(res=[res], x=x, dtype=dtype)
obj.run(res=res, x=x, dtype=dtype)


@pytest.mark.api_base_argmin_parameters
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_argmin7():
x = np.array([[-1], [2], [3]])
keepdim = None
res = np.argmin(a=x)
obj.run(res=[res], x=x, keepdim=keepdim)
obj.run(res=res, x=x, keepdim=keepdim)


@pytest.mark.api_base_argmin_parameters
Expand All @@ -155,7 +155,7 @@ def test_argmin9():
"""
x = randtool("float", -1, 1, [3, 3])
axis = 2
obj.exception(mode="c", etype="InvalidArgumentError", x=x, axis=axis)
obj.exception(mode="c", etype="InvalidArgument", x=x, axis=axis)


@pytest.mark.api_base_argmin_exception
Expand All @@ -175,7 +175,7 @@ def test_argmin11():
"""
x = randtool("float", -10, 10, [3, 3])
dtype = np.float32
obj.exception(mode="c", etype="InvalidArgumentError", x=x, dtype=dtype)
obj.exception(mode="c", etype="InvalidArgument", x=x, dtype=dtype)


@pytest.mark.api_base_argmin_exception
Expand Down
16 changes: 8 additions & 8 deletions framework/api/paddlebase/test_equal_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_equal_all_base():
"""
x = randtool("int", -10, 10, [3, 3, 3])
y = x
res = np.array([True])
res = np.array(True)
obj.base(res=res, x=x, y=y)


Expand All @@ -48,7 +48,7 @@ def test_equal_all():
"""
x = randtool("float", -10, 10, [3, 3, 3])
y = x
res = np.array([True])
res = np.array(True)
obj.run(res=res, x=x, y=y)


Expand All @@ -59,7 +59,7 @@ def test_equal_all1():
"""
x = randtool("float", -10, 10, [3, 3, 3])
y = randtool("int", -10, 10, [3, 3, 3])
res = np.array([False])
res = np.array(False)
obj.run(res=res, x=x, y=y)


Expand All @@ -70,7 +70,7 @@ def test_equal_all2():
"""
x = randtool("float", -10, 10, [3, 3, 3])
y = randtool("int", -10, 10, [3, 3])
res = np.array([False])
res = np.array(False)
obj.run(res=res, x=x, y=y)


Expand All @@ -81,7 +81,7 @@ def test_equal_all4():
"""
x = np.array([[3, 3, 3], [3, 3, 3]])
y = np.array([[3, 3, 3]])
res = np.array([False])
res = np.array(False)
obj.run(res=res, x=x, y=y)


Expand All @@ -102,7 +102,7 @@ def test_equal_all6():
"""
x = np.array([[True, False, True], [True, False, True]])
y = np.array([[True, False, True]])
res = np.array([False])
res = np.array(False)
obj.run(res=res, x=x, y=y)


Expand All @@ -113,7 +113,7 @@ def test_equal_all7():
"""
x = np.array([[[True, False, True], [True, False, True]]])
y = np.array([[True, False, True], [True, False, True]])
res = np.array([False])
res = np.array(False)
obj.run(res=res, x=x, y=y)


Expand All @@ -124,5 +124,5 @@ def test_equal_all8():
"""
x = np.array([[[[[True, False, True], [True, False, True]]]]])
y = x
res = np.array([True])
res = np.array(True)
obj.run(res=res, x=x, y=y)
2 changes: 1 addition & 1 deletion framework/api/paddlebase/test_numel.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ def test_numel_base():
base
"""
x = np.arange(20).reshape(4, 5)
res = np.array([20])
res = np.array(20)
obj.base(res=res, x=x)
4 changes: 2 additions & 2 deletions framework/e2e/scene/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def run(self, model=None, expect=None):
# logging.info('at {}, res is: {}'.format(i, loss))
if self.debug:
print(loss)
self.result.append(loss.numpy()[0])
self.result.append(float(loss))
# logging.info('at {}, result is: {}'.format(i, self.result))
self.check(result=self.result, expect=expect)

Expand Down Expand Up @@ -172,7 +172,7 @@ def run(self, model=None, expect=None):
# logging.info('at {}, res is: {}'.format(i, loss))
if self.debug:
print(loss)
self.result.append(loss.numpy()[0])
self.result.append(float(loss))
# logging.info('at {}, result is: {}'.format(i, self.result))
self.check(result=self.result, expect=expect)

Expand Down

0 comments on commit 05b0d11

Please sign in to comment.