Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic framework detection in TasksManager for large models #883

Merged
merged 8 commits into from
Mar 16, 2023

Conversation

fxmarty
Copy link
Contributor

@fxmarty fxmarty commented Mar 15, 2023

Fixes #867

@fxmarty fxmarty requested review from michaelbenayoun, regisss, JingyaHuang and mht-sharma and removed request for michaelbenayoun March 15, 2023 12:12
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 15, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@michaelbenayoun michaelbenayoun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few nits, but LGTM besides that!

)
logger.info(f"Local {framework_map[framework]} model found.")
all_files = [
os.path.relpath(os.path.join(dirpath, file), full_model_path)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am more in favor to use pathlib here but that's just my tastes:

  • os.walk => path.glob
  • os.path.relpath => path.relative_to

optimum/exporters/tasks.py Show resolved Hide resolved
optimum/exporters/tasks.py Outdated Show resolved Hide resolved
fxmarty and others added 3 commits March 15, 2023 17:32
Comment on lines +936 to +945
if any(is_pt_weight_file):
framework = "pt"
elif any(is_tf_weight_file):
framework = "tf"
else:
raise FileNotFoundError(
"Cannot determine framework from given checkpoint location."
f" There should be a {Path(WEIGHTS_NAME).stem}*{Path(WEIGHTS_NAME).suffix} for PyTorch"
f" or {Path(TF2_WEIGHTS_NAME).stem}*{Path(TF2_WEIGHTS_NAME).suffix} for TensorFlow."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should log which framework has been chosen, something like:

logger.info(f"No framework specified so {framework} has been automatically chosen.")

Because here PT will have the priority over TF when there are checkpoints for both frameworks in the repo (which is fine), like BERT for instance: https://huggingface.co/bert-base-uncased/tree/main

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already logged in the CLI: Framework not specified. Using pt to export to ONNX. If passing export=True this is indeed not logged.

@fxmarty
Copy link
Contributor Author

fxmarty commented Mar 16, 2023

Merging as failing tests are unrelated.

@fxmarty fxmarty merged commit 7f62e7d into huggingface:main Mar 16, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Auto-detect framework for large models at ONNX export
4 participants