-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor Python SDK #568
Refactor Python SDK #568
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,16 +14,13 @@ | |
|
||
|
||
from collections import defaultdict | ||
import copy | ||
import inspect | ||
import re | ||
import string | ||
import tarfile | ||
import tempfile | ||
import yaml | ||
|
||
from .. import dsl | ||
|
||
from ._k8s_helper import K8sHelper | ||
|
||
class Compiler(object): | ||
"""DSL Compiler. | ||
|
@@ -42,9 +39,17 @@ def my_pipeline(a: dsl.PipelineParam, b: dsl.PipelineParam): | |
""" | ||
|
||
def _sanitize_name(self, name): | ||
return re.sub('-+', '-', re.sub('[^-0-9a-z]+', '-', name.lower())).lstrip('-').rstrip('-') #from _make_kubernetes_name | ||
"""From _make_kubernetes_name | ||
_sanitize_name cleans and converts the names in the workflow. | ||
""" | ||
return re.sub('-+', '-', re.sub('[^-0-9a-z]+', '-', name.lower())).lstrip('-').rstrip('-') | ||
|
||
def _param_full_name(self, param): | ||
def _pipelineparam_full_name(self, param): | ||
"""_pipelineparam_full_name | ||
|
||
Args: | ||
param(PipelineParam): pipeline parameter | ||
""" | ||
if param.op_name: | ||
return param.op_name + '-' + param.name | ||
return self._sanitize_name(param.name) | ||
|
@@ -79,12 +84,12 @@ def _op_to_template(self, op): | |
for i, _ in enumerate(processed_args): | ||
if op.argument_inputs: | ||
for param in op.argument_inputs: | ||
full_name = self._param_full_name(param) | ||
full_name = self._pipelineparam_full_name(param) | ||
processed_args[i] = re.sub(str(param), '{{inputs.parameters.%s}}' % full_name, | ||
processed_args[i]) | ||
input_parameters = [] | ||
for param in op.inputs: | ||
one_parameter = {'name': self._param_full_name(param)} | ||
one_parameter = {'name': self._pipelineparam_full_name(param)} | ||
if param.value: | ||
one_parameter['value'] = str(param.value) | ||
input_parameters.append(one_parameter) | ||
|
@@ -94,7 +99,7 @@ def _op_to_template(self, op): | |
output_parameters = [] | ||
for param in op.outputs.values(): | ||
output_parameters.append({ | ||
'name': self._param_full_name(param), | ||
'name': self._pipelineparam_full_name(param), | ||
'valueFrom': {'path': op.file_outputs[param.name]} | ||
}) | ||
output_parameters.sort(key=lambda x: x['name']) | ||
|
@@ -140,9 +145,9 @@ def _op_to_template(self, op): | |
template['nodeSelector'] = op.node_selector | ||
|
||
if op.env_variables: | ||
template['container']['env'] = list(map(self._convert_k8s_obj_to_dic, op.env_variables)) | ||
template['container']['env'] = list(map(K8sHelper.convert_k8s_obj_to_json, op.env_variables)) | ||
if op.volume_mounts: | ||
template['container']['volumeMounts'] = list(map(self._convert_k8s_obj_to_dic, op.volume_mounts)) | ||
template['container']['volumeMounts'] = list(map(K8sHelper.convert_k8s_obj_to_json, op.volume_mounts)) | ||
|
||
if op.pod_annotations or op.pod_labels: | ||
template['metadata'] = {} | ||
|
@@ -222,7 +227,7 @@ def _get_inputs_outputs(self, pipeline, root_group, op_groups): | |
if param.value: | ||
continue | ||
|
||
full_name = self._param_full_name(param) | ||
full_name = self._pipelineparam_full_name(param) | ||
if param.op_name: | ||
upstream_op = pipeline.ops[param.op_name] | ||
upstream_groups, downstream_groups = self._get_uncommon_ancestors( | ||
|
@@ -297,10 +302,16 @@ def _get_dependencies(self, pipeline, root_group, op_groups): | |
dependencies[downstream_groups[0]].add(upstream_groups[0]) | ||
return dependencies | ||
|
||
def _resolve_value_or_reference(self, value_or_reference, inputs): | ||
def _resolve_value_or_reference(self, value_or_reference, potential_references): | ||
"""_resolve_value_or_reference resolves values and PipelineParams, which could be task parameters or input parameters. | ||
|
||
Args: | ||
value_or_reference: value or reference to be resolved. It could be basic python types or PipelineParam | ||
potential_references(dict{str->str}): a dictionary of parameter names to task names | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for clarifying this. P.S. The python dictionary type is: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll sending some more refactoring PRs recently to make DSL more readable. |
||
""" | ||
if isinstance(value_or_reference, dsl.PipelineParam): | ||
parameter_name = self._param_full_name(value_or_reference) | ||
task_names = [task_name for param_name, task_name in inputs if param_name == parameter_name] | ||
parameter_name = self._pipelineparam_full_name(value_or_reference) | ||
task_names = [task_name for param_name, task_name in potential_references if param_name == parameter_name] | ||
if task_names: | ||
task_name = task_names[0] | ||
return '{{tasks.%s.outputs.parameters.%s}}' % (task_name, parameter_name) | ||
|
@@ -381,7 +392,6 @@ def _group_to_template(self, group, inputs, outputs, dependencies): | |
template['dag'] = {'tasks': tasks} | ||
return template | ||
|
||
|
||
def _create_templates(self, pipeline): | ||
"""Create all groups and ops templates in the pipeline.""" | ||
|
||
|
@@ -411,30 +421,36 @@ def _create_volumes(self, pipeline): | |
#TODO: check for duplicity based on the serialized volumes instead of just name. | ||
if v.name not in volume_name_set: | ||
volume_name_set.add(v.name) | ||
volumes.append(self._convert_k8s_obj_to_dic(v)) | ||
volumes.append(K8sHelper.convert_k8s_obj_to_json(v)) | ||
volumes.sort(key=lambda x: x['name']) | ||
return volumes | ||
|
||
def _create_pipeline_workflow(self, args, pipeline): | ||
"""Create workflow for the pipeline.""" | ||
|
||
# Input Parameters | ||
input_params = [] | ||
for arg in args: | ||
param = {'name': arg.name} | ||
if arg.value is not None: | ||
param['value'] = str(arg.value) | ||
input_params.append(param) | ||
|
||
# Templates | ||
templates = self._create_templates(pipeline) | ||
templates.sort(key=lambda x: x['name']) | ||
|
||
# Exit Handler | ||
exit_handler = None | ||
if pipeline.groups[0].groups: | ||
first_group = pipeline.groups[0].groups[0] | ||
if first_group.type == 'exit_handler': | ||
exit_handler = first_group.exit_op | ||
|
||
# Volumes | ||
volumes = self._create_volumes(pipeline) | ||
|
||
# The whole pipeline workflow | ||
workflow = { | ||
'apiVersion': 'argoproj.io/v1alpha1', | ||
'kind': 'Workflow', | ||
|
@@ -503,54 +519,6 @@ def _compile(self, pipeline_func): | |
workflow = self._create_pipeline_workflow(args_list_with_defaults, p) | ||
return workflow | ||
|
||
def _convert_k8s_obj_to_dic(self, obj): | ||
""" | ||
Builds a JSON K8s object. | ||
|
||
If obj is None, return None. | ||
If obj is str, int, long, float, bool, return directly. | ||
If obj is datetime.datetime, datetime.date | ||
convert to string in iso8601 format. | ||
If obj is list, sanitize each element in the list. | ||
If obj is dict, return the dict. | ||
If obj is swagger model, return the properties dict. | ||
|
||
Args: | ||
obj: The data to serialize. | ||
Returns: The serialized form of data. | ||
""" | ||
|
||
from six import text_type, integer_types, iteritems | ||
PRIMITIVE_TYPES = (float, bool, bytes, text_type) + integer_types | ||
from datetime import date, datetime | ||
if obj is None: | ||
return None | ||
elif isinstance(obj, PRIMITIVE_TYPES): | ||
return obj | ||
elif isinstance(obj, list): | ||
return [self._convert_k8s_obj_to_dic(sub_obj) | ||
for sub_obj in obj] | ||
elif isinstance(obj, tuple): | ||
return tuple(self._convert_k8s_obj_to_dic(sub_obj) | ||
for sub_obj in obj) | ||
elif isinstance(obj, (datetime, date)): | ||
return obj.isoformat() | ||
|
||
if isinstance(obj, dict): | ||
obj_dict = obj | ||
else: | ||
# Convert model obj to dict except | ||
# attributes `swagger_types`, `attribute_map` | ||
# and attributes which value is not None. | ||
# Convert attribute name to json key in | ||
# model definition for request. | ||
obj_dict = {obj.attribute_map[attr]: getattr(obj, attr) | ||
for attr, _ in iteritems(obj.swagger_types) | ||
if getattr(obj, attr) is not None} | ||
|
||
return {key: self._convert_k8s_obj_to_dic(val) | ||
for key, val in iteritems(obj_dict)} | ||
|
||
def compile(self, pipeline_func, package_path): | ||
"""Compile the given pipeline function into workflow yaml. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# Copyright 2018 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from kfp.compiler._k8s_helper import K8sHelper | ||
from datetime import datetime | ||
import unittest | ||
|
||
|
||
class TestCompiler(unittest.TestCase): | ||
def test_convert_k8s_obj_to_dic_accepts_dict(self): | ||
now = datetime.now() | ||
converted = K8sHelper.convert_k8s_obj_to_json({ | ||
"ENV": "test", | ||
"number": 3, | ||
"list": [1,2,3], | ||
"time": now | ||
}) | ||
self.assertEqual(converted, { | ||
"ENV": "test", | ||
"number": 3, | ||
"list": [1,2,3], | ||
"time": now.isoformat() | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output in not really JSON. It's a python dict.
P.S. The function (including the name comes from the K8s client library (
ApiClient
class))There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the name as yang's docstring inside the function, which is a json serialized dictionary that has nested structure.