Skip to content

Commit

Permalink
Merge branch 'main' into feature-conv-groups
Browse files Browse the repository at this point in the history
  • Loading branch information
andravin authored Oct 31, 2024
2 parents a0e266e + 6e0668a commit 42edc3c
Show file tree
Hide file tree
Showing 6 changed files with 9 additions and 41 deletions.
11 changes: 6 additions & 5 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ jobs:
pytorch-version: 1.6.0
- python-version: 3.9
pytorch-version: 1.7.1
- python-version: 3.9
pytorch-version: 1.8

- python-version: 3.10
pytorch-version: 1.4.0
Expand All @@ -57,14 +59,13 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest pytest-cov
pip install torch==${{ matrix.pytorch-version }} torchvision transformers
pip install compressai
python -m pip install --upgrade uv
uv pip install --system pytest pytest-cov
uv pip install --system torch==${{ matrix.pytorch-version }} torchvision transformers
- name: mypy
if: ${{ matrix.pytorch-version == '2.2' }}
run: |
python -m pip install mypy==1.9.0
uv pip install --system mypy==1.9.0
mypy --install-types --non-interactive .
- name: pytest
if: ${{ matrix.pytorch-version == '2.2' }}
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ ci:
skip: [mypy, pytest]
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
rev: v0.7.1
hooks:
- id: ruff
args: [--fix]
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ mypy
pytest
pytest-cov
pre-commit
ruff
transformers
compressai
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
torch
torchvision
numpy
numpy<2
7 changes: 0 additions & 7 deletions tests/torchinfo_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any

import numpy as np
import torch
from torch import nn
from torch.nn.utils import prune
Expand All @@ -26,7 +25,6 @@
ModuleDictModel,
MultipleInputNetDifferentDtypes,
NamedTuple,
NumpyModel,
PackPaddedLSTM,
ParameterFCNet,
ParameterListModel,
Expand Down Expand Up @@ -428,11 +426,6 @@ def test_namedtuple() -> None:
summary(model, input_size=input_size, z=named_tuple, device=torch.device("cpu"))


def test_numpy_model() -> None:
model = NumpyModel()
summary(model, input_data=np.ones(3, dtype=np.float32))


def test_return_dict() -> None:
input_size = [torch.Size([1, 28, 28]), [12]]

Expand Down
26 changes: 0 additions & 26 deletions tests/torchinfo_xl_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pytest
import torch
import torchvision # type: ignore[import-untyped]
from compressai.zoo import image_models # type: ignore[import-untyped]
from packaging import version

from tests.fixtures.genotype import GenotypeNetwork # type: ignore[attr-defined]
Expand All @@ -12,8 +11,6 @@
if version.parse(torch.__version__) >= version.parse("1.8"):
from transformers import ( # type: ignore[import-untyped]
AutoModelForSeq2SeqLM,
BertConfig,
BertModel,
)


Expand Down Expand Up @@ -166,26 +163,3 @@ def test_flan_t5_small() -> None:
"labels": torch.zeros(3, 100).long(),
}
summary(model, input_data=inputs)


@pytest.mark.skipif(
version.parse(torch.__version__) < version.parse("1.8"),
reason="BertModel only works for PyTorch v1.8 and above",
)
def test_bert() -> None:
model = BertModel(BertConfig())
summary(
model,
input_size=[(2, 512), (2, 512), (2, 512)],
dtypes=[torch.int, torch.int, torch.int],
device="cpu",
)


@pytest.mark.skipif(
version.parse(torch.__version__) < version.parse("1.8"),
reason="compressai only works for PyTorch v1.8 and above",
)
def test_compressai() -> None:
model = image_models["bmshj2018-factorized"](quality=4, pretrained=True)
summary(model, (1, 3, 256, 256))

0 comments on commit 42edc3c

Please sign in to comment.