Skip to content

Commit

Permalink
[Tests] Adding tests for math primitives (#412)
Browse files Browse the repository at this point in the history
Closes #195
  • Loading branch information
BolinSNLHM authored and vadiklyutiy committed Dec 20, 2024
1 parent 4a1f72d commit f859fac
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tests/frontends/torch/test_torch_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def test_torch_einsum(equation, operand_shapes):

atol = 5e-2
if equation == 'abcd,cd->ab':
atol = 8e-2
atol = 1e-1

check_module(
FunctionalModule(op=lambda *args: torch.einsum(equation, *args)), args=operands_torch, atol=atol, rtol=1e-4
Expand Down
120 changes: 120 additions & 0 deletions tests/operators/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,116 @@ def test_sqrt(shape):
check_unary(shape, np.float32, np.sqrt, ops.sqrt, positive=True)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_sin(shape):
check_unary(shape, np.float32, np.sin, ops.sin)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_cos(shape):
check_unary(shape, np.float32, np.cos, ops.cos)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_tan(shape):
check_unary(shape, np.float32, np.tan, ops.tan)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_sinh(shape):
check_unary(shape, np.float32, np.sinh, ops.sinh)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_cosh(shape):
check_unary(shape, np.float32, np.cosh, ops.cosh)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_tanh(shape):
check_unary(shape, np.float32, np.tanh, ops.tanh)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_asin(shape):
check_unary(shape, np.float32, np.arcsin, ops.asin)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_acos(shape):
check_unary(shape, np.float32, np.arccos, ops.acos)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_atan(shape):
check_unary(shape, np.float32, np.arctan, ops.atan)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_asinh(shape):
check_unary(shape, np.float32, np.arcsinh, ops.asinh)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_acosh(shape):
check_unary(shape, np.float32, np.arccosh, ops.acosh)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_atanh(shape):
check_unary(shape, np.float32, np.arctanh, ops.atanh)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_exp(shape):
check_unary(shape, np.float32, np.exp, ops.exp)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_expm1(shape):
check_unary(shape, np.float32, np.expm1, ops.expm1)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_erf(shape):
check_unary(shape, np.float32, np.vectorize(math.erf), ops.erf)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_sqrt(shape):
check_unary(shape, np.float32, np.sqrt, ops.sqrt, positive=True)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_rsqrt(shape):
check_unary(shape, np.float32, lambda v: np.reciprocal(np.sqrt(v)), ops.rsqrt, positive=True)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_log(shape):
check_unary(shape, np.float32, np.log, ops.log, positive=True)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_log2(shape):
check_unary(shape, np.float32, np.log2, ops.log2, positive=True)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_log10(shape):
check_unary(shape, np.float32, np.log10, ops.log10, positive=True)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_log1p(shape):
check_unary(shape, np.float32, np.log1p, ops.log1p, positive=True)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_round(shape):
check_unary(shape, np.float32, np.round, ops.round)


@pytest.mark.parametrize("shape", unary_op_shapes)
def test_neg(shape):
check_unary(shape, np.float32, np.negative, ops.negative)
Expand Down Expand Up @@ -142,6 +237,31 @@ def test_ceil(a_shape):
check_unary(a_shape, np.float32, np.ceil, ops.ceil)


@pytest.mark.parametrize("a_shape", unary_op_shapes)
def test_floor(a_shape):
check_unary(a_shape, np.float32, np.floor, ops.floor)


@pytest.mark.parametrize("a_shape", unary_op_shapes)
def test_trunc(a_shape):
check_unary(a_shape, np.float32, np.trunc, ops.trunc)


@pytest.mark.parametrize("a_shape", unary_op_shapes)
def test_isfinite(a_shape):
check_unary(a_shape, np.float32, np.isfinite, ops.isfinite)


@pytest.mark.parametrize("a_shape", unary_op_shapes)
def test_isinf(a_shape):
check_unary(a_shape, np.float32, np.isinf, ops.isinf)


@pytest.mark.parametrize("a_shape", unary_op_shapes)
def test_isnan(a_shape):
check_unary(a_shape, np.float32, np.isnan, ops.isnan)


@pytest.mark.parametrize("a_shape", [[20]])
def test_cast_from_fp16(a_shape):
check_unary(a_shape, np.float16, np.int8, lambda x: ops.cast(x, "int8"))
Expand Down

0 comments on commit f859fac

Please sign in to comment.