diff --git a/flytekit/image_spec/default_builder.py b/flytekit/image_spec/default_builder.py index a09f445ba9..b2d7c6753c 100644 --- a/flytekit/image_spec/default_builder.py +++ b/flytekit/image_spec/default_builder.py @@ -1,5 +1,6 @@ import json import os +import re import shutil import subprocess import sys @@ -98,11 +99,25 @@ def get_flytekit_for_pypi(): return f"flytekit=={__version__}" +_PACKAGE_NAME_RE = re.compile(r"^[\w-]+") + + +def _is_flytekit(package: str) -> bool: + """Return True if `package` is flytekit. `package` is expected to be a valid version + spec. i.e. `flytekit==1.12.3`, `flytekit`, `flytekit~=1.12.3`. + """ + m = _PACKAGE_NAME_RE.match(package) + if not m: + return False + name = m.group() + return name == "flytekit" + + def create_docker_context(image_spec: ImageSpec, tmp_dir: Path): """Populate tmp_dir with Dockerfile as specified by the `image_spec`.""" base_image = image_spec.base_image or "debian:bookworm-slim" - requirements = [get_flytekit_for_pypi()] + requirements = [] if image_spec.cuda is not None or image_spec.cudnn is not None: msg = ( @@ -124,6 +139,10 @@ def create_docker_context(image_spec: ImageSpec, tmp_dir: Path): if image_spec.packages: requirements.extend(image_spec.packages) + # Adds flytekit if it is not specified + if not any(_is_flytekit(package) for package in requirements): + requirements.append(get_flytekit_for_pypi()) + uv_requirements = [] # uv does not support git + subdirectory, so we use pip to install them instead diff --git a/tests/flytekit/unit/core/image_spec/test_default_builder.py b/tests/flytekit/unit/core/image_spec/test_default_builder.py index 8941da81b9..e541abe9d6 100644 --- a/tests/flytekit/unit/core/image_spec/test_default_builder.py +++ b/tests/flytekit/unit/core/image_spec/test_default_builder.py @@ -2,6 +2,7 @@ import pytest +import flytekit from flytekit.image_spec import ImageSpec from flytekit.image_spec.default_builder import DefaultImageBuilder, create_docker_context @@ -99,6 +100,41 @@ def test_create_docker_context_with_null_entrypoint(tmp_path): assert "ENTRYPOINT []" in dockerfile_content +@pytest.mark.parametrize("flytekit_spec", [None, "flytekit>=1.12.3", "flytekit==1.12.3"]) +def test_create_docker_context_with_flytekit(tmp_path, flytekit_spec, monkeypatch): + + # pretend version is 1.13.0 + mock_version = "1.13.0" + monkeypatch.setattr(flytekit, "__version__", mock_version) + + docker_context_path = tmp_path / "builder_root" + docker_context_path.mkdir() + + if flytekit_spec: + packages = [flytekit_spec] + else: + packages = [] + + image_spec = ImageSpec( + name="FLYTEKIT", packages=packages + ) + + create_docker_context(image_spec, docker_context_path) + + dockerfile_path = docker_context_path / "Dockerfile" + assert dockerfile_path.exists() + + requirements_path = docker_context_path / "requirements_uv.txt" + assert requirements_path.exists() + + requirements_content = requirements_path.read_text() + if flytekit_spec: + flytekit_spec in requirements_content + assert f"flytekit=={mock_version}" not in requirements_content + else: + assert f"flytekit=={mock_version}" in requirements_content + + def test_create_docker_context_cuda(tmp_path): docker_context_path = tmp_path / "builder_root" docker_context_path.mkdir()