Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Apr 4, 2023
1 parent edde0d7 commit 792e11d
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 10 deletions.
3 changes: 0 additions & 3 deletions src/brevitas_examples/imagenet_classification/qat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,2 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause


from .models import *
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from .models import *
4 changes: 2 additions & 2 deletions tests/brevitas/export/test_generic_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from brevitas.quant.scaled_int import Int4WeightPerTensorFloatDecoupled
from brevitas.quant.scaled_int import Int8ActPerTensorFloat
from brevitas.quant.scaled_int import Int16Bias
from brevitas_examples import imagenet_classification
from brevitas_examples.imagenet_classification import quant_models
from tests.marker import jit_disabled_for_export

OUT_CH = 50
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_generic_quant_avgpool_export_quant_input():

@jit_disabled_for_export()
def test_debug_brevitas_onnx_export():
model, cfg = imagenet_classification.model_with_cfg('quant_mobilenet_v1_4b', pretrained=False)
model, cfg = quant_models.model_with_cfg('quant_mobilenet_v1_4b', pretrained=False)
model.eval()
debug_hook = enable_debug(model, proxy_level=True)
input_tensor = torch.randn(1, 3, 224, 224)
Expand Down
11 changes: 7 additions & 4 deletions tests/brevitas_examples/test_examples_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@ def test_import_bnn_pynq():


def test_import_image_classification():
from brevitas_examples.imagenet_classification import quant_mobilenet_v1_4b
from brevitas_examples.imagenet_classification import quant_proxylessnas_mobile14_4b
from brevitas_examples.imagenet_classification import quant_proxylessnas_mobile14_4b5b
from brevitas_examples.imagenet_classification import quant_proxylessnas_mobile14_hadamard_4b
from brevitas_examples.imagenet_classification.quant_models import quant_mobilenet_v1_4b
from brevitas_examples.imagenet_classification.quant_models import \
quant_proxylessnas_mobile14_4b
from brevitas_examples.imagenet_classification.quant_models import \
quant_proxylessnas_mobile14_4b5b
from brevitas_examples.imagenet_classification.quant_models import \
quant_proxylessnas_mobile14_hadamard_4b


def test_import_tts():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from brevitas import torch_version
from brevitas.export import export_finn_onnx
from brevitas_examples.imagenet_classification import quant_mobilenet_v1_4b
from brevitas_examples.imagenet_classification.quant_models import quant_mobilenet_v1_4b

ort_mac_fail = pytest.mark.skipif(
torch_version >= parse('1.5.0') and system() == 'Darwin',
Expand Down

0 comments on commit 792e11d

Please sign in to comment.