Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding missing fields to FlyteTask remote entity #3093

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions flytekit/remote/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ def __init__(
custom,
container=None,
task_type_version: int = 0,
security_context=None,
config=None,
k8s_pod=None,
sql=None,
extended_resources=None,
should_register: bool = False,
):
super(FlyteTask, self).__init__(
Expand All @@ -61,7 +65,11 @@ def __init__(
custom,
container=container,
task_type_version=task_type_version,
security_context=security_context,
config=config,
k8s_pod=k8s_pod,
sql=sql,
extended_resources=extended_resources,
)
)
self._should_register = should_register
Expand Down Expand Up @@ -146,6 +154,10 @@ def k8s_pod(self):
def sql(self):
return self.template.sql

@property
def extended_resources(self):
return self.template.extended_resources

@property
def should_register(self) -> bool:
return self._should_register
Expand All @@ -172,6 +184,11 @@ def promote_from_model(cls, base_model: _task_model.TaskTemplate) -> FlyteTask:
custom=base_model.custom,
container=base_model.container,
task_type_version=base_model.task_type_version,
security_context=base_model.security_context,
config=base_model.config,
k8s_pod=base_model.k8s_pod,
sql=base_model.sql,
extended_resources=base_model.extended_resources,
)
# Override the newly generated name if one exists in the base model
if not base_model.id.is_empty:
Expand Down
2 changes: 2 additions & 0 deletions plugins/flytekit-spark/tests/test_remote_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def my_python_task(a: str) -> int:
mock_client = MagicMock()
remote._client = mock_client
remote._client_initialized = True
remote._client.get_task.return_value.closure.compiled_task.template.sql = None
remote._client.get_task.return_value.closure.compiled_task.template.k8s_pod = None

mock_image_config = MagicMock(default_image=MagicMock(full="fake-cr.io/image-name:tag"))
remote.register_task(
Expand Down
38 changes: 20 additions & 18 deletions tests/flytekit/unit/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,14 +547,21 @@ def wf1(name: str = "union") -> float:
flyte_remote.register_script(wf1)


@mock.patch("flytekit.remote.remote.FlyteRemote.client")
def test_local_server(mock_client):
@pytest.fixture()
def mock_flyte_remote_client():
with patch("flytekit.remote.remote.FlyteRemote.client") as mock_flyte_remote_client:
mock_flyte_remote_client.get_task.return_value.closure.compiled_task.template.sql = None
mock_flyte_remote_client.get_task.return_value.closure.compiled_task.template.k8s_pod = None
yield mock_flyte_remote_client


def test_local_server(mock_flyte_remote_client):
ctx = FlyteContextManager.current_context()
lt = TypeEngine.to_literal_type(typing.Dict[str, int])
lm = TypeEngine.to_literal(ctx, {"hello": 55}, typing.Dict[str, int], lt)
lm = lm.map.to_flyte_idl()

mock_client.get_data.return_value = dataproxy_pb2.GetDataResponse(literal_map=lm)
mock_flyte_remote_client.get_data.return_value = dataproxy_pb2.GetDataResponse(literal_map=lm)

rr = FlyteRemote(
Config.for_sandbox(),
Expand All @@ -566,8 +573,7 @@ def test_local_server(mock_client):


@mock.patch("flytekit.remote.remote.uuid")
@mock.patch("flytekit.remote.remote.FlyteRemote.client")
def test_execution_name(mock_client, mock_uuid):
def test_execution_name(mock_uuid, mock_flyte_remote_client):
test_uuid = uuid.UUID("16fd2706-8baf-433b-82eb-8c7fada847da")
mock_uuid.uuid4.return_value = test_uuid
remote = FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain")
Expand Down Expand Up @@ -597,7 +603,7 @@ def test_execution_name(mock_client, mock_uuid):
entity=ft,
inputs={"t": datetime.now(), "v": 0},
)
mock_client.create_execution.assert_has_calls(
mock_flyte_remote_client.create_execution.assert_has_calls(
[
mock.call(ANY, ANY, "execution-test", ANY, ANY),
mock.call(ANY, ANY, "execution-test-" + test_uuid.hex[:19], ANY, ANY),
Expand Down Expand Up @@ -688,9 +694,8 @@ def test_register_wf_script_mode(compress_scripts_mock, upload_file_mock, regist
)


@mock.patch("flytekit.remote.remote.FlyteRemote.client")
def test_fetch_active_launchplan_not_found(mock_client, remote):
mock_client.get_active_launch_plan.side_effect = FlyteEntityNotExistException("not found")
def test_fetch_active_launchplan_not_found(mock_flyte_remote_client, remote):
mock_flyte_remote_client.get_active_launch_plan.side_effect = FlyteEntityNotExistException("not found")
assert remote.fetch_active_launchplan(name="basic.list_float_wf.fake_wf") is None


Expand Down Expand Up @@ -785,8 +790,7 @@ async def eager_wf(a: int) -> int:
_get_pickled_target_dict(eager_wf)


@mock.patch("flytekit.remote.remote.FlyteRemote.client")
def test_launchplan_auto_activate(mock_client):
def test_launchplan_auto_activate(mock_flyte_remote_client):
@workflow
def wf() -> int:
return 1
Expand All @@ -804,15 +808,14 @@ def wf() -> int:

# The first one should not update the launchplan
rr.register_launch_plan(lp1, version="1", serialization_settings=ss)
mock_client.update_launch_plan.assert_not_called()
mock_flyte_remote_client.update_launch_plan.assert_not_called()

# the second one should
rr.register_launch_plan(lp2, version="1", serialization_settings=ss)
mock_client.update_launch_plan.assert_called()
mock_flyte_remote_client.update_launch_plan.assert_called()


@mock.patch("flytekit.remote.remote.FlyteRemote.client")
def test_register_task_with_node_dependency_hints(mock_client):
def test_register_task_with_node_dependency_hints(mock_flyte_remote_client):
@task
def task0():
return None
Expand Down Expand Up @@ -858,8 +861,7 @@ def workflow1():
@mock.patch("flytekit.remote.remote.FlyteRemote.fetch_launch_plan")
@mock.patch("flytekit.remote.remote.FlyteRemote.raw_register")
@mock.patch("flytekit.remote.remote.FlyteRemote._serialize_and_register")
@mock.patch("flytekit.remote.remote.FlyteRemote.client")
def test_register_launch_plan(mock_client, mock_serialize_and_register, mock_raw_register,mock_fetch_launch_plan, mock_get_serializable):
def test_register_launch_plan(mock_serialize_and_register, mock_raw_register,mock_fetch_launch_plan, mock_get_serializable, mock_flyte_remote_client):
serialization_settings = SerializationSettings(
image_config=ImageConfig.auto_default_image(),
version="dummy_version",
Expand All @@ -883,7 +885,7 @@ def hello_world_wf() -> str:
lp = LaunchPlan.get_or_create(workflow=hello_world_wf, name="additional_lp_for_hello_world", default_inputs={})

mock_get_serializable.return_value = MagicMock()
mock_client.get_workflow.return_value = MagicMock()
mock_flyte_remote_client.get_workflow.return_value = MagicMock()

mock_remote_lp = MagicMock()
mock_fetch_launch_plan.return_value = mock_remote_lp
Expand Down
Loading