Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into bugfix/truncate-debug-output
Browse files Browse the repository at this point in the history
DhanshreeA authored Jan 27, 2025

Verified

This commit was signed with the committer’s verified signature. The key has expired.
addaleax Anna Henningsen
2 parents dd30990 + d429f07 commit 6aa8aee
Showing 6 changed files with 1,497 additions and 825 deletions.
2 changes: 1 addition & 1 deletion ersilia/auth/auth.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@
try:
from github import Github
from github.GithubException import UnknownObjectException
except ModuleNotFoundError:
except (ModuleNotFoundError, ImportError):
Github = None
UnknownObjectException = None

88 changes: 61 additions & 27 deletions ersilia/cli/commands/test.py
Original file line number Diff line number Diff line change
@@ -20,11 +20,14 @@ def test_cmd():
--------
.. code-block:: console
With default settings:
$ ersilia test my_model -d /path/to/model
With basic testing:/
$ ersilia test eosxxxx --from_dir /path/to/model
With deep testing level and inspect:
$ ersilia test my_model -d /path/to/model --level deep --inspect --remote
With different sources to fetch the model:
$ ersilia test eosxxxx --from_github/--from_dockerhub/--from_s3
With different levels of testing:
$ ersilia test eosxxxx --shallow --from_github/--from_dockerhub/--from_s3
"""

@ersilia_cli.command(
@@ -38,48 +41,79 @@ def test_cmd():
"-l",
"--level",
"level",
help="Level of testing, None: for default, deep: for deep testing",
help="Level of testing, None: for default, deep: for deep testing, shallow: for shallow testing",
required=False,
default=None,
type=click.STRING,
)
@click.option(
"-d",
"--dir",
"dir",
help="Model directory",
required=False,
"--from_dir",
default=None,
type=click.STRING,
help="Local path where the model is stored",
)
@click.option(
"--from_github",
is_flag=True,
default=False,
help="Fetch fetch directly from GitHub",
)
@click.option(
"--from_dockerhub",
is_flag=True,
default=False,
help="Force fetch from DockerHub",
)
@click.option(
"--from_s3", is_flag=True, default=False, help="Force fetch from AWS S3 bucket"
)
@click.option(
"--version",
default=None,
type=click.STRING,
help="Version of the model to fetch, when fetching a model from DockerHub",
)
@click.option(
"--inspect",
help="Inspect the model: More on the docs",
"--shallow",
is_flag=True,
default=False,
help="This flag is used to check shallow checks (such as container size, output consistency..)",
)
@click.option(
"--remote",
help="Test the model from remote git repository",
"--deep",
is_flag=True,
default=False,
help="This flag is used to check deep checks (such as computational performance checks)",
)
@click.option(
"--remove",
help="Remove the model directory after testing",
"--as_json",
is_flag=True,
default=False,
help="This flag is used to save the report as json file)",
)
def test(model, level, dir, inspect, remote, remove):
def test(
model,
level,
from_dir,
from_github,
from_dockerhub,
from_s3,
version,
shallow,
deep,
as_json,
):
mt = ModelTester(
model_id=model,
level=level,
dir=dir,
inspect=inspect,
remote=remote,
remove=remove,
model,
level,
from_dir,
from_github,
from_dockerhub,
from_s3,
version,
shallow,
deep,
as_json,
)
echo("Setting up model tester...")
mt.setup()
echo("Testing model...")
mt.run(output_file=None)
echo(f"Model testing started for: {model}")
mt.run()
1 change: 1 addition & 0 deletions ersilia/default.py
Original file line number Diff line number Diff line change
@@ -7,6 +7,7 @@

# EOS environmental variables
EOS = os.path.join(str(Path.home()), "eos")
EOS_TMP = os.path.join(EOS, "temp")
if not os.path.exists(EOS):
os.makedirs(EOS)
ROOT = os.path.dirname(os.path.realpath(__file__))
90 changes: 70 additions & 20 deletions ersilia/publish/inspect.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
import subprocess
import time
from collections import namedtuple
@@ -387,23 +388,42 @@ def validate_repo_structure(self):

def _validate_dockerfile(self, dockerfile_content):
lines, errors = dockerfile_content.splitlines(), []
for line in lines:
if line.startswith("RUN pip install"):
cmd = line.split("RUN ")[-1]
result = subprocess.run(
cmd,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
)
if result.returncode != 0:
errors.append(f"Failed to run {cmd}: {result.stderr.strip()}")

if "WORKDIR /repo" not in dockerfile_content:
errors.append("Missing 'WORKDIR /repo'.")
if "COPY . /repo" not in dockerfile_content:
errors.append("Missing 'COPY . /repo'.")

pip_install_pattern = re.compile(r"pip install (.+)")
version_pin_pattern = re.compile(r"^[a-zA-Z0-9_\-\.]+==[a-zA-Z0-9_\-\.]+$")

for line in lines:
line = line.strip()

match = pip_install_pattern.search(line)
if match:
packages_and_options = match.group(1).split()
skip_next = False

for item in packages_and_options:
if skip_next:
skip_next = False
continue

if item.startswith("--index-url") or item.startswith(
"--extra-index-url"
):
skip_next = True
continue

if item.startswith("git+"):
continue

if not version_pin_pattern.match(item):
errors.append(
f"Package '{item}' in line '{line}' is not version-pinned (e.g., 'package==1.0.0')."
)

return errors

def _validate_yml(self, yml_content):
@@ -417,18 +437,48 @@ def _validate_yml(self, yml_content):
if not python_version:
errors.append("Missing Python version in install.yml.")

version_pin_pattern = re.compile(r"^[a-zA-Z0-9_\-\.]+==[a-zA-Z0-9_\-\.]+$")

commands = yml_data.get("commands", [])
for command in commands:
if not isinstance(command, list) or command[0] != "pip":
if not isinstance(command, list) or len(command) < 2:
errors.append(f"Invalid command format: {command}")
continue
# package: name & version
name = command[1] if len(command) > 1 else None

tool = command[0]
_ = command[1]
version = command[2] if len(command) > 2 else None
if not name:
errors.append(f"Missing package name in command: {command}")
if name and version:
pass

if tool in ("pip", "conda"):
if tool == "pip":
pip_args = command[1:]
skip_next = False

for item in pip_args:
if skip_next:
skip_next = False
continue

if item.startswith("--index-url") or item.startswith(
"--extra-index-url"
):
skip_next = True
continue

if item.startswith("git+"):
continue

if not version_pin_pattern.match(item):
errors.append(
f"Package '{item}' in command '{command}' is not version-pinned (e.g., 'package==1.0.0')."
)

elif tool == "conda" and not version:
errors.append(
f"Package in command '{command}' does not have a valid pinned version "
f"(should be in the format ['conda', 'package_name', 'x.y.z'])."
)

return errors

def _run_performance_check(self, n):
@@ -445,5 +495,5 @@ def _run_performance_check(self, n):
return Result(False, f"Error serving model: {process.stderr.strip()}")
execution_time = time.time() - start_time
return Result(
True, f"{n} predictions executed in {execution_time:.2f} seconds."
True, f"{n} predictions executed in {execution_time:.2f} seconds. \n"
)
Loading

0 comments on commit 6aa8aee

Please sign in to comment.