Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raw Container Task Local Execution #2258

Merged
merged 32 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
1bb74f6
init
Future-Outlier Mar 12, 2024
b860037
v1
Future-Outlier Mar 13, 2024
3864b68
argurments bug fixed and add log when pulling image
Future-Outlier Mar 15, 2024
f982402
change v to k and handle boolean special case
Future-Outlier Mar 15, 2024
ab05253
support blob type and datetime
Future-Outlier Mar 18, 2024
5817838
add unit tests
Future-Outlier Mar 18, 2024
6c2f0b1
add exception
Future-Outlier Mar 18, 2024
c9d1899
nit
Future-Outlier Mar 18, 2024
09ce020
fix test
Future-Outlier Mar 18, 2024
d6482d6
update for flytefile and flytedirectory
Future-Outlier Mar 20, 2024
19bb116
support both file paths and template inputs
Future-Outlier Mar 27, 2024
c5155a5
pytest use sys platform to handle macos and windows case and support …
Future-Outlier Mar 28, 2024
99f8e07
support datetime.timedelta
Future-Outlier Mar 28, 2024
c16fdbd
lint
Future-Outlier Mar 28, 2024
350f934
add tests and change boolean logic
Future-Outlier Mar 28, 2024
963c0c2
support
Future-Outlier Mar 28, 2024
21e3057
change annotations
Future-Outlier Mar 31, 2024
32546b4
add flytefile and flytedir tests
Future-Outlier Apr 9, 2024
8f62c18
lint
Future-Outlier Apr 9, 2024
87f8d38
add more tests
Future-Outlier Apr 9, 2024
902e67a
lint
Future-Outlier Apr 9, 2024
20fb365
change image name
Future-Outlier Apr 9, 2024
9e77cfb
Update pingsu's advice
Future-Outlier Apr 10, 2024
0d2739d
add docker in dev-requirement
Future-Outlier Apr 10, 2024
a9b0e1e
refactor execution
Future-Outlier Apr 10, 2024
78b5be1
use render pattern
Future-Outlier Apr 11, 2024
732bfe1
add back container task object in test
Future-Outlier Apr 11, 2024
cef801d
refactor output in container task execution
Future-Outlier Apr 11, 2024
bbb59e8
update pingsu's render input advice
Future-Outlier Apr 11, 2024
79f73c6
update tests
Future-Outlier Apr 11, 2024
37f8263
add LiteralMap TypeHints
Future-Outlier Apr 12, 2024
b068357
update dev-req
Future-Outlier Apr 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 162 additions & 3 deletions flytekit/core/container_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import typing
from enum import Enum
from typing import Any, Dict, List, Optional, OrderedDict, Type
from typing import Any, Dict, List, Optional, OrderedDict, Tuple, Type

from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask, TaskMetadata
Expand All @@ -11,10 +12,13 @@
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec
from flytekit.image_spec.image_spec import ImageSpec
from flytekit.loggers import logger
from flytekit.models import task as _task_model
from flytekit.models.literals import LiteralMap
from flytekit.models.security import Secret, SecurityContext

_PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
DOCKER_IMPORT_ERROR_MESSAGE = "Docker is not installed. Please install Docker by running `pip install docker`."


class ContainerTask(PythonTask):
Expand Down Expand Up @@ -82,6 +86,7 @@ def __init__(
self._args = arguments
self._input_data_dir = input_data_dir
self._output_data_dir = output_data_dir
self._outputs = outputs
self._md_format = metadata_format
self._io_strategy = io_strategy
self._resources = ResourceSpec(
Expand All @@ -93,8 +98,162 @@ def __init__(
def resources(self) -> ResourceSpec:
return self._resources

def local_execute(self, ctx: FlyteContext, **kwargs) -> Any:
raise RuntimeError("ContainerTask is not supported in local executions.")
def _extract_command_key(self, cmd: str, **kwargs) -> Any:
"""
Extract the key from the command using regex.
"""
import re

input_regex = r"^\{\{\s*\.inputs\.(.*?)\s*\}\}$"
match = re.match(input_regex, cmd)
if match:
return match.group(1)
return None

def _render_command_and_volume_binding(self, cmd: str, **kwargs) -> Tuple[str, Dict[str, Dict[str, str]]]:
"""
We support template-style references to inputs, e.g., "{{.inputs.infile}}".
"""
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile

command = ""
volume_binding = {}
k = self._extract_command_key(cmd)

if k:
input_val = kwargs.get(k)
if type(input_val) in [FlyteFile, FlyteDirectory]:
local_flyte_file_or_dir_path = str(input_val)
remote_flyte_file_or_dir_path = os.path.join(self._input_data_dir, k.replace(".", "/")) # type: ignore
volume_binding[local_flyte_file_or_dir_path] = {
"bind": remote_flyte_file_or_dir_path,
"mode": "rw",
}
command = remote_flyte_file_or_dir_path
else:
command = str(input_val)
else:
command = cmd

return command, volume_binding

def _prepare_command_and_volumes(
self, cmd_and_args: List[str], **kwargs
) -> Tuple[List[str], Dict[str, Dict[str, str]]]:
"""
Prepares the command and volume bindings for the container based on input arguments and command templates.

Parameters:
- cmd_and_args (List[str]): The command and arguments to prepare.
- **kwargs: Keyword arguments representing task inputs.

Returns:
- Tuple[List[str], Dict[str, Dict[str, str]]]: A tuple containing the prepared commands and volume bindings.
"""

commands = []
volume_bindings = {}

for cmd in cmd_and_args:
command, volume_binding = self._render_command_and_volume_binding(cmd, **kwargs)
commands.append(command)
volume_bindings.update(volume_binding)

return commands, volume_bindings

def _pull_image_if_not_exists(self, client, image: str):
try:
if not client.images.list(filters={"reference": image}):
logger.info(f"Pulling image: {image} for container task: {self.name}")
client.images.pull(image)
except Exception as e:
logger.error(f"Failed to pull image {image}: {str(e)}")
raise

def _string_to_timedelta(self, s: str):
import datetime
import re

regex = r"(?:(\d+) days?, )?(?:(\d+):)?(\d+):(\d+)(?:\.(\d+))?"
parts = re.match(regex, s)
if not parts:
raise ValueError("Invalid timedelta string format")

days = int(parts.group(1)) if parts.group(1) else 0
hours = int(parts.group(2)) if parts.group(2) else 0
minutes = int(parts.group(3)) if parts.group(3) else 0
seconds = int(parts.group(4)) if parts.group(4) else 0
microseconds = int(parts.group(5)) if parts.group(5) else 0

return datetime.timedelta(
days=days,
hours=hours,
minutes=minutes,
seconds=seconds,
microseconds=microseconds,
)

def _convert_output_val_to_correct_type(self, output_val: Any, output_type: Any) -> Any:
import datetime

if output_type == bool:
return output_val.lower() != "false"
elif output_type == datetime.datetime:
return datetime.datetime.fromisoformat(output_val)
elif output_type == datetime.timedelta:
return self._string_to_timedelta(output_val)
else:
return output_type(output_val)

def _get_output_dict(self, output_directory: str) -> Dict[str, Any]:
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile

output_dict = {}
if self._outputs:
for k, output_type in self._outputs.items():
output_path = os.path.join(output_directory, k)
if output_type in [FlyteFile, FlyteDirectory]:
output_dict[k] = output_type(path=output_path)
else:
with open(output_path, "r") as f:
output_val = f.read()
output_dict[k] = self._convert_output_val_to_correct_type(output_val, output_type)
return output_dict

def execute(self, **kwargs) -> LiteralMap:
try:
import docker
except ImportError:
raise ImportError(DOCKER_IMPORT_ERROR_MESSAGE)

from flytekit.core.type_engine import TypeEngine

ctx = FlyteContext.current_context()

# Normalize the input and output directories
self._input_data_dir = os.path.normpath(self._input_data_dir) if self._input_data_dir else ""
self._output_data_dir = os.path.normpath(self._output_data_dir) if self._output_data_dir else ""

output_directory = ctx.file_access.get_random_local_directory()
cmd_and_args = (self._cmd or []) + (self._args or [])
commands, volume_bindings = self._prepare_command_and_volumes(cmd_and_args, **kwargs)
volume_bindings[output_directory] = {"bind": self._output_data_dir, "mode": "rw"}

client = docker.from_env()
self._pull_image_if_not_exists(client, self._image)

container = client.containers.run(
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
self._image, command=commands, remove=True, volumes=volume_bindings, detach=True
)
# Wait for the container to finish the task
# TODO: Add a 'timeout' parameter to control the max wait time for the container to finish the task.
container.wait()
eapolinario marked this conversation as resolved.
Show resolved Hide resolved

output_dict = self._get_output_dict(output_directory)
outputs_literal_map = TypeEngine.dict_to_literal_map(ctx, output_dict)
return outputs_literal_map

def get_container(self, settings: SerializationSettings) -> _task_model.Container:
# if pod_template is specified, return None here but in get_k8s_pod, return pod_template merged with container
Expand Down
9 changes: 9 additions & 0 deletions tests/flytekit/unit/core/Dockerfile.raw_container
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
FROM python:3.9-alpine

WORKDIR /root

COPY ./write_flytefile.py /root/write_flytefile.py
COPY ./write_flytedir.py /root/write_flytedir.py
COPY ./return_same_value.py /root/return_same_value.py

CMD ["/bin/sh"]
22 changes: 22 additions & 0 deletions tests/flytekit/unit/core/return_same_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os
import sys


def write_output(output_dir, output_file, v):
# Ensure the output directory exists
os.makedirs(output_dir, exist_ok=True) # This will create the directory if it doesn't exist
with open(f"{output_dir}/{output_file}", "w") as f:
f.write(str(v))


def main(*args, output_dir):
# Generate output files for each input argument
for i, arg in enumerate(args, start=1):
# Using i to generate filenames like 'a', 'b', 'c', ...
output_file = chr(ord("a") + i - 1)
write_output(output_dir, output_file, arg)


if __name__ == "__main__":
*inputs, output_dir = sys.argv[1:] # Unpack all inputs except for the last one for output_dir
main(*inputs, output_dir=output_dir)
87 changes: 73 additions & 14 deletions tests/flytekit/unit/core/test_container_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os
import sys
from collections import OrderedDict
from typing import Tuple

import pytest
from kubernetes.client.models import (
Expand All @@ -13,14 +16,83 @@
V1Toleration,
)

from flytekit import kwtypes
from flytekit import kwtypes, task, workflow
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core.container_task import ContainerTask
from flytekit.core.pod_template import PodTemplate
from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpec
from flytekit.tools.translator import get_serializable_task


@pytest.mark.skipif(
sys.platform in ["darwin", "win32"],
reason="Skip if running on windows or macos due to CI Docker environment setup failure",
)
def test_local_execution():
calculate_ellipse_area_python_template_style = ContainerTask(
name="calculate_ellipse_area_python_template_style",
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
inputs=kwtypes(a=float, b=float),
outputs=kwtypes(area=float, metadata=str),
image="ghcr.io/flyteorg/rawcontainers-python:v2",
command=[
"python",
"calculate-ellipse-area.py",
"{{.inputs.a}}",
"{{.inputs.b}}",
"/var/outputs",
],
)

area, metadata = calculate_ellipse_area_python_template_style(a=3.0, b=4.0)
assert isinstance(area, float)
assert isinstance(metadata, str)

# Workflow execution with container task
@task
def t1(a: float, b: float) -> Tuple[float, float]:
return a + b, a * b

@workflow
def wf(a: float, b: float) -> Tuple[float, str]:
a, b = t1(a=a, b=b)
area, metadata = calculate_ellipse_area_python_template_style(a=a, b=b)
return area, metadata

area, metadata = wf(a=3.0, b=4.0)
assert isinstance(area, float)
assert isinstance(metadata, str)


@pytest.mark.skipif(
sys.platform == "win32",
reason="Skip if running on windows due to path error",
)
def test_local_execution_special_cases():
# Boolean conversion from string checks
assert all([bool(s) for s in ["False", "false", "True", "true"]])

# Path normalization
input_data_dir = "/var/inputs"
assert os.path.normpath(input_data_dir) == "/var/inputs"
assert os.path.normpath(input_data_dir + "/") == "/var/inputs"

# Datetime and timedelta string conversions
ct = ContainerTask(
name="local-execution",
image="test-image",
command="echo",
)

from datetime import datetime, timedelta

now = datetime.now()
assert datetime.fromisoformat(str(now)) == now
td = timedelta(days=1, hours=1, minutes=1, seconds=1, microseconds=1)
assert td == ct._string_to_timedelta(str(td))


def test_pod_template():
ps = V1PodSpec(
containers=[], tolerations=[V1Toleration(effect="NoSchedule", key="nvidia.com/gpu", operator="Exists")]
Expand Down Expand Up @@ -86,19 +158,6 @@ def test_pod_template():
assert serialized_pod_spec["runtimeClassName"] == "nvidia"


def test_local_execution():
ct = ContainerTask(
name="name",
input_data_dir="/var/inputs",
output_data_dir="/var/outputs",
image="inexistent-image:v42",
command=["some", "command"],
)

with pytest.raises(RuntimeError):
ct()


def test_raw_container_with_image_spec(mock_image_spec_builder):
ImageBuildEngine.register("test-raw-container", mock_image_spec_builder)
image_spec = ImageSpec(registry="flyte", base_image="r-base", builder="test-raw-container")
Expand Down
Loading
Loading