Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Cuda to ImageSpec #1688

Merged
merged 11 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions flytekit/image_spec/image_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,17 @@
import typing
from abc import abstractmethod
from copy import copy
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from functools import lru_cache
from typing import List, Optional

import click
import requests
from dataclasses_json import dataclass_json

DOCKER_HUB = "docker.io"
_F_IMG_ID = "_F_IMG_ID"


@dataclass_json
@dataclass
class ImageSpec:
"""
Expand All @@ -31,6 +29,8 @@ class ImageSpec:
registry: registry of the image.
packages: list of python packages to install.
apt_packages: list of apt packages to install.
cuda: version of cuda to install.
cudnn: version of cudnn to install.
base_image: base image of the image.
platform: Specify the target platforms for the build output (for example, windows/amd64 or linux/amd64,darwin/arm64
pip_index: Specify the custom pip index url
Expand All @@ -44,6 +44,8 @@ class ImageSpec:
registry: Optional[str] = None
packages: Optional[List[str]] = None
apt_packages: Optional[List[str]] = None
cuda: Optional[str] = None
cudnn: Optional[str] = None
base_image: Optional[str] = None
platform: str = "linux/amd64"
pip_index: Optional[str] = None
Expand Down Expand Up @@ -106,7 +108,7 @@ def exist(self) -> bool:
return True

def __hash__(self):
return hash(self.to_json())
return hash(asdict(self).__str__())
eapolinario marked this conversation as resolved.
Show resolved Hide resolved


class ImageSpecBuilder:
Expand Down Expand Up @@ -151,7 +153,7 @@ def calculate_hash_from_image_spec(image_spec: ImageSpec):
# copy the image spec to avoid modifying the original image spec. otherwise, the hash will be different.
spec = copy(image_spec)
spec.source_root = hash_directory(image_spec.source_root) if image_spec.source_root else b""
image_spec_bytes = bytes(spec.to_json(), "utf-8")
image_spec_bytes = asdict(spec).__str__().encode("utf-8")
tag = base64.urlsafe_b64encode(hashlib.md5(image_spec_bytes).digest()).decode("ascii")
# replace "=" with "." and replace "-" with "_" to make it a valid tag
return tag.replace("=", ".").replace("-", "_")
Expand Down
14 changes: 11 additions & 3 deletions plugins/flytekit-envd/flytekitplugins/envd/image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def build_image(self, image_spec: ImageSpec):

def create_envd_config(image_spec: ImageSpec) -> str:
base_image = DefaultImages.default_image() if image_spec.base_image is None else image_spec.base_image
if image_spec.cuda:
if image_spec.python_version is None:
raise Exception("python_version is required when cuda and cudnn are specified")
base_image = "ubuntu20.04"

packages = [] if image_spec.packages is None else image_spec.packages
apt_packages = [] if image_spec.apt_packages is None else image_spec.apt_packages
env = {"PYTHONPATH": "/root", _F_IMG_ID: image_spec.image_name()}
Expand All @@ -52,14 +57,17 @@ def build():
runtime.environ(env={env})
config.pip_index(url = "{pip_index}")
"""
ctx = context_manager.FlyteContextManager.current_context()
cfg_path = ctx.file_access.get_random_local_path("build.envd")
pathlib.Path(cfg_path).parent.mkdir(parents=True, exist_ok=True)

if image_spec.python_version:
# Indentation is required by envd
envd_config += f' install.python(version="{image_spec.python_version}")\n'

ctx = context_manager.FlyteContextManager.current_context()
cfg_path = ctx.file_access.get_random_local_path("build.envd")
pathlib.Path(cfg_path).parent.mkdir(parents=True, exist_ok=True)
if image_spec.cuda:
cudnn = image_spec.cudnn if image_spec.cudnn else ""
envd_config += f' install.cuda(version="{image_spec.cuda}", cudnn="{cudnn}")\n'

if image_spec.source_root:
shutil.copytree(image_spec.source_root, pathlib.Path(cfg_path).parent, dirs_exist_ok=True)
Expand Down
6 changes: 3 additions & 3 deletions plugins/flytekit-envd/tests/test_image_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,24 @@ def test_image_spec():
packages=["pandas"],
apt_packages=["git"],
python_version="3.8",
registry="",
base_image="cr.flyte.org/flyteorg/flytekit:py3.8-latest",
pip_index="https://private-pip-index/simple",
)

EnvdImageSpecBuilder().build_image(image_spec)
config_path = create_envd_config(image_spec)
assert image_spec.platform == "linux/amd64"
image_name = image_spec.image_name()
contents = Path(config_path).read_text()
assert (
contents
== """# syntax=v1
== f"""# syntax=v1

def build():
base(image="cr.flyte.org/flyteorg/flytekit:py3.8-latest", dev=False)
install.python_packages(name = ["pandas"])
install.apt_packages(name = ["git"])
runtime.environ(env={'PYTHONPATH': '/root', '_F_IMG_ID': 'flytekit:46qVNvYHJxppEvVIYrthdA..'})
runtime.environ(env={{'PYTHONPATH': '/root', '_F_IMG_ID': '{image_name}'}})
config.pip_index(url = "https://private-pip-index/simple")
install.python(version="3.8")
"""
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,9 @@ def test_list_default_arguments(wf_path):
)

ic_result_4 = ImageConfig(
default_image=Image(name="default", fqn="flytekit", tag="6y6c8ofS_Pwa2FImlcm3Qg.."),
default_image=Image(name="default", fqn="flytekit", tag="eJgTB5QCJDOSksy6gE0lXA.."),
images=[
Image(name="default", fqn="flytekit", tag="6y6c8ofS_Pwa2FImlcm3Qg.."),
Image(name="default", fqn="flytekit", tag="eJgTB5QCJDOSksy6gE0lXA.."),
Image(name="xyz", fqn="docker.io/xyz", tag="latest"),
Image(name="abc", fqn="docker.io/abc", tag=None),
],
Expand Down
17 changes: 15 additions & 2 deletions tests/flytekit/unit/core/image_spec/test_image_spec.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import base64
import hashlib
import os
from dataclasses import asdict

import pytest

Expand All @@ -15,20 +18,29 @@ def test_image_spec():
python_version="3.8",
registry="",
base_image="cr.flyte.org/flyteorg/flytekit:py3.8-latest",
cuda="11.2.2",
cudnn="8",
)

assert image_spec.python_version == "3.8"
assert image_spec.base_image == "cr.flyte.org/flyteorg/flytekit:py3.8-latest"
assert image_spec.packages == ["pandas"]
assert image_spec.apt_packages == ["git"]
assert image_spec.registry == ""
assert image_spec.cuda == "11.2.2"
assert image_spec.cudnn == "8"
assert image_spec.name == "flytekit"
assert image_spec.builder == "envd"
assert image_spec.source_root is None
assert image_spec.env is None
assert image_spec.pip_index is None
assert image_spec.is_container() is True
assert image_spec.image_name() == "flytekit:_7s_KKi_73h88RBfRZ8jpQ.."

image_spec.source_root = b""
image_spec_bytes = asdict(image_spec).__str__().encode("utf-8")
tag = base64.urlsafe_b64encode(hashlib.md5(image_spec_bytes).digest()).decode("ascii")
tag = tag.replace("=", ".")
assert image_spec.image_name() == f"flytekit:{tag}"
ctx = context_manager.FlyteContext.current_context()
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION))
Expand All @@ -42,8 +54,9 @@ def build_image(self, img):

ImageBuildEngine.register("dummy", DummyImageSpecBuilder())
ImageBuildEngine._REGISTRY["dummy"].build_image(image_spec)

assert "dummy" in ImageBuildEngine._REGISTRY
assert calculate_hash_from_image_spec(image_spec) == "_7s_KKi_73h88RBfRZ8jpQ.."
assert calculate_hash_from_image_spec(image_spec) == tag
assert image_spec.exist() is False

with pytest.raises(Exception):
Expand Down
6 changes: 2 additions & 4 deletions tests/flytekit/unit/core/test_python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,8 @@ def build_image(self, img):
...

ImageBuildEngine.register("test", TestImageSpecBuilder())
assert (
get_registerable_container_image(ImageSpec(builder="test", python_version="3.7"), cfg)
== "flytekit:htjuk6SUglpN7CGTPPtQIA.."
)
image_spec = ImageSpec(builder="test", python_version="3.7", registry="")
assert get_registerable_container_image(image_spec, cfg) == image_spec.image_name()


def test_get_registerable_container_image_no_images():
Expand Down