Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(components): Addressing Review comments on Trainer component for PyTorch - KFP #5814

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
"""Pipeline Base component class."""

import abc
from six import with_metaclass
from pytorch_kfp_components.types import standard_component_specs


class BaseComponent(with_metaclass(abc.ABCMeta, object)): # pylint: disable=R0903
class BaseComponent(metaclass=abc.ABCMeta): # pylint: disable=R0903
"""Pipeline Base component class."""

def __init__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _generate_mar_file(

print("Running Archiver cmd: ", archiver_cmd)

proc = subprocess.Popen(
proc = subprocess.Popen( #pylint: disable=consider-using-with
archiver_cmd,
shell=True,
stdout=subprocess.PIPE,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
# limitations under the License.
"""Minio Executor Module."""
import os
from pytorch_kfp_components.components.base.base_executor import BaseExecutor
from pytorch_kfp_components.types import standard_component_specs
import urllib3
from minio import Minio #pylint: disable=no-name-in-module
from pytorch_kfp_components.components.base.base_executor import BaseExecutor
from pytorch_kfp_components.types import standard_component_specs


class Executor(BaseExecutor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pytorch_kfp_components.types import standard_component_specs


class Trainer(BaseComponent):
class Trainer(BaseComponent): #pylint: disable=too-few-public-methods
"""Initializes the Trainer class."""

def __init__( # pylint: disable=R0913
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Executor(GenericExecutor):
def __init__(self): # pylint:disable=useless-super-delegation
super().__init__()

def Do(self, input_dict: dict, output_dict: dict, exec_properties: dict):
def Do(self, input_dict: dict, output_dict: dict, exec_properties: dict): #pylint: disable=too-many-locals
"""This function of the Executor invokes the PyTorch Lightning training
loop.

Expand Down Expand Up @@ -63,7 +63,7 @@ def Do(self, input_dict: dict, output_dict: dict, exec_properties: dict):
trainer_args,
module_file_args,
data_module_args,
) = self._GetFnArgs(
) = self._get_fn_args(
input_dict=input_dict,
output_dict=output_dict,
execution_properties=exec_properties,
Expand All @@ -75,6 +75,11 @@ def Do(self, input_dict: dict, output_dict: dict, exec_properties: dict):
) = self.derive_model_and_data_module_class(
module_file=module_file, data_module_file=data_module_file
)
if not data_module_class :
raise NotImplementedError(
"Data module class is mandatory. "
"User defined training module is yet to be supported."
)
if data_module_class:
data_module = data_module_class(
**data_module_args if data_module_args else {}
Expand All @@ -93,8 +98,8 @@ def Do(self, input_dict: dict, output_dict: dict, exec_properties: dict):
parser = Namespace(**trainer_args)
trainer = pl.Trainer.from_argparse_args(parser)

trainer.fit(model, data_module)
trainer.test()
trainer.fit(model, data_module) #pylint: disable=no-member
trainer.test() #pylint: disable=no-member

if "checkpoint_dir" in module_file_args:
model_save_path = module_file_args["checkpoint_dir"]
Expand All @@ -114,9 +119,3 @@ def Do(self, input_dict: dict, output_dict: dict, exec_properties: dict):
output_dict[standard_component_specs.TRAINER_MODEL_SAVE_PATH
] = model_save_path
output_dict[standard_component_specs.PTL_TRAINER_OBJ] = trainer

else:
raise NotImplementedError(
"Data module class is mandatory. "
"User defined training module is yet to be supported."
)
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ class GenericExecutor(BaseExecutor):
"""Generic Executor Class that does nothing."""

def Do(self, input_dict: dict, output_dict: dict, exec_properties: dict):
# TODO: Code to train pretrained model
#TODO: Code to train pretrained model #pylint: disable=fixme
pass

def _GetFnArgs(
self, input_dict: dict, output_dict: dict, execution_properties: dict
def _get_fn_args( #pylint: disable=no-self-use
self, input_dict: dict, output_dict: dict, execution_properties: dict #pylint: disable=unused-argument
):
"""Gets the input/output/execution properties from the dictionary.

Expand Down Expand Up @@ -68,7 +68,7 @@ def _GetFnArgs(
data_module_args,
)

def derive_model_and_data_module_class(
def derive_model_and_data_module_class( #pylint: disable=no-self-use
self, module_file: str, data_module_file: str
):
"""Derives the model file and data modul file.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Module for defining standard specifications and validation of parameter
type."""

#pylint: disable=duplicate-code
TRAINER_MODULE_FILE = "module_file"
TRAINER_DATA_MODULE_FILE = "data_module_file"
TRAINER_DATA_MODULE_ARGS = "data_module_args"
Expand Down Expand Up @@ -112,7 +111,7 @@ class MarGenerationSpec: # pylint: disable=R0903
}


class VisualizationSpec:
class VisualizationSpec: #pylint: disable=too-few-public-methods
"""Visualization Specification class.
For validating the parameter 'type'
"""
Expand Down Expand Up @@ -142,7 +141,7 @@ class VisualizationSpec:
}


class MinIoSpec:
class MinIoSpec: #pylint: disable=too-few-public-methods
"""MinIO Specification class.
For validating the parameter 'type'
"""
Expand Down
5 changes: 3 additions & 2 deletions components/PyTorch/pytorch-kfp-components/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def detect_version(base_path):
name="pytorch-kfp-components",
version=version,
description="PyTorch Kubeflow Pipeline",
url="https://github.com/kubeflow/pipelines/tree/master/components",
url="https://github.com/kubeflow/pipelines/tree/master/components/PyTorch/pytorch-kfp-components/",
author="The PyTorch Kubeflow Pipeline Components authors",
author_email="pytorch-kfp-components@fb.com",
license="Apache License 2.0",
Expand All @@ -79,7 +79,8 @@ def detect_version(base_path):
install_requires=make_required_install_packages(),
dependency_links=make_dependency_links(),
keywords=[
"Kubeflow",
"Kubeflow Pipelines",
"KFP",
"ML workflow",
"PyTorch",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_unreachable_endpoint(minio_inputs):
"""Testing unreachable minio endpoint with invalid minio creds."""
os.environ["MINIO_ACCESS_KEY"] = "dummy"
os.environ["MINIO_SECRET_KEY"] = "dummy"
with pytest.raises(Exception, match="Max retries exceeded with url*"):
with pytest.raises(Exception, match="Max retries exceeded with url: "):
upload_to_minio(minio_inputs)


Expand Down