Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PVT v2 tests #29660

Merged
merged 2 commits into from
Mar 14, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/models/pvt_v2/test_modeling_pvt_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import unittest

from transformers import PvtV2Backbone, PvtV2Config, is_torch_available, is_vision_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
require_accelerate,
require_torch,
Expand All @@ -39,6 +38,7 @@
import torch

from transformers import MODEL_MAPPING, AutoImageProcessor, PvtV2ForImageClassification, PvtV2Model
from transformers.models.auto.modeling_auto import MODEL_MAPPING_NAMES
from transformers.models.pvt_v2.modeling_pvt_v2 import PVT_V2_PRETRAINED_MODEL_ARCHIVE_LIST


Expand Down Expand Up @@ -289,7 +289,7 @@ def test_training(self):
config.return_dict = True

for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING):
if model_class.__name__ in MODEL_MAPPING_NAMES.values():
continue
model = model_class(config)
model.to(torch_device)
Expand Down
Loading