Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <pingsutw@apache.org>
  • Loading branch information
pingsutw committed May 5, 2023
1 parent 9cadb80 commit 0357806
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 103 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/pythonpublish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,11 @@ jobs:
uses: docker/metadata-action@v3
with:
images: |
ghcr.io/${{ github.repository_owner }}/flytekit
ghcr.io/${{ github.repository_owner }}/external-plugin-service
tags: |
external-plugin-service-latest
external-plugin-service-${{ github.sha }}
external-plugin-service-${{ needs.deploy.outputs.version }}
latest
${{ github.sha }}
${{ needs.deploy.outputs.version }}
- name: Push External Plugin Service Image to GitHub Registry
uses: docker/build-push-action@v2
with:
Expand Down
43 changes: 42 additions & 1 deletion flytekit/extend/backend/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,29 @@
TaskGetResponse,
)

from flytekit import logger
from flytekit.models.literals import LiteralMap


class BackendPluginBase(ABC):
"""
This is the base class for all backend plugins. It defines the interface that all plugins must implement.
The external plugins service will be run either locally or in a pod, and will be responsible for
invoking backend plugins. The propeller will communicate with the external plugins service
to create tasks, get the status of tasks, and delete tasks.
All the backend plugins should be registered in the BackendPluginRegistry. External plugins service
will look up the plugin based on the task type. Every task type can only have one plugin.
"""

def __init__(self, task_type: str):
self._task_type = task_type

@property
def task_type(self) -> str:
"""
task_type is the name of the task type that this plugin supports.
"""
return self._task_type

@abstractmethod
Expand All @@ -32,30 +46,57 @@ def create(
task_template: TaskTemplate,
inputs: typing.Optional[LiteralMap] = None,
) -> TaskCreateResponse:
"""
Return a Unique ID for the task that was created. It should return error code if the task creation failed.
"""
pass

@abstractmethod
def get(self, context: grpc.ServicerContext, job_id: str) -> TaskGetResponse:
"""
Return the status of the task, and return the outputs in some cases. For example, bigquery job
can't write the structured dataset to the output location, so it returns the output literals to the propeller,
and the propeller will write the structured dataset to the blob store.
"""
pass

@abstractmethod
def delete(self, context: grpc.ServicerContext, job_id: str) -> TaskDeleteResponse:
"""
Delete the task. This call should be idempotent.
"""
pass


class BackendPluginRegistry(object):
"""
This is the registry for all backend plugins. The external plugins service will look up the plugin
based on the task type.
"""

_REGISTRY: typing.Dict[str, BackendPluginBase] = {}

@staticmethod
def register(plugin: BackendPluginBase):
if plugin.task_type in BackendPluginRegistry._REGISTRY:
raise ValueError(f"Duplicate plugin for task type {plugin.task_type}")
BackendPluginRegistry._REGISTRY[plugin.task_type] = plugin
logger.info(f"Registering backend plugin for task type {plugin.task_type}")

@staticmethod
def get_plugin(task_type: str) -> BackendPluginBase:
def get_plugin(context: grpc.ServicerContext, task_type: str) -> typing.Optional[BackendPluginBase]:
if task_type not in BackendPluginRegistry._REGISTRY:
logger.error(f"Cannot find backend plugin for task type [{task_type}]")
context.set_code(grpc.StatusCode.NOT_FOUND)
context.set_details(f"Cannot find backend plugin for task type [{task_type}]")
return None
return BackendPluginRegistry._REGISTRY[task_type]


def convert_to_flyte_state(state: str) -> State:
"""
Convert the state from the backend plugin to the state in flyte.
"""
state = state.lower()
if state in ["failed"]:
return RETRYABLE_FAILURE
Expand Down
27 changes: 19 additions & 8 deletions flytekit/extend/backend/external_plugin_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import grpc
from flyteidl.service.external_plugin_service_pb2 import (
PERMANENT_FAILURE,
TaskCreateRequest,
TaskCreateResponse,
TaskDeleteRequest,
Expand All @@ -9,34 +10,44 @@
)
from flyteidl.service.external_plugin_service_pb2_grpc import ExternalPluginServiceServicer

from flytekit.extend.backend import model
from flytekit import logger
from flytekit.extend.backend.base_plugin import BackendPluginRegistry
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate


class BackendPluginServer(ExternalPluginServiceServicer):
def CreateTask(self, request: TaskCreateRequest, context: grpc.ServicerContext) -> TaskCreateResponse:
try:
req = model.TaskCreateRequest.from_flyte_idl(request)
plugin = BackendPluginRegistry.get_plugin(req.template.type)
return plugin.create(
context=context, inputs=req.inputs, output_prefix=req.output_prefix, task_template=req.template
)
tmp = TaskTemplate.from_flyte_idl(request.template)
inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None
plugin = BackendPluginRegistry.get_plugin(context, tmp.type)
if plugin is None:
return TaskCreateResponse()
return plugin.create(context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp)
except Exception as e:
logger.error(f"failed to create task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to create task with error {e}")

def GetTask(self, request: TaskGetRequest, context: grpc.ServicerContext) -> TaskGetResponse:
try:
plugin = BackendPluginRegistry.get_plugin(request.task_type)
plugin = BackendPluginRegistry.get_plugin(context, request.task_type)
if plugin is None:
return TaskGetResponse(state=PERMANENT_FAILURE)
return plugin.get(context=context, job_id=request.job_id)
except Exception as e:
logger.error(f"failed to get task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to get task with error {e}")

def DeleteTask(self, request: TaskDeleteRequest, context: grpc.ServicerContext) -> TaskDeleteResponse:
try:
plugin = BackendPluginRegistry.get_plugin(request.task_type)
plugin = BackendPluginRegistry.get_plugin(context, request.task_type)
if plugin is None:
return TaskDeleteResponse()
return plugin.delete(context=context, job_id=request.job_id)
except Exception as e:
logger.error(f"failed to delete task with error {e}")
context.set_code(grpc.StatusCode.INTERNAL)
context.set_details(f"failed to delete task with error {e}")
40 changes: 0 additions & 40 deletions flytekit/extend/backend/model.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from google.cloud import bigquery

from flytekit import FlyteContextManager, StructuredDataset
from flytekit import FlyteContextManager, StructuredDataset, logger
from flytekit.core.type_engine import TypeEngine
from flytekit.extend.backend.base_plugin import BackendPluginBase, BackendPluginRegistry, convert_to_flyte_state
from flytekit.models import literals
Expand Down Expand Up @@ -49,6 +49,7 @@ def create(
}
native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs)

logger.info(f"Create BigQuery job config with inputs: {native_inputs}")
job_config = bigquery.QueryJobConfig(
query_parameters=[
bigquery.ScalarQueryParameter(name, pythonTypeToBigQueryType[python_interface_inputs[name]], val)
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-bigquery/tests/test_backend_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def __init__(self):
mock_instance.query.return_value = MockJob()
mock_instance.cancel_job.return_value = MockJob()

p = BackendPluginRegistry.get_plugin("bigquery_query_job_task")
ctx = MagicMock(spec=grpc.ServicerContext)
p = BackendPluginRegistry.get_plugin(ctx, "bigquery_query_job_task")

task_id = Identifier(
resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version"
Expand Down
8 changes: 3 additions & 5 deletions tests/flytekit/unit/extend/test_backend_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import grpc
from flyteidl.service.external_plugin_service_pb2 import (
PERMANENT_FAILURE,
SUCCEEDED,
TaskCreateRequest,
TaskCreateResponse,
Expand All @@ -28,9 +29,6 @@ class DummyPlugin(BackendPluginBase):
def __init__(self):
super().__init__(task_type="dummy")

def initialize(self):
pass

def create(
self,
context: grpc.ServicerContext,
Expand Down Expand Up @@ -85,8 +83,8 @@ def delete(self, context: grpc.ServicerContext, job_id) -> TaskDeleteResponse:


def test_dummy_plugin():
p = BackendPluginRegistry.get_plugin("dummy")
ctx = MagicMock(spec=grpc.ServicerContext)
p = BackendPluginRegistry.get_plugin(ctx, "dummy")
assert p.create(ctx, "/tmp", dummy_template, task_inputs).job_id == dummy_id
assert p.get(ctx, dummy_id).state == SUCCEEDED
assert p.delete(ctx, dummy_id) == TaskDeleteResponse()
Expand All @@ -104,4 +102,4 @@ def test_backend_plugin_server():
assert server.DeleteTask(TaskDeleteRequest(task_type="dummy", job_id=dummy_id), ctx) == TaskDeleteResponse()

res = server.GetTask(TaskGetRequest(task_type="fake", job_id=dummy_id), ctx)
assert res is None
assert res.state == PERMANENT_FAILURE
43 changes: 0 additions & 43 deletions tests/flytekit/unit/extend/test_model.py

This file was deleted.

0 comments on commit 0357806

Please sign in to comment.