Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip forge verification #1147

Merged
merged 1 commit into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/src/test.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ Full list of supported query parameters
| TEST_ID | Id of a test containing test parameters | test_single |
| ID_FILE | Path to a file containing test ids | test_ids |

Test configuration parameters

| Parameter | Description | Supported by commands |
| ------------------------- | --------------------------------------------------------------------------------------------- | ------------------------------------- |
| SKIP_FORGE_VERIFICATION | Skip Forge model verification including model compiling and inference | all |

To check supported values and options for each query parameter please run command `print_query_docs`.

Expand Down
26 changes: 26 additions & 0 deletions forge/test/operators/pytorch/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,10 @@ def print_query_params(cls, max_width=80):
cls.print_query_values(max_width)
print("Query examples:")
cls.print_query_examples(max_width)
print("Configuration parameters:")
cls.print_configuration_params(max_width)
print("Configuration examples:")
cls.print_configuration_examples(max_width)

@classmethod
def print_query_values(cls, max_width=80):
Expand Down Expand Up @@ -500,6 +504,28 @@ def print_query_examples(cls, max_width=80):

cls.print_formatted_parameters(parameters, max_width, headers=["Parameter", "Examples"])

@classmethod
def print_configuration_params(cls, max_width=80):

parameters = [
{
"name": "SKIP_FORGE_VERIFICATION",
"description": f"Skip Forge model verification including model compiling and inference",
"default": "false",
},
]

cls.print_formatted_parameters(parameters, max_width, headers=["Parameter", "Description", "Default"])

@classmethod
def print_configuration_examples(cls, max_width=80):

parameters = [
{"name": "SKIP_FORGE_VERIFICATION", "description": "export SKIP_FORGE_VERIFICATION=true"},
]

cls.print_formatted_parameters(parameters, max_width, headers=["Parameter", "Examples"])

@classmethod
def print_formatted_parameters(cls, parameters, max_width=80, headers=["Parameter", "Description"]):
for param in parameters:
Expand Down
2 changes: 2 additions & 0 deletions forge/test/operators/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .utils import LoggerUtils
from .utils import RateLimiter
from .utils import FrameworkModelType
from .features import TestFeaturesConfiguration
from .plan import InputSource
from .plan import TestVector
from .plan import TestCollection
Expand Down Expand Up @@ -41,6 +42,7 @@
"VerifyUtils",
"LoggerUtils",
"RateLimiter",
"TestFeaturesConfiguration",
"FrameworkModelType",
"InputSource",
"TestVector",
Expand Down
50 changes: 50 additions & 0 deletions forge/test/operators/utils/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Optional, List, Union

from forge import ForgeModule, Module, DepricatedVerifyConfig
from forge.tensor import to_pt_tensors
from forge.op_repo import TensorShape
from forge.verify.compare import compare_with_golden
from forge.verify.verify import verify
Expand Down Expand Up @@ -326,3 +327,52 @@ def verify_module_for_inputs(
forge_inputs = [forge.Tensor.create_from_torch(input, dev_data_format=dev_data_format) for input in inputs]
compiled_model = forge.compile(model, sample_inputs=forge_inputs)
verify(inputs, model, compiled_model, verify_config)


def verify_module_for_inputs_torch(
model: Module,
inputs: List[torch.Tensor],
verify_config: Optional[VerifyConfig] = VerifyConfig(),
):

verify_torch(inputs, model, verify_config)


def verify_torch(
inputs: List[torch.Tensor],
framework_model: torch.nn.Module,
verify_cfg: VerifyConfig = VerifyConfig(),
):
"""
Verify the pytorch model with the given inputs
"""
if not verify_cfg.enabled:
logger.warning("Verification is disabled")
return

# 0th step: input checks

# Check if inputs are of the correct type
if not inputs:
raise ValueError("Input tensors must be provided")
for input_tensor in inputs:
if not isinstance(input_tensor, verify_cfg.supported_tensor_types):
raise TypeError(
f"Input tensor must be of type {verify_cfg.supported_tensor_types}, but got {type(input_tensor)}"
)

if not isinstance(framework_model, verify_cfg.framework_model_types):
raise TypeError(
f"Framework model must be of type {verify_cfg.framework_model_types}, but got {type(framework_model)}"
)

# 1st step: run forward pass for the networks
fw_out = framework_model(*inputs)

# 2nd step: apply preprocessing (push tensors to cpu, perform any reshape if necessary,
# cast from tensorflow tensors to pytorch tensors if needed)
if not isinstance(fw_out, torch.Tensor):
fw_out = to_pt_tensors(fw_out)

fw_out = [fw_out] if isinstance(fw_out, torch.Tensor) else fw_out
return fw_out
17 changes: 17 additions & 0 deletions forge/test/operators/utils/features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
#
# SPDX-License-Identifier: Apache-2.0

import os


class TestFeaturesConfiguration:
"""Store test features configuration"""

__test__ = False # Disable pytest collection

@staticmethod
def get_env_property(env_var: str, default_value: str):
return os.getenv(env_var, default_value)

SKIP_FORGE_VERIFICATION = get_env_property("SKIP_FORGE_VERIFICATION", "false").lower() == "true"
17 changes: 16 additions & 1 deletion forge/test/operators/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,14 @@
from forge.verify.config import VerifyConfig

from .compat import TestDevice
from .compat import create_torch_inputs, verify_module_for_inputs, verify_module_for_inputs_deprecated
from .compat import (
create_torch_inputs,
verify_module_for_inputs,
verify_module_for_inputs_deprecated,
verify_module_for_inputs_torch,
)
from .datatypes import ValueRanges
from .features import TestFeaturesConfiguration


# All supported framework model types
Expand Down Expand Up @@ -130,6 +136,7 @@ def verify(
warm_reset: bool = False,
deprecated_verification: bool = True,
verify_config: Optional[VerifyConfig] = VerifyConfig(),
skip_forge_verification: bool = TestFeaturesConfiguration.SKIP_FORGE_VERIFICATION,
):
"""Perform Forge verification on the model
Expand All @@ -146,6 +153,8 @@ def verify(
random_seed: Random seed
warm_reset: Warm reset the device before verification
deprecated_verification: Use deprecated verification method
verify_config: Verification configuration
skip_forge_verification: Skip verification with Forge module
"""

cls.setup(
Expand All @@ -168,6 +177,12 @@ def verify(
pcc=pcc,
dev_data_format=dev_data_format,
)
elif skip_forge_verification:
verify_module_for_inputs_torch(
model=model,
inputs=inputs,
verify_config=verify_config,
)
else:
cls.verify_module_for_inputs(
model=model,
Expand Down
Loading