Skip to content

Testing binary operators via torch data formats #1151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions forge/test/operators/pytorch/eltwise_binary/test_binary.py
Original file line number Diff line number Diff line change
@@ -68,6 +68,7 @@
from test.operators.utils import TestResultFailing
from test.operators.utils import FailingRulesConverter
from test.operators.utils import TestCollectionCommon
from test.operators.utils import TestCollectionTorch
from test.operators.utils import FailingReasons
from test.operators.utils.compat import TestDevice

@@ -128,7 +129,7 @@ def verify(
# else:
# dev_data_format = TestCollectionCommon.single.dev_data_formats[0]

if dev_data_format in TestCollectionCommon.int.dev_data_formats:
if dev_data_format in TestCollectionTorch.int.dev_data_formats:
value_range = ValueRanges.LARGE

if value_range is None:
@@ -162,6 +163,7 @@ def verify(
value_range=value_range,
pcc=test_vector.pcc,
warm_reset=warm_reset,
convert_to_forge=False,
)


@@ -199,7 +201,7 @@ class TestParamsData:

@classmethod
def generate_kwargs_alpha(cls, test_vector: TestVector):
if test_vector.dev_data_format in TestCollectionCommon.int.dev_data_formats:
if test_vector.dev_data_format in TestCollectionTorch.int.dev_data_formats:
return cls.kwargs_alpha_int
else:
return cls.kwargs_alpha_float
@@ -309,14 +311,14 @@ class TestCollectionData:
operators=implemented.operators,
input_sources=TestCollectionCommon.all.input_sources,
input_shapes=TestCollectionCommon.all.input_shapes,
dev_data_formats=TestCollectionCommon.all.dev_data_formats,
dev_data_formats=TestCollectionTorch.all.dev_data_formats,
math_fidelities=TestCollectionCommon.all.math_fidelities,
)

single = TestCollection(
input_sources=TestCollectionCommon.single.input_sources,
input_shapes=TestCollectionCommon.single.input_shapes,
dev_data_formats=TestCollectionCommon.single.dev_data_formats,
dev_data_formats=TestCollectionTorch.single.dev_data_formats,
math_fidelities=TestCollectionCommon.single.math_fidelities,
)

9 changes: 5 additions & 4 deletions forge/test/operators/pytorch/tm/test_reshape.py
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@
from test.operators.utils.compat import TestDevice
from test.operators.utils import TestCollection
from test.operators.utils import TestCollectionCommon
from test.operators.utils import TestCollectionTorch
from test.operators.utils import ValueRanges

from test.operators.pytorch.eltwise_unary import ModelFromAnotherOp, ModelDirect, ModelConstEvalPass
@@ -132,7 +133,7 @@ def verify(

# We use AllCloseValueChecker in all cases except for integer data formats:
verify_config = VerifyConfig(value_checker=AllCloseValueChecker())
if test_vector.dev_data_format in TestCollectionCommon.int.dev_data_formats:
if test_vector.dev_data_format in TestCollectionTorch.int.dev_data_formats:
verify_config = VerifyConfig(value_checker=AutomaticValueChecker())

VerifyUtils.verify(
@@ -271,8 +272,8 @@ class TestIdsData:
kwargs=lambda test_vector: TestParamsData.generate_random_kwargs(test_vector),
dev_data_formats=[
item
for item in TestCollectionCommon.all.dev_data_formats
if item not in TestCollectionCommon.single.dev_data_formats
for item in TestCollectionTorch.all.dev_data_formats
if item not in TestCollectionTorch.single.dev_data_formats
],
math_fidelities=TestCollectionCommon.single.math_fidelities,
),
@@ -282,7 +283,7 @@ class TestIdsData:
input_sources=TestCollectionCommon.single.input_sources,
input_shapes=TestCollectionCommon.single.input_shapes,
kwargs=lambda test_vector: TestParamsData.generate_random_kwargs(test_vector),
dev_data_formats=TestCollectionCommon.single.dev_data_formats,
dev_data_formats=TestCollectionTorch.single.dev_data_formats,
math_fidelities=TestCollectionCommon.all.math_fidelities,
),
# Test specific classes of reshape operations collection:
2 changes: 1 addition & 1 deletion forge/test/operators/utils/test_data.py
Original file line number Diff line number Diff line change
@@ -321,7 +321,7 @@ class TestCollectionTorch:

float = TestCollection(
dev_data_formats=[
torch.float16,
# torch.float16,
torch.float32,
# torch.float64,
torch.bfloat16,