diff --git a/kubernetes-tests/tests/kubernetes_tests/conftest.py b/kubernetes-tests/tests/kubernetes_tests/conftest.py index 1a78383f4b0c3..06ad4ff9ca2d4 100644 --- a/kubernetes-tests/tests/kubernetes_tests/conftest.py +++ b/kubernetes-tests/tests/kubernetes_tests/conftest.py @@ -38,3 +38,20 @@ def pod_template() -> Path: @pytest.fixture def basic_pod_template() -> Path: return (DATA_FILES_DIRECTORY / "basic_pod.yaml").resolve(strict=True) + + +@pytest.fixture +def create_connection_without_db(monkeypatch): + """ + Fixture to create connections for tests without using the database. + + This fixture uses monkeypatch to set the appropriate AIRFLOW_CONN_{conn_id} environment variable. + """ + + def _create_conn(connection, session=None): + """Create connection using environment variable.""" + + env_var_name = f"AIRFLOW_CONN_{connection.conn_id.upper()}" + monkeypatch.setenv(env_var_name, connection.as_json()) + + return _create_conn diff --git a/kubernetes-tests/tests/kubernetes_tests/test_kubernetes_pod_operator.py b/kubernetes-tests/tests/kubernetes_tests/test_kubernetes_pod_operator.py index 6b51627146c14..790db7c7df704 100644 --- a/kubernetes-tests/tests/kubernetes_tests/test_kubernetes_pod_operator.py +++ b/kubernetes-tests/tests/kubernetes_tests/test_kubernetes_pod_operator.py @@ -105,14 +105,7 @@ def test_label(request): return label[-63:] -@pytest.fixture -def mock_get_connection(): - with mock.patch(f"{HOOK_CLASS}.get_connection", return_value=Connection(conn_id="kubernetes_default")): - yield - - @pytest.mark.execution_timeout(180) -@pytest.mark.usefixtures("mock_get_connection") class TestKubernetesPodOperatorSystem: @pytest.fixture(autouse=True) def setup_tests(self, test_label): @@ -166,12 +159,21 @@ def setup_tests(self, test_label): client = hook.core_v1_client client.delete_collection_namespaced_pod(namespace="default", grace_period_seconds=0) + @pytest.fixture(autouse=True) + def setup_connections(self, create_connection_without_db): + """Create kubernetes_default connection""" + connection = Connection( + conn_id="kubernetes_default", + conn_type="kubernetes", + ) + create_connection_without_db(connection) + def _get_labels_selector(self) -> str | None: if not self.labels: return None return ",".join([f"{key}={value}" for key, value in enumerate(self.labels)]) - def test_do_xcom_push_defaults_false(self, kubeconfig_path, mock_get_connection, tmp_path): + def test_do_xcom_push_defaults_false(self, kubeconfig_path, tmp_path): new_config_path = tmp_path / "kube_config.cfg" shutil.copy(kubeconfig_path, new_config_path) k = KubernetesPodOperator( @@ -187,7 +189,7 @@ def test_do_xcom_push_defaults_false(self, kubeconfig_path, mock_get_connection, ) assert not k.do_xcom_push - def test_config_path_move(self, kubeconfig_path, mock_get_connection, tmp_path): + def test_config_path_move(self, kubeconfig_path, tmp_path): new_config_path = tmp_path / "kube_config.cfg" shutil.copy(kubeconfig_path, new_config_path) @@ -209,7 +211,7 @@ def test_config_path_move(self, kubeconfig_path, mock_get_connection, tmp_path): actual_pod = self.api_client.sanitize_for_serialization(k.pod) assert actual_pod == expected_pod - def test_working_pod(self, mock_get_connection): + def test_working_pod(self): k = KubernetesPodOperator( namespace="default", image="ubuntu:16.04", @@ -226,7 +228,7 @@ def test_working_pod(self, mock_get_connection): assert self.expected_pod["spec"] == actual_pod["spec"] assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"] - def test_skip_cleanup(self, mock_get_connection): + def test_skip_cleanup(self): k = KubernetesPodOperator( namespace="unknown", image="ubuntu:16.04", @@ -241,7 +243,7 @@ def test_skip_cleanup(self, mock_get_connection): with pytest.raises(ApiException): k.execute(context) - def test_delete_operator_pod(self, mock_get_connection): + def test_delete_operator_pod(self): k = KubernetesPodOperator( namespace="default", image="ubuntu:16.04", @@ -259,7 +261,7 @@ def test_delete_operator_pod(self, mock_get_connection): assert self.expected_pod["spec"] == actual_pod["spec"] assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"] - def test_skip_on_specified_exit_code(self, mock_get_connection): + def test_skip_on_specified_exit_code(self): k = KubernetesPodOperator( namespace="default", image="ubuntu:16.04", @@ -275,7 +277,7 @@ def test_skip_on_specified_exit_code(self, mock_get_connection): with pytest.raises(AirflowSkipException): k.execute(context) - def test_already_checked_on_success(self, mock_get_connection): + def test_already_checked_on_success(self): """ When ``on_finish_action="keep_pod"``, pod should have 'already_checked' label, whether pod is successful or not. @@ -297,7 +299,7 @@ def test_already_checked_on_success(self, mock_get_connection): actual_pod = self.api_client.sanitize_for_serialization(actual_pod) assert actual_pod["metadata"]["labels"]["already_checked"] == "True" - def test_already_checked_on_failure(self, mock_get_connection): + def test_already_checked_on_failure(self): """ When ``on_finish_action="keep_pod"``, pod should have 'already_checked' label, whether pod is successful or not. @@ -322,7 +324,7 @@ def test_already_checked_on_failure(self, mock_get_connection): assert status["state"]["terminated"]["reason"] == "Error" assert actual_pod["metadata"]["labels"]["already_checked"] == "True" - def test_pod_hostnetwork(self, mock_get_connection): + def test_pod_hostnetwork(self): k = KubernetesPodOperator( namespace="default", image="ubuntu:16.04", @@ -341,7 +343,7 @@ def test_pod_hostnetwork(self, mock_get_connection): assert self.expected_pod["spec"] == actual_pod["spec"] assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"] - def test_pod_dnspolicy(self, mock_get_connection): + def test_pod_dnspolicy(self): dns_policy = "ClusterFirstWithHostNet" k = KubernetesPodOperator( namespace="default", @@ -363,7 +365,7 @@ def test_pod_dnspolicy(self, mock_get_connection): assert self.expected_pod["spec"] == actual_pod["spec"] assert self.expected_pod["metadata"]["labels"] == actual_pod["metadata"]["labels"] - def test_pod_schedulername(self, mock_get_connection): + def test_pod_schedulername(self): scheduler_name = "default-scheduler" k = KubernetesPodOperator( namespace="default", @@ -382,7 +384,7 @@ def test_pod_schedulername(self, mock_get_connection): self.expected_pod["spec"]["schedulerName"] = scheduler_name assert self.expected_pod == actual_pod - def test_pod_node_selector(self, mock_get_connection): + def test_pod_node_selector(self): node_selector = {"beta.kubernetes.io/os": "linux"} k = KubernetesPodOperator( namespace="default", @@ -401,7 +403,7 @@ def test_pod_node_selector(self, mock_get_connection): self.expected_pod["spec"]["nodeSelector"] = node_selector assert self.expected_pod == actual_pod - def test_pod_resources(self, mock_get_connection): + def test_pod_resources(self): resources = k8s.V1ResourceRequirements( requests={"memory": "64Mi", "cpu": "250m", "ephemeral-storage": "1Gi"}, limits={"memory": "64Mi", "cpu": 0.25, "nvidia.com/gpu": None, "ephemeral-storage": "2Gi"}, @@ -471,7 +473,7 @@ def test_pod_resources(self, mock_get_connection): ), ], ) - def test_pod_affinity(self, val, mock_get_connection): + def test_pod_affinity(self, val): expected = { "nodeAffinity": { "requiredDuringSchedulingIgnoredDuringExecution": { @@ -502,7 +504,7 @@ def test_pod_affinity(self, val, mock_get_connection): self.expected_pod["spec"]["affinity"] = expected assert self.expected_pod == actual_pod - def test_port(self, mock_get_connection): + def test_port(self): port = k8s.V1ContainerPort( name="http", container_port=80, @@ -525,7 +527,7 @@ def test_port(self, mock_get_connection): self.expected_pod["spec"]["containers"][0]["ports"] = [{"name": "http", "containerPort": 80}] assert self.expected_pod == actual_pod - def test_volume_mount(self, mock_get_connection): + def test_volume_mount(self): with mock.patch.object(PodManager, "log") as mock_logger: volume_mount = k8s.V1VolumeMount( name="test-volume", mount_path="/tmp/test_volume", sub_path=None, read_only=False @@ -565,7 +567,7 @@ def test_volume_mount(self, mock_get_connection): assert self.expected_pod == actual_pod @pytest.mark.parametrize("uid", [0, 1000]) - def test_run_as_user(self, uid, mock_get_connection): + def test_run_as_user(self, uid): security_context = V1PodSecurityContext(run_as_user=uid) name = str(uuid4()) k = KubernetesPodOperator( @@ -590,7 +592,7 @@ def test_run_as_user(self, uid, mock_get_connection): assert pod.to_dict()["spec"]["security_context"]["run_as_user"] == uid @pytest.mark.parametrize("gid", [0, 1000]) - def test_fs_group(self, gid, mock_get_connection): + def test_fs_group(self, gid): security_context = V1PodSecurityContext(fs_group=gid) name = str(uuid4()) k = KubernetesPodOperator( @@ -614,7 +616,7 @@ def test_fs_group(self, gid, mock_get_connection): ) assert pod.to_dict()["spec"]["security_context"]["fs_group"] == gid - def test_disable_privilege_escalation(self, mock_get_connection): + def test_disable_privilege_escalation(self): container_security_context = V1SecurityContext(allow_privilege_escalation=False) k = KubernetesPodOperator( @@ -636,7 +638,7 @@ def test_disable_privilege_escalation(self, mock_get_connection): } assert self.expected_pod == actual_pod - def test_faulty_image(self, mock_get_connection): + def test_faulty_image(self): bad_image_name = "foobar" k = KubernetesPodOperator( namespace="default", @@ -656,7 +658,7 @@ def test_faulty_image(self, mock_get_connection): self.expected_pod["spec"]["containers"][0]["image"] = bad_image_name assert self.expected_pod == actual_pod - def test_faulty_service_account(self, mock_get_connection): + def test_faulty_service_account(self): k = KubernetesPodOperator( namespace="default", image="ubuntu:16.04", @@ -674,7 +676,7 @@ def test_faulty_service_account(self, mock_get_connection): with pytest.raises(ApiException, match="error looking up service account default/foobar"): k.get_or_create_pod(pod, context) - def test_pod_failure(self, mock_get_connection): + def test_pod_failure(self): """ Tests that the task fails when a pod reports a failure """ @@ -696,7 +698,7 @@ def test_pod_failure(self, mock_get_connection): self.expected_pod["spec"]["containers"][0]["args"] = bad_internal_command assert self.expected_pod == actual_pod - def test_xcom_push(self, test_label, mock_get_connection): + def test_xcom_push(self, test_label): expected = {"test_label": test_label, "buzz": 2} args = [f"echo '{json.dumps(expected)}' > /airflow/xcom/return.json"] k = KubernetesPodOperator( @@ -712,7 +714,7 @@ def test_xcom_push(self, test_label, mock_get_connection): context = create_context(k) assert k.execute(context) == expected - def test_env_vars(self, mock_get_connection): + def test_env_vars(self): # WHEN env_vars = [ k8s.V1EnvVar(name="ENV1", value="val1"), @@ -744,7 +746,7 @@ def test_env_vars(self, mock_get_connection): ] assert self.expected_pod == actual_pod - def test_pod_template_file_system(self, mock_get_connection, basic_pod_template): + def test_pod_template_file_system(self, basic_pod_template): """Note: this test requires that you have a namespace ``mem-example`` in your cluster.""" k = KubernetesPodOperator( task_id=str(uuid4()), @@ -766,9 +768,7 @@ def test_pod_template_file_system(self, mock_get_connection, basic_pod_template) pytest.param({"env_name": "value"}, id="backcompat"), # todo: remove? ], ) - def test_pod_template_file_with_overrides_system( - self, env_vars, test_label, mock_get_connection, basic_pod_template - ): + def test_pod_template_file_with_overrides_system(self, env_vars, test_label, basic_pod_template): k = KubernetesPodOperator( task_id=str(uuid4()), labels=self.labels, @@ -794,7 +794,7 @@ def test_pod_template_file_with_overrides_system( assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")] assert result == {"hello": "world"} - def test_pod_template_file_with_full_pod_spec(self, test_label, mock_get_connection, basic_pod_template): + def test_pod_template_file_with_full_pod_spec(self, test_label, basic_pod_template): pod_spec = k8s.V1Pod( metadata=k8s.V1ObjectMeta( labels={"test_label": test_label, "fizz": "buzz"}, @@ -834,7 +834,7 @@ def test_pod_template_file_with_full_pod_spec(self, test_label, mock_get_connect assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")] assert result == {"hello": "world"} - def test_full_pod_spec(self, test_label, mock_get_connection): + def test_full_pod_spec(self, test_label): pod_spec = k8s.V1Pod( metadata=k8s.V1ObjectMeta( labels={"test_label": test_label, "fizz": "buzz"}, namespace="default", name="test-pod" @@ -879,7 +879,7 @@ def test_full_pod_spec(self, test_label, mock_get_connection): assert k.pod.spec.containers[0].env == [k8s.V1EnvVar(name="env_name", value="value")] assert result == {"hello": "world"} - def test_init_container(self, mock_get_connection): + def test_init_container(self): # GIVEN volume_mounts = [ k8s.V1VolumeMount(mount_path="/etc/foo", name="test-volume", sub_path=None, read_only=True) @@ -1078,7 +1078,7 @@ def test_pod_priority_class_name(self, hook_mock, await_pod_completion_mock): self.expected_pod["spec"]["priorityClassName"] = priority_class_name assert self.expected_pod == actual_pod - def test_pod_name(self, mock_get_connection): + def test_pod_name(self): pod_name_too_long = "a" * 221 k = KubernetesPodOperator( namespace="default", @@ -1097,7 +1097,7 @@ def test_pod_name(self, mock_get_connection): with pytest.raises(AirflowException): k.execute(context) - def test_on_kill(self, mock_get_connection): + def test_on_kill(self): hook = KubernetesHook(conn_id=None, in_cluster=False) client = hook.core_v1_client name = "test" @@ -1137,7 +1137,7 @@ class ShortCircuitException(Exception): with pytest.raises(ApiException, match=r'pods \\"test.[a-z0-9]+\\" not found'): client.read_namespaced_pod(name=name, namespace=namespace) - def test_reattach_failing_pod_once(self, mock_get_connection): + def test_reattach_failing_pod_once(self): hook = KubernetesHook(conn_id=None, in_cluster=False) client = hook.core_v1_client name = "test" @@ -1204,7 +1204,7 @@ def get_op(): k.execute(context) create_mock.assert_called_once() - def test_changing_base_container_name_with_get_logs(self, mock_get_connection): + def test_changing_base_container_name_with_get_logs(self): k = KubernetesPodOperator( namespace="default", image="ubuntu:16.04", @@ -1229,7 +1229,7 @@ def test_changing_base_container_name_with_get_logs(self, mock_get_connection): self.expected_pod["spec"]["containers"][0]["name"] = "apple-sauce" assert self.expected_pod["spec"] == actual_pod["spec"] - def test_changing_base_container_name_no_logs(self, mock_get_connection): + def test_changing_base_container_name_no_logs(self): """ This test checks BOTH a modified base container name AND the get_logs=False flow, and as a result, also checks that the flow works with fast containers @@ -1259,7 +1259,7 @@ def test_changing_base_container_name_no_logs(self, mock_get_connection): self.expected_pod["spec"]["containers"][0]["name"] = "apple-sauce" assert self.expected_pod["spec"] == actual_pod["spec"] - def test_changing_base_container_name_no_logs_long(self, mock_get_connection): + def test_changing_base_container_name_no_logs_long(self): """ Similar to test_changing_base_container_name_no_logs, but ensures that pods running longer than 1 second work too. @@ -1290,7 +1290,7 @@ def test_changing_base_container_name_no_logs_long(self, mock_get_connection): self.expected_pod["spec"]["containers"][0]["args"] = ["sleep 3"] assert self.expected_pod["spec"] == actual_pod["spec"] - def test_changing_base_container_name_failure(self, mock_get_connection): + def test_changing_base_container_name_failure(self): k = KubernetesPodOperator( namespace="default", image="ubuntu:16.04", @@ -1317,7 +1317,7 @@ class ShortCircuitException(Exception): assert mock_get_container_termination_message.call_args[0][1] == "apple-sauce" - def test_base_container_name_init_precedence(self, mock_get_connection): + def test_base_container_name_init_precedence(self): assert ( KubernetesPodOperator(base_container_name="apple-sauce", task_id=str(uuid4())).base_container_name == "apple-sauce" @@ -1336,7 +1336,7 @@ class MyK8SPodOperator(KubernetesPodOperator): ) assert MyK8SPodOperator(task_id=str(uuid4())).base_container_name == "tomato-sauce" - def test_init_container_logs(self, mock_get_connection): + def test_init_container_logs(self): marker_from_init_container = f"{uuid4()}" marker_from_main_container = f"{uuid4()}" callback = MagicMock() @@ -1367,7 +1367,7 @@ def test_init_container_logs(self, mock_get_connection): assert marker_from_init_container in calls_args assert marker_from_main_container in calls_args - def test_init_container_logs_filtered(self, mock_get_connection): + def test_init_container_logs_filtered(self): marker_from_init_container_to_log_1 = f"{uuid4()}" marker_from_init_container_to_log_2 = f"{uuid4()}" marker_from_init_container_to_ignore = f"{uuid4()}"