Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update package command and clean up root_path #153

Merged
merged 9 commits into from
Nov 23, 2021
4 changes: 4 additions & 0 deletions bioimageio/core/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def package(
path: Path = typer.Argument(Path() / "{src_name}-package.zip", help="Save package as"),
weights_priority_order: Optional[List[str]] = typer.Option(
None,
"--weights-priority-order",
"-wpo",
help="For model packages only. "
"If given only the first weights matching the given weight formats are included. "
Expand All @@ -49,6 +50,9 @@ def package(
),
verbose: bool = typer.Option(False, help="show traceback of exceptions"),
) -> int:
# typer bug: typer returns empty tuple instead of None if weights_order_priority is not given
weights_priority_order = weights_priority_order or None

return commands.package(
rdf_source=rdf_source, path=path, weights_priority_order=weights_priority_order, verbose=verbose
)
Expand Down
6 changes: 3 additions & 3 deletions bioimageio/core/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ def package(
verbose: bool = False,
) -> int:
"""Package a BioImage.IO resource described by a BioImage.IO Resource Description File (RDF)."""
code = validate(rdf_source, update_format=True, update_format_inner=True, verbose=verbose)
code = validate(rdf_source, update_format=True, update_format_inner=True)
source_name = rdf_source.get("name") if isinstance(rdf_source, dict) else rdf_source
if code:
if code["error"]:
print(f"Cannot package invalid BioImage.IO RDF {source_name}")
return code
return 1

try:
tmp_package_path = export_resource_package(rdf_source, weights_priority_order=weights_priority_order)
Expand Down
12 changes: 4 additions & 8 deletions bioimageio/core/resource_io/io_.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,12 @@ def extract_resource_package(

package_path = cache_folder / sha256(str(root).encode("utf-8")).hexdigest()
if isinstance(root, raw_nodes.URI):
from urllib.request import urlretrieve

for rdf_name in RDF_NAMES:
if (package_path / rdf_name).exists():
download = None
break
else:
try:
download, header = urlretrieve(str(root))
except Exception as e:
raise RuntimeError(f"Failed to download {str(root)} ({e})")
download = resolve_uri(root)

local_source = download
else:
Expand Down Expand Up @@ -94,7 +89,8 @@ def _replace_relative_paths_for_remote_source(
else:
raise TypeError(root)

raw_rd.root_path = root_path
assert isinstance(root_path, pathlib.Path)
raw_rd.root_path = root_path.resolve()
return raw_rd


Expand Down Expand Up @@ -145,7 +141,7 @@ def load_resource_description(
raw_rd.weights = {wf: raw_rd.weights[wf]}
break
else:
raise ValueError(f"Not found any of the specified weights formats ({weights_priority_order})")
raise ValueError(f"Not found any of the specified weights formats {weights_priority_order}")

rd: ResourceDescription = resolve_raw_resource_description(raw_rd=raw_rd, nodes_module=nodes)
assert isinstance(rd, getattr(nodes, get_class_name_from_type(raw_rd.type)))
Expand Down
30 changes: 23 additions & 7 deletions bioimageio/core/resource_io/utils.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import dataclasses
import importlib.util
import logging
import os
import pathlib
import sys
import typing
import warnings
from functools import singledispatch
from types import ModuleType
from urllib.request import url2pathname, urlretrieve
from urllib.request import url2pathname

import requests
from marshmallow import ValidationError
from tqdm import tqdm

from bioimageio.spec.shared import fields, raw_nodes
from bioimageio.spec.shared.common import BIOIMAGEIO_CACHE_PATH
Expand All @@ -31,7 +31,7 @@ class UriNodeChecker(NodeVisitor):
"""raises FileNotFoundError for unavailable URIs and paths"""

def __init__(self, *, root_path: os.PathLike):
self.root_path = pathlib.Path(root_path)
self.root_path = pathlib.Path(root_path).resolve()

def visit_URI(self, node: raw_nodes.URI):
if not uri_available(node, self.root_path):
Expand All @@ -50,7 +50,7 @@ def visit_WindowsPath(self, leaf: pathlib.WindowsPath):

class UriNodeTransformer(NodeTransformer):
def __init__(self, *, root_path: os.PathLike):
self.root_path = pathlib.Path(root_path)
self.root_path = pathlib.Path(root_path).resolve()

def transform_URI(self, node: raw_nodes.URI) -> pathlib.Path:
local_path = resolve_uri(node, root_path=self.root_path)
Expand Down Expand Up @@ -147,7 +147,7 @@ def _resolve_uri_uri_node(uri: raw_nodes.URI, root_path: os.PathLike = pathlib.P
assert isinstance(uri, (raw_nodes.URI, nodes.URI))
path_or_remote_uri = resolve_local_uri(uri, root_path)
if isinstance(path_or_remote_uri, raw_nodes.URI):
local_path = _download_uri_to_local_path(path_or_remote_uri)
local_path = _download_url_to_local_path(path_or_remote_uri)
elif isinstance(path_or_remote_uri, pathlib.Path):
local_path = path_or_remote_uri
else:
Expand Down Expand Up @@ -268,14 +268,30 @@ def download_uri_to_local_path(uri: typing.Union[raw_nodes.URI, str]) -> pathlib
return resolve_uri(uri)


def _download_uri_to_local_path(uri: raw_nodes.URI) -> pathlib.Path:
def _download_url_to_local_path(uri: raw_nodes.URI) -> pathlib.Path:
local_path = BIOIMAGEIO_CACHE_PATH / uri.scheme / uri.authority / uri.path.strip("/") / uri.query
if local_path.exists():
warnings.warn(f"found cached {local_path}. Skipping download of {uri}.")
else:
local_path.parent.mkdir(parents=True, exist_ok=True)

try:
urlretrieve(str(uri), str(local_path))
# download with tqdm adapted from:
# https://github.com/shaypal5/tqdl/blob/189f7fd07f265d29af796bee28e0893e1396d237/tqdl/core.py
# Streaming, so we can iterate over the response.
r = requests.get(str(uri), stream=True)
# Total size in bytes.
total_size = int(r.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte
t = tqdm(total=total_size, unit="iB", unit_scale=True, desc=local_path.name)
with local_path.open("wb") as f:
for data in r.iter_content(block_size):
t.update(len(data))
f.write(data)
t.close()
if total_size != 0 and t.n != total_size:
# todo: check more carefully and raise on real issue
warnings.warn("Download does not have expected size.")
FynnBe marked this conversation as resolved.
Show resolved Hide resolved
except Exception as e:
raise RuntimeError(f"Failed to download {uri} ({e})")

Expand Down
5 changes: 5 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@ def test_validate_model(unet2d_nuclei_broad_model):
assert ret.returncode == 0


def test_cli_package(unet2d_nuclei_broad_model):
ret = subprocess.run(["bioimageio", "package", unet2d_nuclei_broad_model])
assert ret.returncode == 0


def test_cli_test_model(unet2d_nuclei_broad_model):
ret = subprocess.run(["bioimageio", "test-model", unet2d_nuclei_broad_model])
assert ret.returncode == 0
Expand Down