Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove the need for the config to be in the subfolder #2044

Merged
merged 8 commits into from
Oct 10, 2024
Merged
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 20 additions & 14 deletions optimum/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@

from huggingface_hub import create_repo, upload_file
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from huggingface_hub.errors import OfflineModeIsEnabled
from transformers import AutoConfig, PretrainedConfig, add_start_docstrings

from .exporters import TasksManager
from .utils import CONFIG_NAME
from .utils.file_utils import find_files_matching_pattern


if TYPE_CHECKING:
Expand Down Expand Up @@ -380,27 +382,31 @@ def from_pretrained(
)
model_id, revision = model_id.split("@")

config_folder = subfolder
try:
if len(find_files_matching_pattern(model_id, cls.config_name, subfolder=subfolder)) == 0:
logger.info(
f"{cls.config_name} not found in the specified subfolder {subfolder}. Using the top level {cls.config_name}."
)
config_folder = ""
except OfflineModeIsEnabled:
# TODO: enable this for offline mode by checking the cache
logger.info(f"Offline mode enabled, the {cls.config_name} is expected to be in the subfolder {subfolder}.")

library_name = TasksManager.infer_library_from_model(
model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
model_id, subfolder=config_folder, revision=revision, cache_dir=cache_dir, token=token
)

if library_name == "timm":
config = PretrainedConfig.from_pretrained(
model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
model_id, subfolder=config_folder, revision=revision, cache_dir=cache_dir, token=token
)

if config is None:
if os.path.isdir(os.path.join(model_id, subfolder)) and cls.config_name == CONFIG_NAME:
if CONFIG_NAME in os.listdir(os.path.join(model_id, subfolder)):
if os.path.isdir(os.path.join(model_id, config_folder)) and cls.config_name == CONFIG_NAME:
if CONFIG_NAME in os.listdir(os.path.join(model_id, config_folder)):
config = AutoConfig.from_pretrained(
os.path.join(model_id, subfolder), trust_remote_code=trust_remote_code
)
elif CONFIG_NAME in os.listdir(model_id):
config = AutoConfig.from_pretrained(
os.path.join(model_id, CONFIG_NAME), trust_remote_code=trust_remote_code
)
logger.info(
f"config.json not found in the specified subfolder {subfolder}. Using the top level config.json."
os.path.join(model_id, config_folder), trust_remote_code=trust_remote_code
)
else:
raise OSError(f"config.json not found in {model_id} local folder")
Expand All @@ -411,7 +417,7 @@ def from_pretrained(
cache_dir=cache_dir,
token=token,
force_download=force_download,
subfolder=subfolder,
subfolder=config_folder,
trust_remote_code=trust_remote_code,
)
elif isinstance(config, (str, os.PathLike)):
Expand All @@ -421,7 +427,7 @@ def from_pretrained(
cache_dir=cache_dir,
token=token,
force_download=force_download,
subfolder=subfolder,
subfolder=config_folder,
trust_remote_code=trust_remote_code,
)

Expand Down
Loading