diff --git a/prefect_gcp/workers/cloud_run_v2.py b/prefect_gcp/workers/cloud_run_v2.py index 6237bed..956beb4 100644 --- a/prefect_gcp/workers/cloud_run_v2.py +++ b/prefect_gcp/workers/cloud_run_v2.py @@ -54,6 +54,7 @@ def _get_default_job_body_template() -> Dict[str, Any]: "serviceAccount": "{{ service_account_name }}", "maxRetries": "{{ max_retries }}", "timeout": "{{ timeout }}", + "vpcAccess": "{{ vpc_connector_name }}", "containers": [ { "env": [], @@ -183,6 +184,7 @@ def prepare_for_flow_run( self._format_args_if_present() self._populate_image_if_not_present() self._populate_timeout() + self._populate_vpc_if_present() def _populate_timeout(self): """ @@ -233,6 +235,15 @@ def _format_args_if_present(self): "args" ] = shlex.split(args) + def _populate_vpc_if_present(self): + """ + Populates the job body with the VPC connector if present. + """ + if self.job_body["template"]["template"].get("vpcAccess") is not None: + self.job_body["template"]["template"]["vpcAccess"] = { + "connector": self.job_body["template"]["template"]["vpcAccess"], + } + # noinspection PyMethodParameters @validator("job_body") def _ensure_job_includes_all_required_components(cls, value: Dict[str, Any]): diff --git a/tests/test_cloud_run_worker_v2.py b/tests/test_cloud_run_worker_v2.py index e7025ed..5422430 100644 --- a/tests/test_cloud_run_worker_v2.py +++ b/tests/test_cloud_run_worker_v2.py @@ -15,6 +15,7 @@ def job_body(): "template": { "maxRetries": None, "timeout": None, + "vpcAccess": "projects/my_project/locations/us-central1/connectors/my-connector", # noqa: E501 "containers": [ { "env": [], @@ -120,3 +121,13 @@ def test_format_args_if_present(self, cloud_run_worker_v2_job_config): assert cloud_run_worker_v2_job_config.job_body["template"]["template"][ "containers" ][0]["args"] == ["-m", "prefect.engine"] + + def test_populate_vpc_if_present(self, cloud_run_worker_v2_job_config): + cloud_run_worker_v2_job_config._populate_vpc_if_present() + + assert ( + cloud_run_worker_v2_job_config.job_body["template"]["template"][ + "vpcAccess" + ]["connector"] + == "projects/my_project/locations/us-central1/connectors/my-connector" + )