-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add paddle hub #31873
Add paddle hub #31873
Changes from 28 commits
6ef62e0
a71dff5
64570eb
af43bfd
00a9786
1b613cb
884bbf9
2c52a28
f9dd61d
c8d2ae3
8acff14
2237120
8185cc4
c4347b2
aa0eac9
dfa6d03
ec5183d
cc7bbe7
58d9632
8ecb192
8d544b9
17e6517
697ccbd
3d7d9d4
3552ea2
b8d5d28
efb1a83
cb03c54
ba15ac1
ae0bb17
59e57e3
204e047
56e1042
aea875a
67d17ca
cad8c80
7408e84
7829730
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import re | ||
import sys | ||
import shutil | ||
import zipfile | ||
from paddle.utils.download import get_path_from_url | ||
|
||
MASTER_BRANCH = 'main' | ||
DEFAULT_CACHE_DIR = '~/.cache' | ||
VAR_DEPENDENCY = 'dependencies' | ||
MODULE_HUBCONF = 'hubconf.py' | ||
HUB_DIR = os.path.expanduser(os.path.join('~', '.cache', 'paddle', 'hub')) | ||
|
||
|
||
def _remove_if_exists(path): | ||
if os.path.exists(path): | ||
if os.path.isfile(path): | ||
os.remove(path) | ||
else: | ||
shutil.rmtree(path) | ||
|
||
|
||
def _git_archive_link(repo_owner, repo_name, branch, source): | ||
if source == 'github': | ||
return 'https://github.com/{}/{}/archive/{}.zip'.format( | ||
repo_owner, repo_name, branch) | ||
elif source == 'gitee': | ||
return 'https://gitee.com/{}/{}/repository/archive/{}.zip'.format( | ||
repo_owner, repo_name, branch) | ||
|
||
|
||
def _parse_repo_info(github): | ||
branch = MASTER_BRANCH | ||
if ':' in github: | ||
repo_info, branch = github.split(':') | ||
else: | ||
repo_info = github | ||
repo_owner, repo_name = repo_info.split('/') | ||
return repo_owner, repo_name, branch | ||
|
||
|
||
def _get_cache_or_reload(repo, force_reload, verbose=True, source='github'): | ||
# Setup hub_dir to save downloaded files | ||
hub_dir = HUB_DIR | ||
if not os.path.exists(hub_dir): | ||
os.makedirs(hub_dir) | ||
# Parse github/gitee repo information | ||
repo_owner, repo_name, branch = _parse_repo_info(repo) | ||
# Github allows branch name with slash '/', | ||
# this causes confusion with path on both Linux and Windows. | ||
# Backslash is not allowed in Github branch name so no need to | ||
# to worry about it. | ||
normalized_br = branch.replace('/', '_') | ||
# Github renames folder repo/v1.x.x to repo-1.x.x | ||
# We don't know the repo name before downloading the zip file | ||
# and inspect name from it. | ||
# To check if cached repo exists, we need to normalize folder names. | ||
repo_dir = os.path.join(hub_dir, | ||
'_'.join([repo_owner, repo_name, normalized_br])) | ||
|
||
use_cache = (not force_reload) and os.path.exists(repo_dir) | ||
|
||
if use_cache: | ||
if verbose: | ||
sys.stderr.write('Using cache found in {}\n'.format(repo_dir)) | ||
else: | ||
cached_file = os.path.join(hub_dir, normalized_br + '.zip') | ||
_remove_if_exists(cached_file) | ||
|
||
url = _git_archive_link(repo_owner, repo_name, branch, source=source) | ||
|
||
get_path_from_url(url, hub_dir, decompress=False) | ||
|
||
with zipfile.ZipFile(cached_file) as cached_zipfile: | ||
extraced_repo_name = cached_zipfile.infolist()[0].filename | ||
extracted_repo = os.path.join(hub_dir, extraced_repo_name) | ||
_remove_if_exists(extracted_repo) | ||
# Unzip the code and rename the base folder | ||
cached_zipfile.extractall(hub_dir) | ||
|
||
_remove_if_exists(cached_file) | ||
_remove_if_exists(repo_dir) | ||
# rename the repo | ||
shutil.move(extracted_repo, repo_dir) | ||
|
||
return repo_dir | ||
|
||
|
||
def _load_entry_from_hubconf(m, name): | ||
'''load entry from hubconf | ||
''' | ||
if not isinstance(name, str): | ||
raise ValueError( | ||
'Invalid input: model should be a str of function name') | ||
|
||
func = getattr(m, name, None) | ||
|
||
if func is None or not callable(func): | ||
raise RuntimeError('Canot find callable {} in hubconf'.format(name)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. copy~ |
||
|
||
return func | ||
|
||
|
||
def _check_module_exists(name): | ||
try: | ||
__import__(name) | ||
return True | ||
except ImportError: | ||
return False | ||
|
||
|
||
def _check_dependencies(m): | ||
dependencies = getattr(m, VAR_DEPENDENCY, None) | ||
|
||
if dependencies is not None: | ||
missing_deps = [ | ||
pkg for pkg in dependencies if not _check_module_exists(pkg) | ||
] | ||
if len(missing_deps): | ||
raise RuntimeError('Missing dependencies: {}'.format(', '.join( | ||
missing_deps))) | ||
|
||
|
||
def list(repo_dir, source='github', force_reload=False): | ||
r""" | ||
List all entrypoints available in `github` hubconf. | ||
|
||
Args: | ||
repo_dir(str): github or local path | ||
github path (str): a str with format "repo_owner/repo_name[:tag_name]" with an optional | ||
tag/branch. The default branch is `master` if not specified. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. default是main还是master There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 应该是 |
||
local path (str): local repo path | ||
source (str): `github` | `gitee` | `local`, default is `github` | ||
force_reload (bool, optional): whether to discard the existing cache and force a fresh download, default is `False`. | ||
Returns: | ||
entrypoints: a list of available entrypoint names | ||
|
||
Example: | ||
```python | ||
import paddle | ||
|
||
paddle.hub.help('lyuwenyu/paddlehub_demo:main', source='github', force_reload=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 示例的api不对 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it ~😢 |
||
|
||
``` | ||
""" | ||
if source not in ('github', 'gitee', 'local'): | ||
raise ValueError( | ||
'Unknown source: "{}". Allowed values: "github" | "gitee" | "local".'. | ||
format(source)) | ||
|
||
if source in ('github', 'gitee'): | ||
repo_dir = _get_cache_or_reload( | ||
repo_dir, force_reload, True, source=source) | ||
|
||
sys.path.insert(0, repo_dir) | ||
hub_module = __import__(MODULE_HUBCONF.split('.')[0]) | ||
sys.path.remove(repo_dir) | ||
|
||
entrypoints = [ | ||
f for f in dir(hub_module) | ||
if callable(getattr(hub_module, f)) and not f.startswith('_') | ||
] | ||
|
||
return entrypoints | ||
|
||
|
||
def help(repo_dir, model, source='github', force_reload=False): | ||
""" | ||
Show help information of model | ||
|
||
Args: | ||
repo_dir(str): github or local path | ||
github path (str): a str with format "repo_owner/repo_name[:tag_name]" with an optional | ||
tag/branch. The default branch is `master` if not specified. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
local path (str): local repo path | ||
model (str): model name | ||
source (str): `github` | `gitee` | `local`, default is `github` | ||
force_reload (bool, optional): default is `False` | ||
Return: | ||
docs | ||
|
||
Example: | ||
```python | ||
import paddle | ||
|
||
paddle.hub.help('lyuwenyu/paddlehub_demo:main', model='MM', source='github') | ||
``` | ||
""" | ||
if source not in ('github', 'gitee', 'local'): | ||
raise ValueError( | ||
'Unknown source: "{}". Allowed values: "github" | "gitee" | "local".'. | ||
format(source)) | ||
|
||
if source in ('github', 'gitee'): | ||
repo_dir = _get_cache_or_reload( | ||
repo_dir, force_reload, True, source=source) | ||
|
||
sys.path.insert(0, repo_dir) | ||
hub_module = __import__(MODULE_HUBCONF.split('.')[0]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right. I should
ps. using importlib seems beeter, but it‘s incompatible in py2 and py3 😢 |
||
sys.path.remove(repo_dir) | ||
|
||
entry = _load_entry_from_hubconf(hub_module, model) | ||
|
||
return entry.__doc__ | ||
|
||
|
||
def load(repo_dir, model, source='github', force_reload=False, **kwargs): | ||
""" | ||
Load model | ||
|
||
Args: | ||
repo_dir(str): github or local path | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as offline discussion, also can add gitee There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok, will push it after testing offline |
||
github path (str): a str with format "repo_owner/repo_name[:tag_name]" with an optional | ||
tag/branch. The default branch is `master` if not specified. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
local path (str): local repo path | ||
mdoel (str): model name | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo |
||
source (str): `github` | `gitee` | `local`, default is `github` | ||
force_reload (bool, optional), default is `False` | ||
**kwargs: parameters using for model | ||
Return: | ||
paddle model | ||
Example: | ||
```python | ||
import paddle | ||
paddle.hub.load('lyuwenyu/paddlehub_demo:main', model='MM', source='github') | ||
``` | ||
""" | ||
if source not in ('github', 'gitee', 'local'): | ||
raise ValueError( | ||
'Unknown source: "{}". Allowed values: "github" | "gitee" | "local".'. | ||
format(source)) | ||
|
||
if source in ('github', 'gitee'): | ||
repo_dir = _get_cache_or_reload( | ||
repo_dir, force_reload, True, source=source) | ||
|
||
sys.path.insert(0, repo_dir) | ||
hub_module = __import__(MODULE_HUBCONF.split('.')[0]) | ||
sys.path.remove(repo_dir) | ||
|
||
_check_dependencies(hub_module) | ||
|
||
entry = _load_entry_from_hubconf(hub_module, model) | ||
|
||
return entry(**kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
dependencies = ['paddle'] | ||
|
||
import paddle | ||
from test_hapi_hub_model import MM as _MM | ||
|
||
|
||
def MM(out_channels=8, pretrained=False): | ||
'''This is a test demo for paddle hub | ||
''' | ||
return _MM(out_channels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that multi-process calls may cause the problem of repeated directory creation and raise an exception.
suggest:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copy~