Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDK - Components - Split load_component functions into loading the spec and creating task factory #3614

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 38 additions & 13 deletions sdk/python/kfp/components/_component_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
]

from pathlib import Path
import copy
import requests
from typing import Callable
from . import _components as comp
Expand Down Expand Up @@ -52,11 +53,38 @@ def load_component(self, name, digest=None, tag=None):
#This function should be called load_task_factory since it returns a factory function.
#The real load_component function should produce an object with component properties (e.g. name, description, inputs/outputs).
#TODO: Change this function to return component spec object but it should be callable to construct tasks.
component_ref = ComponentReference(name=name, digest=digest, tag=tag)
component_ref = self._load_component_spec_in_component_ref(component_ref)
return comp._create_task_factory_from_component_spec(
component_spec=component_ref.spec,
component_ref=component_ref,
)

def _load_component_spec_in_component_ref(
self,
component_ref: ComponentReference,
) -> ComponentReference:
'''Takes component_ref, finds the component spec and returns component_ref with .spec set to the component spec.

See ComponentStore.load_component for the details of the search logic.
'''
if component_ref.spec:
return component_ref

component_ref = copy.copy(component_ref)
if component_ref.url:
component_ref.spec = comp._load_component_spec_from_url(component_ref.url)
return component_ref

name = component_ref.name
if not name:
raise TypeError("name is required")
if name.startswith('/') or name.endswith('/'):
raise ValueError('Component name should not start or end with slash: "{}"'.format(name))

digest = component_ref.digest
tag = component_ref.tag

tried_locations = []

if digest is not None and tag is not None:
Expand All @@ -75,8 +103,10 @@ def load_component(self, name, digest=None, tag=None):
component_path = Path(local_search_path, path_suffix)
tried_locations.append(str(component_path))
if component_path.is_file():
component_ref = ComponentReference(name=name, digest=digest, tag=tag)
return comp.load_component_from_file(str(component_path))
# TODO: Verify that the content matches the digest (if specified).
component_ref._local_path = str(component_path)
component_ref.spec = comp._load_component_spec_from_file(str(component_path))
return component_ref

#Trying URL prefixes
for url_search_prefix in self.url_search_prefixes:
Expand All @@ -88,21 +118,16 @@ def load_component(self, name, digest=None, tag=None):
except:
continue
if response.content:
component_ref = ComponentReference(name=name, digest=digest, tag=tag, url=url)
return comp._load_component_from_yaml_or_zip_bytes(response.content, url, component_ref)
# TODO: Verify that the content matches the digest (if specified).
component_ref.url = url
component_ref.spec = comp._load_component_spec_from_yaml_or_zip_bytes(response.content)
return component_ref

raise RuntimeError('Component {} was not found. Tried the following locations:\n{}'.format(name, '\n'.join(tried_locations)))

def _load_component_from_ref(self, component_ref: ComponentReference) -> Callable:
if component_ref.spec:
return comp._create_task_factory_from_component_spec(component_spec=component_ref.spec, component_ref=component_ref)
if component_ref.url:
return self.load_component_from_url(component_ref.url)
return self.load_component(
name=component_ref.name,
digest=component_ref.digest,
tag=component_ref.tag,
)
component_ref = self._load_component_spec_in_component_ref(component_ref)
return comp._create_task_factory_from_component_spec(component_spec=component_ref.spec, component_ref=component_ref)


ComponentStore.default_store = ComponentStore(
Expand Down
81 changes: 51 additions & 30 deletions sdk/python/kfp/components/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,14 @@ def load_component_from_url(url):
A factory function with a strongly-typed signature.
Once called with the required arguments, the factory constructs a pipeline task instance (ContainerOp).
'''
if url is None:
raise TypeError

#Handling Google Cloud Storage URIs
if url.startswith('gs://'):
#Replacing the gs:// URI with https:// URI (works for public objects)
url = 'https://storage.googleapis.com/' + url[len('gs://'):]

import requests
resp = requests.get(url)
resp.raise_for_status()
component_spec = _load_component_spec_from_url(url)
url = _fix_component_uri(url)
component_ref = ComponentReference(url=url)
return _load_component_from_yaml_or_zip_bytes(resp.content, url, component_ref)
return _create_task_factory_from_component_spec(
component_spec=component_spec,
component_filename=url,
component_ref=component_ref,
)


def load_component_from_file(filename):
Expand All @@ -100,10 +95,11 @@ def load_component_from_file(filename):
A factory function with a strongly-typed signature.
Once called with the required arguments, the factory constructs a pipeline task instance (ContainerOp).
'''
if filename is None:
raise TypeError
with open(filename, 'rb') as component_stream:
return _load_component_from_yaml_or_zip_stream(component_stream, filename)
component_spec = _load_component_spec_from_file(path=filename)
return _create_task_factory_from_component_spec(
component_spec=component_spec,
component_filename=filename,
)


def load_component_from_text(text):
Expand All @@ -119,20 +115,47 @@ def load_component_from_text(text):
'''
if text is None:
raise TypeError
return _create_task_factory_from_component_text(text, None)
component_spec = _load_component_spec_from_component_text(text)
return _create_task_factory_from_component_spec(component_spec=component_spec)


def _fix_component_uri(uri: str) -> str:
#Handling Google Cloud Storage URIs
if uri.startswith('gs://'):
#Replacing the gs:// URI with https:// URI (works for public objects)
uri = 'https://storage.googleapis.com/' + uri[len('gs://'):]
return uri


def _load_component_spec_from_file(path) -> ComponentSpec:
with open(path, 'rb') as component_stream:
return _load_component_spec_from_yaml_or_zip_stream(component_stream)


def _load_component_spec_from_url(url: str):
if url is None:
raise TypeError

url = _fix_component_uri(url)

import requests
resp = requests.get(url)
resp.raise_for_status()
return _load_component_spec_from_yaml_or_zip_bytes(resp.content)


_COMPONENT_FILE_NAME_IN_ARCHIVE = 'component.yaml'


def _load_component_from_yaml_or_zip_bytes(bytes, component_filename=None, component_ref: ComponentReference = None):
def _load_component_spec_from_yaml_or_zip_bytes(data: bytes):
import io
component_stream = io.BytesIO(bytes)
return _load_component_from_yaml_or_zip_stream(component_stream, component_filename, component_ref)
component_stream = io.BytesIO(data)
return _load_component_spec_from_yaml_or_zip_stream(component_stream)


def _load_component_spec_from_yaml_or_zip_stream(stream) -> ComponentSpec:
'''Loads component spec from a stream.

def _load_component_from_yaml_or_zip_stream(stream, component_filename=None, component_ref: ComponentReference = None):
'''Loads component from a stream and creates a task factory function.
The stream can be YAML or a zip file with a component.yaml file inside.
'''
import zipfile
Expand All @@ -141,20 +164,18 @@ def _load_component_from_yaml_or_zip_stream(stream, component_filename=None, com
stream.seek(0)
with zipfile.ZipFile(stream) as zip_obj:
with zip_obj.open(_COMPONENT_FILE_NAME_IN_ARCHIVE) as component_stream:
return _create_task_factory_from_component_text(component_stream, component_filename, component_ref)
return _load_component_spec_from_component_text(
text_or_file=component_stream,
)
else:
stream.seek(0)
return _create_task_factory_from_component_text(stream, component_filename, component_ref)
return _load_component_spec_from_component_text(stream)


def _create_task_factory_from_component_text(text_or_file, component_filename=None, component_ref: ComponentReference = None):
def _load_component_spec_from_component_text(text_or_file) -> ComponentSpec:
component_dict = load_yaml(text_or_file)
return _create_task_factory_from_component_dict(component_dict, component_filename, component_ref)


def _create_task_factory_from_component_dict(component_dict, component_filename=None, component_ref: ComponentReference = None):
component_spec = ComponentSpec.from_dict(component_dict)
return _create_task_factory_from_component_spec(component_spec, component_filename, component_ref)
return component_spec


_inputs_dir = '/tmp/inputs'
Expand Down