Skip to content

Commit

Permalink
fix: Create PipelineJobSchedule in same project and location as assoc…
Browse files Browse the repository at this point in the history
…iated PipelineJob by default

PiperOrigin-RevId: 568706789
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Sep 27, 2023
1 parent 0c1c129 commit c22220e
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 2 deletions.
6 changes: 4 additions & 2 deletions google/cloud/aiplatform/pipeline_job_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,17 @@ def __init__(
Overrides credentials set in aiplatform.init.
project (str):
Optional. The project that you want to run this PipelineJobSchedule in.
If not set, the project set in aiplatform.init will be used.
If not set, the project used for the PipelineJob will be used.
location (str):
Optional. Location to create PipelineJobSchedule. If not set,
location set in aiplatform.init will be used.
location used for the PipelineJob will be used.
"""
if not display_name:
display_name = self.__class__._generate_display_name()
utils.validate_display_name(display_name)

project = project or pipeline_job.project
location = location or pipeline_job.location
super().__init__(credentials=credentials, project=project, location=location)

self._parent = initializer.global_config.common_location_path(
Expand Down
129 changes: 129 additions & 0 deletions tests/unit/aiplatform/test_pipeline_job_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,91 @@ def test_call_schedule_service_create(
gca_schedule.Schedule.State.COMPLETED
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
def test_call_schedule_service_create_uses_pipeline_job_project_location(
self,
mock_schedule_service_create,
mock_schedule_service_get,
mock_schedule_bucket_exists,
job_spec,
mock_load_yaml_and_json,
):
"""Creates a PipelineJobSchedule.
Tests that the PipelineJobSchedule is created in the same project and location as the PipelineJob.
"""
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
enable_caching=True,
project="managed-pipeline-test",
location="europe-west4",
)

pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule(
pipeline_job=job,
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
)

assert pipeline_job_schedule.project == "managed-pipeline-test"
assert pipeline_job_schedule.location == "europe-west4"

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
def test_call_schedule_service_create_uses_specified_project_location(
self,
mock_schedule_service_create,
mock_schedule_service_get,
mock_schedule_bucket_exists,
job_spec,
mock_load_yaml_and_json,
):
"""Creates a PipelineJobSchedule.
Tests that PipelineJobSchedule is created in the specified project and location over the PipelineJob's.
"""
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
enable_caching=True,
)

pipeline_job_schedule = pipeline_job_schedules.PipelineJobSchedule(
pipeline_job=job,
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
project="managed-pipeline-test",
location="europe-west4",
)

assert job.project == _TEST_PROJECT
assert job.location == _TEST_LOCATION

assert pipeline_job_schedule.project == "managed-pipeline-test"
assert pipeline_job_schedule.location == "europe-west4"

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
Expand Down Expand Up @@ -1148,6 +1233,50 @@ def test_call_pipeline_job_create_schedule(
gca_schedule.Schedule.State.COMPLETED
)

@pytest.mark.parametrize(
"job_spec",
[_TEST_PIPELINE_SPEC_JSON, _TEST_PIPELINE_SPEC_YAML, _TEST_PIPELINE_JOB],
)
def test_call_pipeline_job_create_schedule_uses_pipeline_job_project_location(
self,
mock_schedule_service_create,
mock_schedule_service_get,
job_spec,
mock_load_yaml_and_json,
):
"""Creates a PipelineJobSchedule via PipelineJob.create_schedule().
Tests that the PipelineJobSchedule is created in the same project and location as the PipelineJob.
"""
aiplatform.init(
project=_TEST_PROJECT,
staging_bucket=_TEST_GCS_BUCKET_NAME,
location=_TEST_LOCATION,
credentials=_TEST_CREDENTIALS,
)

job = pipeline_jobs.PipelineJob(
display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME,
template_path=_TEST_TEMPLATE_PATH,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
enable_caching=True,
project="managed-pipeline-test",
location="europe-west4",
)

pipeline_job_schedule = job.create_schedule(
display_name=_TEST_PIPELINE_JOB_SCHEDULE_DISPLAY_NAME,
cron=_TEST_PIPELINE_JOB_SCHEDULE_CRON,
max_concurrent_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_CONCURRENT_RUN_COUNT,
max_run_count=_TEST_PIPELINE_JOB_SCHEDULE_MAX_RUN_COUNT,
service_account=_TEST_SERVICE_ACCOUNT,
network=_TEST_NETWORK,
)

assert pipeline_job_schedule.project == "managed-pipeline-test"
assert pipeline_job_schedule.location == "europe-west4"

@pytest.mark.usefixtures("mock_schedule_service_get")
def test_get_schedule(self, mock_schedule_service_get):
aiplatform.init(project=_TEST_PROJECT)
Expand Down

0 comments on commit c22220e

Please sign in to comment.