diff --git a/tests/unit/challenges/test_admin_filters.py b/tests/unit/challenges/test_admin_filters.py new file mode 100644 index 0000000000..7a9a93c3e2 --- /dev/null +++ b/tests/unit/challenges/test_admin_filters.py @@ -0,0 +1,73 @@ +from django.test import TestCase +from django.utils import timezone +from challenges.models import Challenge +from challenges.admin import ChallengeFilter +from hosts.models import ChallengeHostTeam +from django.contrib.auth.models import User + + +class ChallengeFilterTest(TestCase): + def setUp(self): + # Create a user + self.user = User.objects.create_user(username='testuser', password='12345') + + # Create a challenge host team + self.challenge_host_team = ChallengeHostTeam.objects.create( + team_name="Test Challenge Host Team", created_by=self.user + ) + + # Create test data + self.past_challenge = Challenge.objects.create( + title="Past Challenge", + start_date=timezone.now() - timezone.timedelta(days=10), + end_date=timezone.now() - timezone.timedelta(days=5), + published=True, + approved_by_admin=True, + is_disabled=False, + creator=self.challenge_host_team, + ) + self.present_challenge = Challenge.objects.create( + title="Present Challenge", + start_date=timezone.now() - timezone.timedelta(days=5), + end_date=timezone.now() + timezone.timedelta(days=5), + published=True, + approved_by_admin=True, + is_disabled=False, + creator=self.challenge_host_team, + ) + self.future_challenge = Challenge.objects.create( + title="Future Challenge", + start_date=timezone.now() + timezone.timedelta(days=5), + end_date=timezone.now() + timezone.timedelta(days=10), + published=True, + approved_by_admin=True, + is_disabled=False, + creator=self.challenge_host_team, + ) + + def test_past_challenge_filter(self): + request = None # Mock request if needed + filter_instance = ChallengeFilter(request, {}, Challenge, Challenge.objects.all()) + filter_instance.value = lambda: "past" + queryset = filter_instance.queryset(request, Challenge.objects.all()) + self.assertIn(self.past_challenge, queryset) + self.assertNotIn(self.present_challenge, queryset) + self.assertNotIn(self.future_challenge, queryset) + + def test_present_challenge_filter(self): + request = None # Mock request if needed + filter_instance = ChallengeFilter(request, {}, Challenge, Challenge.objects.all()) + filter_instance.value = lambda: "present" + queryset = filter_instance.queryset(request, Challenge.objects.all()) + self.assertNotIn(self.past_challenge, queryset) + self.assertIn(self.present_challenge, queryset) + self.assertNotIn(self.future_challenge, queryset) + + def test_future_challenge_filter(self): + request = None # Mock request if needed + filter_instance = ChallengeFilter(request, {}, Challenge, Challenge.objects.all()) + filter_instance.value = lambda: "future" + queryset = filter_instance.queryset(request, Challenge.objects.all()) + self.assertNotIn(self.past_challenge, queryset) + self.assertNotIn(self.present_challenge, queryset) + self.assertIn(self.future_challenge, queryset) diff --git a/tests/unit/challenges/test_apps.py b/tests/unit/challenges/test_apps.py new file mode 100644 index 0000000000..1432f482d9 --- /dev/null +++ b/tests/unit/challenges/test_apps.py @@ -0,0 +1,12 @@ +from django.test import TestCase +from django.apps import apps +from challenges.apps import ChallengesConfig + + +class ChallengesConfigTest(TestCase): + def test_app_name(self): + self.assertEqual(ChallengesConfig.name, "challenges") + + def test_app_config(self): + app_config = apps.get_app_config('challenges') + self.assertEqual(app_config.name, "challenges") diff --git a/tests/unit/challenges/test_aws_utils.py b/tests/unit/challenges/test_aws_utils.py new file mode 100644 index 0000000000..cff5f1cc07 --- /dev/null +++ b/tests/unit/challenges/test_aws_utils.py @@ -0,0 +1,2315 @@ +from unittest import TestCase, mock +import unittest +from django.core import serializers +from challenges.aws_utils import ( + create_ec2_instance, + create_eks_nodegroup, + delete_log_group, + delete_service_by_challenge_pk, + delete_workers, + describe_ec2_instance, + get_code_upload_setup_meta_for_challenge, + create_service_by_challenge_pk, + get_logs_from_cloudwatch, + restart_ec2_instance, + restart_workers, + restart_workers_signal_callback, + scale_resources, + scale_workers, + service_manager, + setup_ec2, + setup_eks_cluster, + start_ec2_instance, + start_workers, + stop_ec2_instance, + stop_workers, + terminate_ec2_instance, + update_service_by_challenge_pk, + update_sqs_retention_period, + update_sqs_retention_period_task +) +import pytest +from unittest.mock import MagicMock, mock_open, patch +from botocore.exceptions import ClientError +from http import HTTPStatus +from hosts.models import ChallengeHostTeam +from django.contrib.auth.models import User +from challenges.models import Challenge + + +class AWSUtilsTestCase(TestCase): + @mock.patch('challenges.models.ChallengeEvaluationCluster.objects.get') + @mock.patch('challenges.utils.get_challenge_model') + def test_get_code_upload_setup_meta_for_challenge_with_host_credentials(self, mock_get_challenge_model, mock_get_cluster): + # Mock the return value of get_challenge_model + mock_challenge = mock_get_challenge_model.return_value + mock_challenge.use_host_credentials = True + + mock_challenge_evaluation_cluster = mock_get_cluster.return_value + mock_challenge_evaluation_cluster.subnet_1_id = 'subnet1' + mock_challenge_evaluation_cluster.subnet_2_id = 'subnet2' + mock_challenge_evaluation_cluster.security_group_id = 'sg' + mock_challenge_evaluation_cluster.node_group_arn_role = 'node_group_arn_role' + mock_challenge_evaluation_cluster.eks_arn_role = 'eks_arn_role' + + # Call the function under test + result = get_code_upload_setup_meta_for_challenge(1) + + # Expected result + expected_result = { + "SUBNET_1": 'subnet1', + "SUBNET_2": 'subnet2', + "SUBNET_SECURITY_GROUP": 'sg', + "EKS_NODEGROUP_ROLE_ARN": 'node_group_arn_role', + "EKS_CLUSTER_ROLE_ARN": 'eks_arn_role', + } + + # Assertions + self.assertEqual(result, expected_result) + mock_get_cluster.assert_called_once_with(challenge=mock_challenge) + + @mock.patch('challenges.utils.get_challenge_model') + @mock.patch('challenges.aws_utils.VPC_DICT', { + "SUBNET_1": "vpc_subnet1", + "SUBNET_2": "vpc_subnet2", + "SUBNET_SECURITY_GROUP": "vpc_sg" + }) + @mock.patch('challenges.aws_utils.settings') + def test_get_code_upload_setup_meta_for_challenge_without_host_credentials(self, mock_settings, mock_get_challenge_model): + # Mock the return value of get_challenge_model + mock_challenge = mock_get_challenge_model.return_value + mock_challenge.use_host_credentials = False + + # Mock settings for the else case + mock_settings.EKS_NODEGROUP_ROLE_ARN = 'vpc_node_group_arn_role' + mock_settings.EKS_CLUSTER_ROLE_ARN = 'vpc_eks_arn_role' + + # Call the function under test + result = get_code_upload_setup_meta_for_challenge(1) + + # Expected result + expected_result = { + "SUBNET_1": 'vpc_subnet1', + "SUBNET_2": 'vpc_subnet2', + "SUBNET_SECURITY_GROUP": 'vpc_sg', + "EKS_NODEGROUP_ROLE_ARN": 'vpc_node_group_arn_role', + "EKS_CLUSTER_ROLE_ARN": 'vpc_eks_arn_role', + } + + # Assertions + self.assertEqual(result, expected_result) + + +@pytest.fixture +def mock_client(): + return MagicMock() + + +@pytest.fixture +def mock_challenge(): + return MagicMock() + + +@pytest.fixture +def client_token(): + return "dummy_client_token" + + +@pytest.fixture +def num_of_tasks(): + return 3 + + +class TestCreateServiceByChallengePk: + def test_create_service_success(self, mock_client, mock_challenge, client_token): + mock_challenge.workers = None + mock_challenge.task_def_arn = "valid_task_def_arn" + + response_metadata = {"HTTPStatusCode": HTTPStatus.OK} + mock_client.create_service.return_value = {"ResponseMetadata": response_metadata} + + with patch('challenges.aws_utils.register_task_def_by_challenge_pk', return_value={"ResponseMetadata": response_metadata}): + response = create_service_by_challenge_pk(mock_client, mock_challenge, client_token) + + assert response["ResponseMetadata"]["HTTPStatusCode"] == HTTPStatus.OK + mock_challenge.save.assert_called_once() + assert mock_challenge.workers == 1 + + def test_create_service_client_error(self, mock_client, mock_challenge, client_token): + mock_challenge.workers = None + mock_challenge.task_def_arn = "valid_task_def_arn" + + mock_client.create_service.side_effect = ClientError( + error_response={"Error": {"Code": "SomeError"}, "ResponseMetadata": {"HTTPStatusCode": HTTPStatus.BAD_REQUEST}}, + operation_name='CreateService' + ) + + with patch('challenges.aws_utils.register_task_def_by_challenge_pk', return_value={"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK}}): + response = create_service_by_challenge_pk(mock_client, mock_challenge, client_token) + + assert response["ResponseMetadata"]["HTTPStatusCode"] == HTTPStatus.BAD_REQUEST + + def test_service_already_exists(self, mock_client, mock_challenge, client_token): + mock_challenge.workers = 1 + + response = create_service_by_challenge_pk(mock_client, mock_challenge, client_token) + + assert response["ResponseMetadata"]["HTTPStatusCode"] == HTTPStatus.BAD_REQUEST + assert "Worker service for challenge" in response["Error"] + + def test_register_task_def_fails(self, mock_client, mock_challenge, client_token): + mock_challenge.workers = None + mock_challenge.task_def_arn = None # Simulate task definition is not yet registered + + register_task_response = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.BAD_REQUEST}} + + with patch('challenges.aws_utils.register_task_def_by_challenge_pk', return_value=register_task_response): + response = create_service_by_challenge_pk(mock_client, mock_challenge, client_token) + + assert response["ResponseMetadata"]["HTTPStatusCode"] == HTTPStatus.BAD_REQUEST + + +def test_update_service_success(mock_client, mock_challenge, num_of_tasks): + mock_challenge.queue = "dummy_queue" + mock_challenge.task_def_arn = "valid_task_def_arn" + + response_metadata = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK}} + mock_client.update_service.return_value = response_metadata + + response = update_service_by_challenge_pk(mock_client, mock_challenge, num_of_tasks) + + assert response["ResponseMetadata"]["HTTPStatusCode"] == HTTPStatus.OK + mock_challenge.save.assert_called_once() + assert mock_challenge.workers == num_of_tasks + + +def test_update_service_client_error(mock_client, mock_challenge, num_of_tasks): + mock_challenge.queue = "dummy_queue" + mock_challenge.task_def_arn = "valid_task_def_arn" + + mock_client.update_service.side_effect = ClientError( + error_response={"Error": {"Code": "ServiceError"}, "ResponseMetadata": {"HTTPStatusCode": HTTPStatus.BAD_REQUEST}}, + operation_name='UpdateService' + ) + + response = update_service_by_challenge_pk(mock_client, mock_challenge, num_of_tasks) + + assert response["ResponseMetadata"]["HTTPStatusCode"] == HTTPStatus.BAD_REQUEST + assert "ServiceError" in response["Error"]["Code"] + + +def test_update_service_force_new_deployment(mock_client, mock_challenge, num_of_tasks): + mock_challenge.queue = "dummy_queue" + mock_challenge.task_def_arn = "valid_task_def_arn" + + response_metadata = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK}} + mock_client.update_service.return_value = response_metadata + + response = update_service_by_challenge_pk(mock_client, mock_challenge, num_of_tasks, force_new_deployment=True) + + assert response["ResponseMetadata"]["HTTPStatusCode"] == HTTPStatus.OK + mock_challenge.save.assert_called_once() + assert mock_challenge.workers == num_of_tasks + + +def test_delete_service_success_when_workers_zero(mock_challenge, mock_client): + mock_challenge.workers = 0 + mock_challenge.task_def_arn = "valid_task_def_arn" # Ensure task_def_arn is set to a valid string + response_metadata_ok = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK}} + + with patch('challenges.aws_utils.get_boto3_client', return_value=mock_client): + mock_client.delete_service.return_value = response_metadata_ok + + response = delete_service_by_challenge_pk(mock_challenge) + + assert response["ResponseMetadata"]["HTTPStatusCode"] == HTTPStatus.OK + mock_challenge.save.assert_called() + mock_client.deregister_task_definition.assert_called_once_with(taskDefinition="valid_task_def_arn") + + +def test_delete_service_success_when_workers_not_zero(mock_challenge, mock_client): + mock_challenge.workers = 3 + mock_challenge.task_def_arn = "valid_task_def_arn" + response_metadata_ok = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK}} + + with patch('challenges.aws_utils.get_boto3_client', return_value=mock_client): + with patch('challenges.aws_utils.update_service_by_challenge_pk', return_value=response_metadata_ok): + mock_client.delete_service.return_value = response_metadata_ok + + response = delete_service_by_challenge_pk(mock_challenge) + + assert response["ResponseMetadata"]["HTTPStatusCode"] == HTTPStatus.OK + mock_challenge.save.assert_called() + mock_client.deregister_task_definition.assert_called_once_with(taskDefinition="valid_task_def_arn") + + +def test_update_service_failure(mock_challenge, mock_client): + mock_challenge.workers = 3 + response_metadata_error = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.BAD_REQUEST}} + + with patch('challenges.aws_utils.get_boto3_client', return_value=mock_client): + with patch('challenges.aws_utils.update_service_by_challenge_pk', return_value=response_metadata_error): + response = delete_service_by_challenge_pk(mock_challenge) + + assert response["ResponseMetadata"]["HTTPStatusCode"] == HTTPStatus.BAD_REQUEST + mock_client.delete_service.assert_not_called() + + +def test_delete_service_failure(mock_challenge, mock_client): + mock_challenge.workers = 0 + response_metadata_error = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.BAD_REQUEST}} + + with patch('challenges.aws_utils.get_boto3_client', return_value=mock_client): + mock_client.delete_service.return_value = response_metadata_error + + response = delete_service_by_challenge_pk(mock_challenge) + + assert response["ResponseMetadata"]["HTTPStatusCode"] == HTTPStatus.BAD_REQUEST + mock_challenge.save.assert_not_called() + + +def test_deregister_task_definition_failure(mock_challenge, mock_client): + mock_challenge.workers = 0 + response_metadata_ok = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK}} + + with patch('challenges.aws_utils.get_boto3_client', return_value=mock_client): + mock_client.delete_service.return_value = response_metadata_ok + mock_client.deregister_task_definition.side_effect = ClientError( + error_response={"Error": {"Code": "DeregisterError"}, "ResponseMetadata": {"HTTPStatusCode": HTTPStatus.BAD_REQUEST}}, + operation_name='DeregisterTaskDefinition' + ) + + response = delete_service_by_challenge_pk(mock_challenge) + + assert response["ResponseMetadata"]["HTTPStatusCode"] == HTTPStatus.BAD_REQUEST + mock_client.deregister_task_definition.assert_called_once_with(taskDefinition=mock_challenge.task_def_arn) + + +def test_delete_service_client_error(mock_challenge, mock_client): + mock_challenge.workers = 0 + + with patch('challenges.aws_utils.get_boto3_client', return_value=mock_client): + mock_client.delete_service.side_effect = ClientError( + error_response={"Error": {"Code": "DeleteServiceError"}, "ResponseMetadata": {"HTTPStatusCode": HTTPStatus.BAD_REQUEST}}, + operation_name='DeleteService' + ) + + response = delete_service_by_challenge_pk(mock_challenge) + + assert response["ResponseMetadata"]["HTTPStatusCode"] == HTTPStatus.BAD_REQUEST + mock_challenge.save.assert_not_called() + mock_client.deregister_task_definition.assert_not_called() + + +class TestServiceManager: + + @pytest.fixture + def mock_client(self): + return MagicMock() + + @pytest.fixture + def mock_challenge(self): + return MagicMock() + + def test_service_manager_updates_service(self, mock_client, mock_challenge): + # Setup + mock_challenge.workers = 1 + num_of_tasks = 5 + force_new_deployment = False + response_metadata_ok = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK}} + + # Mock the update_service_by_challenge_pk to return a mock response + with patch('challenges.aws_utils.update_service_by_challenge_pk', return_value=response_metadata_ok) as mock_update: + # Call the function + response = service_manager(mock_client, mock_challenge, num_of_tasks=num_of_tasks, force_new_deployment=force_new_deployment) + + # Verify + assert response == response_metadata_ok + mock_update.assert_called_once_with(mock_client, mock_challenge, num_of_tasks, force_new_deployment) + + def test_service_manager_creates_service(self, mock_client, mock_challenge): + # Setup + mock_challenge.workers = None + response_metadata_ok = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK}} + + # Mock client_token_generator and create_service_by_challenge_pk to return a mock response + with patch('challenges.aws_utils.client_token_generator', return_value='mock_client_token'): + with patch('challenges.aws_utils.create_service_by_challenge_pk', return_value=response_metadata_ok) as mock_create: + # Call the function + response = service_manager(mock_client, mock_challenge) + + # Verify + assert response == response_metadata_ok + mock_create.assert_called_once_with(mock_client, mock_challenge, 'mock_client_token') + + +class TestStopEc2Instance(unittest.TestCase): + @patch('challenges.aws_utils.get_boto3_client') + def test_stop_instance_success(self, mock_get_boto3_client): + # Mocking the EC2 client + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + # Mocking describe_instance_status response + mock_ec2.describe_instance_status.return_value = { + "InstanceStatuses": [{ + "SystemStatus": {"Status": "ok"}, + "InstanceStatus": {"Status": "ok"}, + "InstanceState": {"Name": "running"} + }] + } + # Mocking stop_instances response + mock_ec2.stop_instances.return_value = {"StoppingInstances": [{"InstanceId": "i-1234567890abcdef0", "CurrentState": {"Name": "stopping"}}]} + + # Creating a mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = "i-1234567890abcdef0" + challenge.pk = 1 + + # Calling the function + result = stop_ec2_instance(challenge) + + # Checking the response + self.assertEqual(result['response'], {"StoppingInstances": [{"InstanceId": "i-1234567890abcdef0", "CurrentState": {"Name": "stopping"}}]}) + self.assertEqual(result['message'], "Instance for challenge 1 successfully stopped.") + + @patch('challenges.aws_utils.get_boto3_client') + def test_instance_not_running(self, mock_get_boto3_client): + # Mocking the EC2 client + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + # Mocking describe_instance_status response + mock_ec2.describe_instance_status.return_value = { + "InstanceStatuses": [{ + "SystemStatus": {"Status": "ok"}, + "InstanceStatus": {"Status": "ok"}, + "InstanceState": {"Name": "stopped"} + }] + } + + # Creating a mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = "i-1234567890abcdef0" + challenge.pk = 1 + + # Calling the function + result = stop_ec2_instance(challenge) + + # Checking the response + self.assertEqual(result['error'], "Instance for challenge 1 is not running. Please ensure the instance is running.") + + @patch('challenges.aws_utils.get_boto3_client') + def test_status_checks_not_ready(self, mock_get_boto3_client): + # Mocking the EC2 client + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + # Mocking describe_instance_status response + mock_ec2.describe_instance_status.return_value = { + "InstanceStatuses": [{ + "SystemStatus": {"Status": "impaired"}, + "InstanceStatus": {"Status": "ok"}, + "InstanceState": {"Name": "running"} + }] + } + + # Creating a mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = "i-1234567890abcdef0" + challenge.pk = 1 + + # Calling the function + result = stop_ec2_instance(challenge) + + # Checking the response + self.assertEqual(result['error'], "Instance status checks are not ready for challenge 1. Please wait for the status checks to pass.") + + @patch('challenges.aws_utils.get_boto3_client') + def test_instance_not_found(self, mock_get_boto3_client): + # Mocking the EC2 client + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + + # Mocking describe_instance_status response + mock_ec2.describe_instance_status.return_value = { + "InstanceStatuses": [] + } + + # Creating a mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = "i-1234567890abcdef0" + challenge.pk = 1 + + # Calling the function + result = stop_ec2_instance(challenge) + + # Checking the response + self.assertEqual(result['error'], "Instance for challenge 1 not found. Please ensure the instance exists.") + + +class TestDescribeEC2Instance(unittest.TestCase): + @patch('challenges.aws_utils.get_boto3_client') # Mock the `get_boto3_client` function + def test_describe_ec2_instance_success(self, mock_get_boto3_client): + # Setup mock + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + mock_response = { + "Reservations": [ + { + "Instances": [ + {"InstanceId": "i-1234567890abcdef0", "State": {"Name": "running"}} + ] + } + ] + } + mock_ec2.describe_instances.return_value = mock_response + # Mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = "i-1234567890abcdef0" + # Call the function + result = describe_ec2_instance(challenge) + # Assert the result + self.assertEqual(result, {"message": {"InstanceId": "i-1234567890abcdef0", "State": {"Name": "running"}}}) + mock_ec2.describe_instances.assert_called_once_with(InstanceIds=["i-1234567890abcdef0"]) + + @patch('challenges.aws_utils.get_boto3_client') + @patch('builtins.open', new_callable=mock_open, read_data="ec2_worker_script_content") + def test_set_ec2_storage(self, mock_open, mock_get_boto3_client): + # Mock setup + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + mock_response = {"Instances": [{"InstanceId": "i-1234567890abcdef0"}]} + mock_ec2.run_instances.return_value = mock_response + + # Mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = None + challenge.pk = 1 + challenge.ec2_storage = None + challenge.worker_instance_type = None + challenge.worker_image_url = None + challenge.queue = 'test_queue' # Add this to avoid None issues + + # Call the function with ec2_storage + result = create_ec2_instance(challenge, ec2_storage=100) + + # Assert the expected output + expected_message = "Instance for challenge 1 successfully created." + self.assertEqual(result['response'], mock_response) + self.assertEqual(result['message'], expected_message) + + # Ensure the challenge was updated and saved + self.assertEqual(challenge.ec2_storage, 100) + challenge.save.assert_called_once() + + # Ensure the worker script was read correctly + mock_open.assert_called_once_with('/code/scripts/deployment/deploy_ec2_worker.sh') + self.assertEqual(mock_open().read(), "ec2_worker_script_content") + + @patch('challenges.aws_utils.get_boto3_client') + @patch('builtins.open', new_callable=mock_open, read_data="ec2_worker_script_content") + def test_set_worker_instance_type(self, mock_open, mock_get_boto3_client): + # Mock setup + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + mock_response = {"Instances": [{"InstanceId": "i-1234567890abcdef0"}]} + mock_ec2.run_instances.return_value = mock_response + + # Mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = None + challenge.pk = 1 + challenge.ec2_storage = None + challenge.worker_instance_type = None + challenge.worker_image_url = None + challenge.queue = 'test_queue' # Add this to avoid None issues + + # Call the function with worker_instance_type + result = create_ec2_instance(challenge, worker_instance_type='t3.medium') + + # Assert the expected output + expected_message = "Instance for challenge 1 successfully created." + self.assertEqual(result['response'], mock_response) + self.assertEqual(result['message'], expected_message) + + # Ensure the challenge was updated and saved + self.assertEqual(challenge.worker_instance_type, 't3.medium') + challenge.save.assert_called_once() + + # Ensure the worker script was read correctly + mock_open.assert_called_once_with('/code/scripts/deployment/deploy_ec2_worker.sh') + self.assertEqual(mock_open().read(), "ec2_worker_script_content") + + @patch('challenges.aws_utils.get_boto3_client') + @patch('builtins.open', new_callable=mock_open, read_data="ec2_worker_script_content") + def test_set_worker_image_url(self, mock_open, mock_get_boto3_client): + # Mock setup + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + mock_response = {"Instances": [{"InstanceId": "i-1234567890abcdef0"}]} + mock_ec2.run_instances.return_value = mock_response + + # Mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = None + challenge.pk = 1 + challenge.ec2_storage = None + challenge.worker_instance_type = None + challenge.worker_image_url = None + challenge.queue = 'test_queue' # Add this to avoid None issues + + # Call the function with worker_image_url + result = create_ec2_instance(challenge, worker_image_url='ami-12345678') + + # Assert the expected output + expected_message = "Instance for challenge 1 successfully created." + self.assertEqual(result['response'], mock_response) + self.assertEqual(result['message'], expected_message) + + # Ensure the challenge was updated and saved + self.assertEqual(challenge.worker_image_url, 'ami-12345678') + challenge.save.assert_called_once() + + # Ensure the worker script was read correctly + mock_open.assert_called_once_with('/code/scripts/deployment/deploy_ec2_worker.sh') + self.assertEqual(mock_open().read(), "ec2_worker_script_content") + + @patch('challenges.aws_utils.get_boto3_client') + def test_multiple_instances(self, mock_get_boto3_client): + # Setup mock + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + mock_response = { + "Reservations": [ + { + "Instances": [ + {"InstanceId": "i-1234567890abcdef0", "State": {"Name": "running"}}, + {"InstanceId": "i-0987654321fedcba0", "State": {"Name": "stopped"}} + ] + } + ] + } + mock_ec2.describe_instances.return_value = mock_response + + # Mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = "i-1234567890abcdef0" + + # Call the function + result = describe_ec2_instance(challenge) + + # Assert the result + self.assertEqual(result, {"message": {"InstanceId": "i-1234567890abcdef0", "State": {"Name": "running"}}}) + + +class TestStartEC2Instance(unittest.TestCase): + @patch('challenges.aws_utils.get_boto3_client') # Mock the `get_boto3_client` function + def test_start_ec2_instance_success(self, mock_get_boto3_client): + # Setup mock + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + mock_response = { + "Reservations": [ + { + "Instances": [ + {"InstanceId": "i-1234567890abcdef0", "State": {"Name": "stopped"}} + ] + } + ] + } + mock_ec2.describe_instances.return_value = mock_response + mock_start_response = {"StartingInstances": [{"InstanceId": "i-1234567890abcdef0", "CurrentState": {"Name": "pending"}}]} + mock_ec2.start_instances.return_value = mock_start_response + # Mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = "i-1234567890abcdef0" + challenge.pk = 1 + # Call the function + result = start_ec2_instance(challenge) + # Assert the result + self.assertEqual(result, { + "response": mock_start_response, + "message": "Instance for challenge 1 successfully started." + }) + mock_ec2.describe_instances.assert_called_once_with(InstanceIds=["i-1234567890abcdef0"]) + mock_ec2.start_instances.assert_called_once_with(InstanceIds=["i-1234567890abcdef0"]) + + @patch('challenges.aws_utils.get_boto3_client') + def test_start_ec2_instance_already_running(self, mock_get_boto3_client): + # Setup mock + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + + mock_response = { + "Reservations": [ + { + "Instances": [ + {"InstanceId": "i-1234567890abcdef0", "State": {"Name": "running"}} + ] + } + ] + } + mock_ec2.describe_instances.return_value = mock_response + + # Mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = "i-1234567890abcdef0" + challenge.pk = 1 + # Call the function + result = start_ec2_instance(challenge) + + # Assert the result + self.assertEqual(result, { + "error": "Instance for challenge 1 is running. Please ensure the instance is stopped." + }) + mock_ec2.describe_instances.assert_called_once_with(InstanceIds=["i-1234567890abcdef0"]) + mock_ec2.start_instances.assert_not_called() + + @patch('challenges.aws_utils.get_boto3_client') + def test_start_ec2_instance_no_instances(self, mock_get_boto3_client): + # Setup mock + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + + mock_response = { + "Reservations": [] + } + mock_ec2.describe_instances.return_value = mock_response + + # Mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = "i-1234567890abcdef0" + challenge.pk = 1 + + # Call the function + result = start_ec2_instance(challenge) + + # Assert the result + self.assertEqual(result, { + "error": "Instance for challenge 1 not found. Please ensure the instance exists." + }) + mock_ec2.describe_instances.assert_called_once_with(InstanceIds=["i-1234567890abcdef0"]) + mock_ec2.start_instances.assert_not_called() + + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.logger') + def test_start_ec2_instance_exception(self, mock_logger, mock_get_boto3_client): + # Setup mock + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + mock_ec2.describe_instances.return_value = { + "Reservations": [ + { + "Instances": [ + {"InstanceId": "i-1234567890abcdef0", "State": {"Name": "stopped"}} + ] + } + ] + } + mock_ec2.start_instances.side_effect = ClientError({"Error": {"Message": "Test Exception"}}, "StartInstances") + + # Mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = "i-1234567890abcdef0" + challenge.pk = 1 + + # Call the function + result = start_ec2_instance(challenge) + + # Assert the result + self.assertIn("error", result) + mock_logger.exception.assert_called_once() + + +class TestRestartEC2Instance(unittest.TestCase): + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.logger') + def test_restart_ec2_instance_success(self, mock_logger, mock_get_boto3_client): + # Setup mock + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + mock_response = {"RebootingInstances": [{"InstanceId": "i-1234567890abcdef0"}]} + mock_ec2.reboot_instances.return_value = mock_response + + # Mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = "i-1234567890abcdef0" + challenge.pk = 1 + + # Call the function + result = restart_ec2_instance(challenge) + + # Assert the expected output + expected_message = "Instance for challenge 1 successfully restarted." + self.assertEqual(result['response'], mock_response) + self.assertEqual(result['message'], expected_message) + + # Ensure no exceptions were logged + mock_logger.exception.assert_not_called() + + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.logger') + def test_restart_ec2_instance_client_error(self, mock_logger, mock_get_boto3_client): + # Setup mock + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + + # Simulate ClientError + error_response = {'Error': {'Code': 'InvalidInstanceID.NotFound', 'Message': 'The instance ID does not exist'}} + mock_ec2.reboot_instances.side_effect = ClientError(error_response, 'RebootInstances') + + # Mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = "i-1234567890abcdef0" + + # Call the function + result = restart_ec2_instance(challenge) + + # Assert the expected output + self.assertEqual(result['error'], error_response) + + # Ensure the exception was logged + mock_logger.exception.assert_called_once_with(mock_ec2.reboot_instances.side_effect) + + +class TestTerminateEC2Instance(unittest.TestCase): + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.logger') + def test_terminate_ec2_instance_success(self, mock_logger, mock_get_boto3_client): + # Setup mock + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + mock_response = {"TerminatingInstances": [{"InstanceId": "i-1234567890abcdef0"}]} + mock_ec2.terminate_instances.return_value = mock_response + + # Mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = "i-1234567890abcdef0" + challenge.pk = 1 + + # Call the function + result = terminate_ec2_instance(challenge) + + # Assert the expected output + expected_message = "Instance for challenge 1 successfully terminated." + self.assertEqual(result['response'], mock_response) + self.assertEqual(result['message'], expected_message) + + # Ensure the EC2 instance ID was cleared and the challenge was saved + self.assertEqual(challenge.ec2_instance_id, "") + challenge.save.assert_called_once() + + # Ensure no exceptions were logged + mock_logger.exception.assert_not_called() + + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.logger') + def test_terminate_ec2_instance_client_error(self, mock_logger, mock_get_boto3_client): + # Setup mock + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + + # Simulate ClientError + error_response = {'Error': {'Code': 'InvalidInstanceID.NotFound', 'Message': 'The instance ID does not exist'}} + mock_ec2.terminate_instances.side_effect = ClientError(error_response, 'TerminateInstances') + + # Mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = "i-1234567890abcdef0" + + # Call the function + result = terminate_ec2_instance(challenge) + + # Assert the expected output + self.assertEqual(result['error'], error_response) + + # Ensure the exception was logged + mock_logger.exception.assert_called_once_with(mock_ec2.terminate_instances.side_effect) + + # Ensure the EC2 instance ID was not cleared and the challenge was not saved + self.assertNotEqual(challenge.ec2_instance_id, "") + challenge.save.assert_not_called() + + +class TestCreateEC2Instance(unittest.TestCase): + @patch('challenges.aws_utils.get_boto3_client') + def test_existing_ec2_instance_id(self, mock_get_boto3_client): + # Mock challenge object with existing EC2 instance ID + challenge = MagicMock() + challenge.ec2_instance_id = "i-1234567890abcdef0" + challenge.pk = 1 + + # Call the function + result = create_ec2_instance(challenge) + + # Assert the expected output + expected_error = ( + "Challenge 1 has existing EC2 instance ID. " + "Please ensure there is no existing associated instance before trying to create one." + ) + self.assertEqual(result['error'], expected_error) + + @patch('challenges.aws_utils.get_boto3_client') + @patch('builtins.open', new_callable=mock_open, read_data="ec2_worker_script_content") + def test_create_ec2_instance_success(self, mock_open, mock_get_boto3_client): + # Setup mock + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + mock_response = {"Instances": [{"InstanceId": "i-1234567890abcdef0"}]} + mock_ec2.run_instances.return_value = mock_response + + # Mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = None + challenge.pk = 1 + challenge.ec2_storage = 50 + challenge.worker_instance_type = 't2.micro' + challenge.worker_image_url = 'ami-0747bdcabd34c712a' + challenge.queue = 'some_queue' + + # Mock settings + with patch('challenges.aws_utils.settings', ENVIRONMENT='test'): + # Call the function + result = create_ec2_instance(challenge) + + # Assert the expected output + expected_message = "Instance for challenge 1 successfully created." + self.assertEqual(result['response'], mock_response) + self.assertEqual(result['message'], expected_message) + + # Ensure the challenge was updated and saved + self.assertTrue(challenge.uses_ec2_worker) + self.assertEqual(challenge.ec2_instance_id, "i-1234567890abcdef0") + challenge.save.assert_called_once() + + # Ensure the worker script was read correctly + mock_open.assert_called_once_with('/code/scripts/deployment/deploy_ec2_worker.sh') + self.assertEqual(mock_open().read(), "ec2_worker_script_content") + + @patch('challenges.aws_utils.get_boto3_client') + @patch('builtins.open', new_callable=mock_open, read_data="ec2_worker_script_content") + @patch('challenges.aws_utils.logger') + def test_create_ec2_instance_client_error(self, mock_logger, mock_open, mock_get_boto3_client): + # Setup mock + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + + # Simulate ClientError + error_response = {'Error': {'Code': 'InvalidParameterValue', 'Message': 'The parameter value is invalid'}} + client_error = ClientError(error_response, 'RunInstances') + mock_ec2.run_instances.side_effect = client_error + + # Mock challenge object + challenge = MagicMock() + challenge.ec2_instance_id = None + challenge.pk = 1 + challenge.queue = 'test_queue' + challenge.worker_image_url = 'worker_image_url' + + # Mock aws_keys and settings + with patch('challenges.aws_utils.aws_keys', { + "AWS_ACCOUNT_ID": "123456789012", + "AWS_ACCESS_KEY_ID": "access_key", + "AWS_SECRET_ACCESS_KEY": "secret_key", + "AWS_REGION": "us-west-1" + }): + with patch('challenges.aws_utils.settings', ENVIRONMENT='test'): + # Call the function + result = create_ec2_instance(challenge) + + # Assert the expected output + self.assertEqual(result['error'], error_response) + + mock_logger.exception.assert_called_once() + logged_exception = mock_logger.exception.call_args[0][0] + self.assertIsInstance(logged_exception, ClientError) + self.assertEqual(str(logged_exception), str(client_error)) + + +class TestUpdateSQSRetentionPeriod(unittest.TestCase): + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.logger') + def test_update_sqs_retention_period_success(self, mock_logger, mock_get_boto3_client): + # Setup mock SQS client + mock_sqs = MagicMock() + mock_get_boto3_client.return_value = mock_sqs + mock_sqs.get_queue_url.return_value = {'QueueUrl': 'https://sqs.us-west-1.amazonaws.com/123456789012/test_queue'} + mock_sqs.set_queue_attributes.return_value = {'ResponseMetadata': {'HTTPStatusCode': 200}} + + # Mock challenge object + challenge = MagicMock() + challenge.queue = 'test_queue' + challenge.sqs_retention_period = 86400 # 1 day in seconds + + # Call the function + result = update_sqs_retention_period(challenge) + + # Assert the expected output + expected_response = {'ResponseMetadata': {'HTTPStatusCode': 200}} + self.assertEqual(result, {"message": expected_response}) + + # Ensure methods were called with expected arguments + mock_sqs.get_queue_url.assert_called_once_with(QueueName=challenge.queue) + mock_sqs.set_queue_attributes.assert_called_once_with( + QueueUrl='https://sqs.us-west-1.amazonaws.com/123456789012/test_queue', + Attributes={'MessageRetentionPeriod': '86400'} + ) + mock_logger.exception.assert_not_called() + + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.logger') + def test_update_sqs_retention_period_failure(self, mock_logger, mock_get_boto3_client): + # Setup mock SQS client + mock_sqs = MagicMock() + mock_get_boto3_client.return_value = mock_sqs + mock_sqs.get_queue_url.side_effect = ClientError( + {'Error': {'Code': 'QueueDoesNotExist', 'Message': 'The queue does not exist'}}, + 'GetQueueUrl' + ) + + # Mock challenge object + challenge = MagicMock() + challenge.queue = 'test_queue' + challenge.sqs_retention_period = 86400 # 1 day in seconds + + # Call the function + result = update_sqs_retention_period(challenge) + + # Assert the expected output + self.assertEqual(result, {"error": "An error occurred (QueueDoesNotExist) when calling the GetQueueUrl operation: The queue does not exist"}) + + # Ensure methods were called with expected arguments + mock_sqs.get_queue_url.assert_called_once_with(QueueName=challenge.queue) + mock_sqs.set_queue_attributes.assert_not_called() + mock_logger.exception.assert_called_once() + + +class TestStartWorkers(unittest.TestCase): + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.service_manager') + @patch('challenges.aws_utils.settings', DEBUG=True) + def test_start_workers_debug_mode(self, mock_settings, mock_service_manager, mock_get_boto3_client): + # Mock queryset + challenge = MagicMock() + challenge.pk = 1 + challenge.workers = None + queryset = [challenge] + + # Call the function + result = start_workers(queryset) + + # Assert the expected output + expected_result = { + "count": 0, + "failures": [ + { + "message": "Workers cannot be started on AWS ECS service in development environment", + "challenge_pk": 1 + } + ] + } + self.assertEqual(result, expected_result) + mock_get_boto3_client.assert_not_called() + mock_service_manager.assert_not_called() + + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.service_manager') + @patch('challenges.aws_utils.settings', DEBUG=False) + def test_start_workers_success(self, mock_settings, mock_service_manager, mock_get_boto3_client): + # Setup mock ECS client + mock_client = MagicMock() + mock_get_boto3_client.return_value = mock_client + mock_service_manager.return_value = { + "ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK} + } + + # Mock queryset + challenge = MagicMock() + challenge.pk = 1 + challenge.workers = 0 + queryset = [challenge] + + # Call the function + result = start_workers(queryset) + + # Assert the expected output + expected_result = {"count": 1, "failures": []} + self.assertEqual(result, expected_result) + + # Ensure methods were called with expected arguments + aws_keys = {'AWS_ACCOUNT_ID': 'x', 'AWS_ACCESS_KEY_ID': 'x', 'AWS_SECRET_ACCESS_KEY': 'x', 'AWS_REGION': 'us-east-1', 'AWS_STORAGE_BUCKET_NAME': 'evalai-s3-bucket'} + mock_get_boto3_client.assert_called_once_with("ecs", aws_keys) + mock_service_manager.assert_called_once_with( + mock_client, challenge=challenge, num_of_tasks=1 + ) + + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.service_manager') + @patch('challenges.aws_utils.settings', DEBUG=False) + def test_start_workers_failure(self, mock_settings, mock_service_manager, mock_get_boto3_client): + # Setup mock ECS client + mock_client = MagicMock() + mock_get_boto3_client.return_value = mock_client + mock_service_manager.return_value = { + "ResponseMetadata": {"HTTPStatusCode": HTTPStatus.BAD_REQUEST}, + "Error": "Some error occurred" + } + + # Mock queryset + challenge = MagicMock() + challenge.pk = 1 + challenge.workers = 0 + queryset = [challenge] + + # Call the function + result = start_workers(queryset) + + # Assert the expected output + expected_result = { + "count": 0, + "failures": [ + { + "message": "Some error occurred", + "challenge_pk": 1 + } + ] + } + self.assertEqual(result, expected_result) + + # Ensure methods were called with expected arguments + aws_keys = {'AWS_ACCOUNT_ID': 'x', 'AWS_ACCESS_KEY_ID': 'x', 'AWS_SECRET_ACCESS_KEY': 'x', 'AWS_REGION': 'us-east-1', 'AWS_STORAGE_BUCKET_NAME': 'evalai-s3-bucket'} + mock_get_boto3_client.assert_called_once_with("ecs", aws_keys) + mock_service_manager.assert_called_once_with( + mock_client, challenge=challenge, num_of_tasks=1 + ) + + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.service_manager') + @patch('challenges.aws_utils.settings', DEBUG=False) + def test_start_workers_with_active_workers(self, mock_settings, mock_service_manager, mock_get_boto3_client): + # Mock queryset with active workers + challenge = MagicMock() + challenge.pk = 1 + challenge.workers = 5 + queryset = [challenge] + + # Call the function + result = start_workers(queryset) + + # Assert the expected output + expected_result = { + "count": 0, + "failures": [ + { + "message": "Please select challenge with inactive workers only.", + "challenge_pk": 1 + } + ] + } + self.assertEqual(result, expected_result) + + @patch('challenges.aws_utils.settings', DEBUG=True) + def test_stop_workers_debug_mode(self, mock_settings): + # Mock queryset with challenges + challenge1 = MagicMock() + challenge1.pk = 1 + challenge2 = MagicMock() + challenge2.pk = 2 + queryset = [challenge1, challenge2] + + # Call the function + result = stop_workers(queryset) + + # Assert the expected output + expected_failures = [ + {"message": "Workers cannot be stopped on AWS ECS service in development environment", "challenge_pk": 1}, + {"message": "Workers cannot be stopped on AWS ECS service in development environment", "challenge_pk": 2} + ] + self.assertEqual(result, {"count": 0, "failures": expected_failures}) + + @patch('challenges.aws_utils.settings', DEBUG=False) + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.service_manager') + def test_stop_workers_success(self, mock_service_manager, mock_get_boto3_client, mock_settings): + # Mock client and service manager response + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + mock_response = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK}} + mock_service_manager.return_value = mock_response + + # Mock queryset with active workers + challenge1 = MagicMock() + challenge1.pk = 1 + challenge1.workers = 5 + challenge2 = MagicMock() + challenge2.pk = 2 + challenge2.workers = 3 + queryset = [challenge1, challenge2] + + # Call the function + result = stop_workers(queryset) + + # Assert the expected output + self.assertEqual(result, {"count": 2, "failures": []}) + + # Ensure the service manager was called correctly + mock_service_manager.assert_called_with(mock_ec2, challenge=challenge2, num_of_tasks=0) + self.assertEqual(mock_service_manager.call_count, 2) + + @patch('challenges.aws_utils.settings', DEBUG=False) + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.service_manager') + def test_stop_workers_failure(self, mock_service_manager, mock_get_boto3_client, mock_settings): + # Mock client and service manager response with error + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + mock_response = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.BAD_REQUEST}, "Error": "Error stopping worker"} + mock_service_manager.return_value = mock_response + + # Mock queryset with active workers + challenge1 = MagicMock() + challenge1.pk = 1 + challenge1.workers = 5 + challenge2 = MagicMock() + challenge2.pk = 2 + challenge2.workers = 0 # No active workers + queryset = [challenge1, challenge2] + + # Call the function + result = stop_workers(queryset) + + # Assert the expected output + expected_failures = [ + {"message": "Error stopping worker", "challenge_pk": 1}, + {"message": "Please select challenges with active workers only.", "challenge_pk": 2} + ] + self.assertEqual(result, {"count": 0, "failures": expected_failures}) + + # Ensure the service manager was called correctly + mock_service_manager.assert_called_with(mock_ec2, challenge=challenge1, num_of_tasks=0) + self.assertEqual(mock_service_manager.call_count, 1) + + @patch('challenges.aws_utils.settings', DEBUG=False) + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.service_manager') + def test_stop_workers_no_active_workers(self, mock_service_manager, mock_get_boto3_client, mock_settings): + # Mock client and service manager response with success + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + mock_response = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK}} + mock_service_manager.return_value = mock_response + + # Mock queryset with no active workers + challenge1 = MagicMock() + challenge1.pk = 1 + challenge1.workers = 0 + challenge2 = MagicMock() + challenge2.pk = 2 + challenge2.workers = 0 + queryset = [challenge1, challenge2] + + # Call the function + result = stop_workers(queryset) + + # Assert the expected output + expected_failures = [ + {"message": "Please select challenges with active workers only.", "challenge_pk": 1}, + {"message": "Please select challenges with active workers only.", "challenge_pk": 2} + ] + self.assertEqual(result, {"count": 0, "failures": expected_failures}) + + # Ensure the service manager was not called + mock_service_manager.assert_not_called() + + +class TestScaleWorkers(unittest.TestCase): + @patch('challenges.aws_utils.settings', DEBUG=True) + def test_scale_workers_debug_mode(self, mock_settings): + # Mock queryset with challenges + challenge1 = MagicMock() + challenge1.pk = 1 + challenge2 = MagicMock() + challenge2.pk = 2 + queryset = [challenge1, challenge2] + + # Call the function + result = scale_workers(queryset, num_of_tasks=5) + + # Assert the expected output + expected_failures = [ + {"message": "Workers cannot be scaled on AWS ECS service in development environment", "challenge_pk": 1}, + {"message": "Workers cannot be scaled on AWS ECS service in development environment", "challenge_pk": 2} + ] + self.assertEqual(result, {"count": 0, "failures": expected_failures}) + + @patch('challenges.aws_utils.settings', DEBUG=False) + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.service_manager') + def test_scale_workers_no_current_workers(self, mock_service_manager, mock_get_boto3_client, mock_settings): + # Mock client and service manager response + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + + # Mock queryset with no current workers + challenge1 = MagicMock() + challenge1.pk = 1 + challenge1.workers = None + challenge2 = MagicMock() + challenge2.pk = 2 + challenge2.workers = None + queryset = [challenge1, challenge2] + + # Call the function + result = scale_workers(queryset, num_of_tasks=5) + + # Assert the expected output + expected_failures = [ + {"message": "Please start worker(s) before scaling.", "challenge_pk": 1}, + {"message": "Please start worker(s) before scaling.", "challenge_pk": 2} + ] + self.assertEqual(result, {"count": 0, "failures": expected_failures}) + + # Ensure the service manager was not called + mock_service_manager.assert_not_called() + + @patch('challenges.aws_utils.settings', DEBUG=False) + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.service_manager') + def test_scale_workers_same_num_of_tasks(self, mock_service_manager, mock_get_boto3_client, mock_settings): + # Mock client and service manager response + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + + # Mock queryset where num_of_tasks is the same as current workers + challenge1 = MagicMock() + challenge1.pk = 1 + challenge1.workers = 5 + challenge2 = MagicMock() + challenge2.pk = 2 + challenge2.workers = 5 + queryset = [challenge1, challenge2] + + # Call the function + result = scale_workers(queryset, num_of_tasks=5) + + # Assert the expected output + expected_failures = [ + {"message": "Please scale to a different number. Challenge has 5 worker(s).", "challenge_pk": 1}, + {"message": "Please scale to a different number. Challenge has 5 worker(s).", "challenge_pk": 2} + ] + self.assertEqual(result, {"count": 0, "failures": expected_failures}) + + # Ensure the service manager was not called + mock_service_manager.assert_not_called() + + @patch('challenges.aws_utils.settings', DEBUG=False) + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.service_manager') + def test_scale_workers_success(self, mock_service_manager, mock_get_boto3_client, mock_settings): + # Mock client and service manager response + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + mock_response = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK}} + mock_service_manager.return_value = mock_response + + # Mock queryset with current workers and valid scaling + challenge1 = MagicMock() + challenge1.pk = 1 + challenge1.workers = 5 + challenge2 = MagicMock() + challenge2.pk = 2 + challenge2.workers = 10 + queryset = [challenge1, challenge2] + + # Call the function + result = scale_workers(queryset, num_of_tasks=7) + + # Assert the expected output + self.assertEqual(result, {"count": 2, "failures": []}) + + # Ensure the service manager was called correctly + mock_service_manager.assert_called_with(mock_ec2, challenge=challenge2, num_of_tasks=7) + self.assertEqual(mock_service_manager.call_count, 2) + + @patch('challenges.aws_utils.settings', DEBUG=False) + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.service_manager') + def test_scale_workers_failure(self, mock_service_manager, mock_get_boto3_client, mock_settings): + # Mock client and service manager response with error + mock_ec2 = MagicMock() + mock_get_boto3_client.return_value = mock_ec2 + mock_response = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.BAD_REQUEST}, "Error": "Error scaling workers"} + mock_service_manager.return_value = mock_response + + # Mock queryset with current workers and valid scaling + challenge1 = MagicMock() + challenge1.pk = 1 + challenge1.workers = 5 + challenge2 = MagicMock() + challenge2.pk = 2 + challenge2.workers = 10 + queryset = [challenge1, challenge2] + + # Call the function + result = scale_workers(queryset, num_of_tasks=7) + + # Assert the expected output + expected_failures = [ + {"message": "Error scaling workers", "challenge_pk": 1}, + {"message": "Error scaling workers", "challenge_pk": 2} + ] + self.assertEqual(result, {"count": 0, "failures": expected_failures}) + + # Ensure the service manager was called correctly + mock_service_manager.assert_called_with(mock_ec2, challenge=challenge2, num_of_tasks=7) + self.assertEqual(mock_service_manager.call_count, 2) + + +class TestScaleResources(unittest.TestCase): + @patch('challenges.aws_utils.settings', DEBUG=False) + @patch('challenges.aws_utils.get_boto3_client') + def test_scale_resources_no_changes(self, mock_get_boto3_client, mock_settings): + # Mock client + mock_client = MagicMock() + mock_get_boto3_client.return_value = mock_client + # Mock challenge + challenge = MagicMock() + challenge.worker_cpu_cores = 2 + challenge.worker_memory = 4096 + # Call the function with no changes + result = scale_resources(challenge, worker_cpu_cores=2, worker_memory=4096) + # Assert the expected output + expected_result = { + "Success": True, + "Message": "Worker not modified", + "ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK} + } + self.assertEqual(result, expected_result) + + @patch('challenges.aws_utils.settings', DEBUG=False) + @patch('challenges.aws_utils.get_boto3_client') + def test_scale_resources_no_task_def_arn(self, mock_get_boto3_client, mock_settings): + # Mock client + mock_client = MagicMock() + mock_get_boto3_client.return_value = mock_client + # Mock challenge with no task definition ARN + challenge = MagicMock() + challenge.task_def_arn = None + challenge.worker_cpu_cores = 2 + challenge.worker_memory = 4096 + # Call the function + result = scale_resources(challenge, worker_cpu_cores=4, worker_memory=8192) + # Assert the expected output + expected_result = { + "Error": "Error. No active task definition registered for the challenge {}.".format(challenge.pk), + "ResponseMetadata": {"HTTPStatusCode": HTTPStatus.BAD_REQUEST} + } + self.assertEqual(result, expected_result) + + @patch('challenges.aws_utils.settings', DEBUG=False) + @patch('challenges.aws_utils.get_boto3_client') + def test_scale_resources_deregister_success(self, mock_get_boto3_client, mock_settings): + # Mock client and response + mock_client = MagicMock() + mock_get_boto3_client.return_value = mock_client + mock_client.deregister_task_definition.return_value = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK}} + # Mock challenge + challenge = MagicMock() + challenge.task_def_arn = "some_task_def_arn" + challenge.worker_cpu_cores = 2 + challenge.worker_memory = 4096 + # Mock other dependencies + with patch('challenges.utils.get_aws_credentials_for_challenge') as mock_get_aws_credentials_for_challenge, \ + patch('challenges.aws_utils.task_definition', new_callable=MagicMock) as mock_task_definition, \ + patch('challenges.aws_utils.eval') as mock_eval: + + mock_get_aws_credentials_for_challenge.return_value = {} + mock_task_definition.return_value = {'some_key': 'some_value'} # Use a dictionary here + mock_eval.return_value = {'some_key': 'some_value'} # Use a dictionary here + + # Mock register_task_definition response + mock_client.register_task_definition.return_value = { + "ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK}, + "taskDefinition": {"taskDefinitionArn": "new_task_def_arn"} + } + + # Mock update_service response + mock_client.update_service.return_value = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK}} + + # Call the function + result = scale_resources(challenge, worker_cpu_cores=4, worker_memory=8192) + + expected_result = { + "ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK} + } + self.assertEqual(result, expected_result) + + @patch('challenges.aws_utils.settings', DEBUG=False) + @patch('challenges.aws_utils.get_boto3_client') + def test_scale_resources_deregister_failure(self, mock_get_boto3_client, mock_settings): + # Mock client and response + mock_client = MagicMock() + mock_get_boto3_client.return_value = mock_client + mock_client.deregister_task_definition.side_effect = ClientError( + {"Error": {"Message": "Scaling inactive workers not supported"}}, 'DeregisterTaskDefinition' + ) + + # Mock challenge + challenge = MagicMock() + challenge.task_def_arn = "some_task_def_arn" + challenge.worker_cpu_cores = 2 + challenge.worker_memory = 4096 + + # Call the function + result = scale_resources(challenge, worker_cpu_cores=4, worker_memory=8192) + + # Assert the expected output + expected_result = { + "Error": True, + "Message": "Scaling inactive workers not supported" + } + self.assertEqual(result, expected_result) + + @patch('challenges.aws_utils.settings', DEBUG=False) + @patch('challenges.aws_utils.get_boto3_client') + def test_scale_resources_register_task_def_success(self, mock_get_boto3_client, mock_settings): + # Mock client and response + mock_client = MagicMock() + mock_get_boto3_client.return_value = mock_client + + mock_client.deregister_task_definition.return_value = { + "ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK} + } + + # Mock challenge + challenge = MagicMock() + challenge.task_def_arn = "some_task_def_arn" + challenge.worker_cpu_cores = 2 + challenge.worker_memory = 4096 + challenge.worker_image_url = "some_image_url" + challenge.queue = "queue_name" + challenge.ephemeral_storage = 50 + challenge.pk = 123 + challenge.workers = 10 + + # Mock other dependencies + with patch('challenges.utils.get_aws_credentials_for_challenge') as mock_get_aws_credentials_for_challenge, \ + patch('challenges.aws_utils.task_definition', new_callable=MagicMock) as mock_task_definition, \ + patch('challenges.aws_utils.eval') as mock_eval: + + mock_get_aws_credentials_for_challenge.return_value = {} + mock_task_definition.return_value = {'some_key': 'some_value'} # Use a dictionary here + mock_eval.return_value = {'some_key': 'some_value'} # Use a dictionary here + + # Mock register_task_definition response + mock_client.register_task_definition.return_value = { + "ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK}, + "taskDefinition": {"taskDefinitionArn": "new_task_def_arn"} + } + + # Mock update_service response + mock_client.update_service.return_value = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK}} + + # Call the function + result = scale_resources(challenge, worker_cpu_cores=4, worker_memory=8192) + + expected_result = { + "ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK} + } + self.assertEqual(result, expected_result) + + @patch('challenges.aws_utils.settings', DEBUG=False) + @patch('challenges.aws_utils.get_boto3_client') + def test_scale_resources_register_task_def_failure(self, mock_get_boto3_client, mock_settings): + # Mock client and response + mock_client = MagicMock() + mock_get_boto3_client.return_value = mock_client + + mock_client.deregister_task_definition.return_value = { + "ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK} + } + + # Mock challenge + challenge = MagicMock() + challenge.task_def_arn = "some_task_def_arn" + challenge.worker_cpu_cores = 2 + challenge.worker_memory = 4096 + challenge.worker_image_url = "some_image_url" + challenge.queue = "queue_name" + challenge.ephemeral_storage = 50 + challenge.pk = 123 + challenge.workers = 10 + + # Mock other dependencies + with patch('challenges.utils.get_aws_credentials_for_challenge') as mock_get_aws_credentials_for_challenge, \ + patch('challenges.aws_utils.task_definition', new_callable=MagicMock) as mock_task_definition, \ + patch('challenges.aws_utils.eval') as mock_eval: + + mock_get_aws_credentials_for_challenge.return_value = {} + mock_task_definition.return_value = {'some_key': 'some_value'} # Use a dictionary here + mock_eval.return_value = {'some_key': 'some_value'} # Use a dictionary here + + # Mock register_task_definition response with error + mock_client.register_task_definition.return_value = { + "ResponseMetadata": {"HTTPStatusCode": HTTPStatus.BAD_REQUEST}, + "Error": "Failed to register task definition" + } + + # Call the function + result = scale_resources(challenge, worker_cpu_cores=4, worker_memory=8192) + + # Expected result + expected_result = { + "Error": "Failed to register task definition", + "ResponseMetadata": {"HTTPStatusCode": HTTPStatus.BAD_REQUEST} + } + self.assertEqual(result, expected_result) + + +class TestDeleteWorkers(TestCase): + @patch('challenges.aws_utils.delete_service_by_challenge_pk') + @patch('challenges.aws_utils.get_log_group_name') + @patch('challenges.aws_utils.delete_log_group') + @patch('challenges.aws_utils.settings', DEBUG=True) + def test_delete_workers_in_dev_environment(self, mock_settings, mock_delete_log_group, mock_get_log_group_name, mock_delete_service_by_challenge_pk): + # Mock a queryset + mock_queryset = [MagicMock(pk=1), MagicMock(pk=2)] + + # Call the function + result = delete_workers(mock_queryset) + + # Assertions + expected_failures = [ + {"message": "Workers cannot be deleted on AWS ECS service in development environment", "challenge_pk": 1}, + {"message": "Workers cannot be deleted on AWS ECS service in development environment", "challenge_pk": 2}, + ] + self.assertEqual(result, {"count": 0, "failures": expected_failures}) + + # Ensure the delete_service_by_challenge_pk method was never called + mock_delete_service_by_challenge_pk.assert_not_called() + + @patch('challenges.aws_utils.delete_service_by_challenge_pk') + @patch('challenges.aws_utils.get_log_group_name') + @patch('challenges.aws_utils.delete_log_group') + @patch('challenges.aws_utils.settings', DEBUG=False) + def test_delete_workers_no_workers(self, mock_settings, mock_delete_log_group, mock_get_log_group_name, mock_delete_service_by_challenge_pk): + # Mock a queryset with no workers + challenge_with_no_workers = MagicMock(pk=1, workers=None) + mock_queryset = [challenge_with_no_workers] + + # Call the function + result = delete_workers(mock_queryset) + + # Assertions + expected_failures = [{"message": "Please select challenges with active workers only.", "challenge_pk": 1}] + self.assertEqual(result, {"count": 0, "failures": expected_failures}) + + # Ensure the delete_service_by_challenge_pk method was never called + mock_delete_service_by_challenge_pk.assert_not_called() + + @patch('challenges.aws_utils.delete_service_by_challenge_pk') + @patch('challenges.aws_utils.get_log_group_name') + @patch('challenges.aws_utils.delete_log_group') + @patch('challenges.aws_utils.settings', DEBUG=False) + def test_delete_workers_success(self, mock_settings, mock_delete_log_group, mock_get_log_group_name, mock_delete_service_by_challenge_pk): + # Mock a challenge with workers and successful deletion + challenge_with_workers = MagicMock(pk=1, workers=5) + mock_queryset = [challenge_with_workers] + + # Mock the delete_service_by_challenge_pk response + mock_delete_service_by_challenge_pk.return_value = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK}} + + # Mock the log group name + mock_get_log_group_name.return_value = "log_group_name" + + # Call the function + result = delete_workers(mock_queryset) + + # Assertions + self.assertEqual(result, {"count": 1, "failures": []}) + + # Ensure the delete_service_by_challenge_pk, get_log_group_name, and delete_log_group methods were called + mock_delete_service_by_challenge_pk.assert_called_once_with(challenge=challenge_with_workers) + mock_get_log_group_name.assert_called_once_with(challenge_with_workers.pk) + mock_delete_log_group.assert_called_once_with("log_group_name") + + @patch('challenges.aws_utils.delete_service_by_challenge_pk') + @patch('challenges.aws_utils.get_log_group_name') + @patch('challenges.aws_utils.delete_log_group') + @patch('challenges.aws_utils.settings', DEBUG=False) + def test_delete_workers_failure(self, mock_settings, mock_delete_log_group, mock_get_log_group_name, mock_delete_service_by_challenge_pk): + # Mock a challenge with workers and failed deletion + challenge_with_workers = MagicMock(pk=1, workers=5) + mock_queryset = [challenge_with_workers] + + # Mock the delete_service_by_challenge_pk response to simulate a failure + mock_delete_service_by_challenge_pk.return_value = { + "ResponseMetadata": {"HTTPStatusCode": HTTPStatus.BAD_REQUEST}, + "Error": "An error occurred" + } + + # Call the function + result = delete_workers(mock_queryset) + + # Assertions + expected_failures = [{"message": "An error occurred", "challenge_pk": 1}] + self.assertEqual(result, {"count": 0, "failures": expected_failures}) + + # Ensure the delete_service_by_challenge_pk was called + mock_delete_service_by_challenge_pk.assert_called_once_with(challenge=challenge_with_workers) + + # Ensure get_log_group_name and delete_log_group were not called + mock_get_log_group_name.assert_not_called() + mock_delete_log_group.assert_not_called() + + +class TestRestartWorkers(TestCase): + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.service_manager') + @patch('challenges.aws_utils.settings', DEBUG=True) + def test_restart_workers_in_dev_environment(self, mock_settings, mock_service_manager, mock_get_boto3_client): + # Mock a queryset + mock_queryset = [MagicMock(pk=1), MagicMock(pk=2)] + + # Call the function + result = restart_workers(mock_queryset) + + # Assertions + expected_failures = [ + {"message": "Workers cannot be restarted on AWS ECS service in development environment", "challenge_pk": 1}, + {"message": "Workers cannot be restarted on AWS ECS service in development environment", "challenge_pk": 2}, + ] + self.assertEqual(result, {"count": 0, "failures": expected_failures}) + + # Ensure the service_manager method was never called + mock_service_manager.assert_not_called() + + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.service_manager') + @patch('challenges.aws_utils.settings', DEBUG=False) + def test_restart_workers_docker_based_challenge(self, mock_settings, mock_service_manager, mock_get_boto3_client): + # Mock a Docker-based challenge queryset + challenge_docker_based = MagicMock(pk=1, is_docker_based=True, is_static_dataset_code_upload=False) + mock_queryset = [challenge_docker_based] + + # Call the function + result = restart_workers(mock_queryset) + + # Assertions + expected_failures = [ + {"message": "Sorry. This feature is not available for code upload/docker based challenges.", "challenge_pk": 1} + ] + self.assertEqual(result, {"count": 0, "failures": expected_failures}) + + # Ensure the service_manager method was never called + mock_service_manager.assert_not_called() + + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.service_manager') + @patch('challenges.aws_utils.settings', DEBUG=False) + def test_restart_workers_no_workers(self, mock_settings, mock_service_manager, mock_get_boto3_client): + # Mock a challenge with no workers + challenge_no_workers = MagicMock(pk=1, workers=0, is_docker_based=False, is_static_dataset_code_upload=False) + mock_queryset = [challenge_no_workers] + + # Call the function + result = restart_workers(mock_queryset) + + # Assertions + # Expect a failure message indicating no active workers + expected_result = { + "count": 0, + "failures": [ + { + "message": "Please select challenges with active workers only.", + "challenge_pk": 1 + } + ] + } + self.assertEqual(result, expected_result) + + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.service_manager') + @patch('challenges.aws_utils.settings', DEBUG=False) + def test_restart_workers_success(self, mock_settings, mock_service_manager, mock_get_boto3_client): + # Mock a challenge with active workers + challenge_with_workers = MagicMock(pk=1, workers=2, is_docker_based=False, is_static_dataset_code_upload=False) + mock_queryset = [challenge_with_workers] + + # Mock the service_manager response + mock_service_manager.return_value = {"ResponseMetadata": {"HTTPStatusCode": HTTPStatus.OK}} + + # Call the function + result = restart_workers(mock_queryset) + + # Assertions + self.assertEqual(result, {"count": 1, "failures": []}) + + # Ensure the service_manager method was called + mock_service_manager.assert_called_once_with( + mock_get_boto3_client.return_value, + challenge=challenge_with_workers, + num_of_tasks=challenge_with_workers.workers, + force_new_deployment=True, + ) + + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.service_manager') + @patch('challenges.aws_utils.settings', DEBUG=False) + def test_restart_workers_failure(self, mock_settings, mock_service_manager, mock_get_boto3_client): + # Mock a challenge with active workers + challenge_with_workers = MagicMock(pk=1, workers=2, is_docker_based=False, is_static_dataset_code_upload=False) + mock_queryset = [challenge_with_workers] + + # Mock the service_manager response to simulate a failure + mock_service_manager.return_value = { + "ResponseMetadata": {"HTTPStatusCode": HTTPStatus.BAD_REQUEST}, + "Error": "An error occurred" + } + + # Call the function + result = restart_workers(mock_queryset) + + # Assertions + expected_failures = [{"message": "An error occurred", "challenge_pk": 1}] + self.assertEqual(result, {"count": 0, "failures": expected_failures}) + + # Ensure the service_manager method was called + mock_service_manager.assert_called_once_with( + mock_get_boto3_client.return_value, + challenge=challenge_with_workers, + num_of_tasks=challenge_with_workers.workers, + force_new_deployment=True, + ) + + +class TestRestartWorkersSignalCallback(TestCase): + @patch('challenges.aws_utils.settings', DEBUG=True) + def test_restart_workers_signal_callback_debug_mode(self, mock_settings, *args): + # Mock the sender and instance + mock_sender = MagicMock() + mock_instance = MagicMock() + + # Call the signal callback + result = restart_workers_signal_callback(sender=mock_sender, instance=mock_instance, field_name="evaluation_script") + + # Assert that the function returns None when DEBUG is True + self.assertIsNone(result) + + +class TestGetLogsFromCloudwatch(TestCase): + @patch('challenges.aws_utils.settings', DEBUG=True) + def test_get_logs_from_cloudwatch_debug_mode(self, mock_settings, *args): + # Test when DEBUG is True + log_group_name = "dummy_group" + log_stream_prefix = "dummy_prefix" + start_time = 123456789 + end_time = 123456999 + pattern = "" + limit = 10 + + logs = get_logs_from_cloudwatch(log_group_name, log_stream_prefix, start_time, end_time, pattern, limit) + + expected_logs = [ + "The worker logs in the development environment are available on the terminal. Please use docker-compose logs -f worker to view the logs." + ] + + self.assertEqual(logs, expected_logs) + + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.settings', DEBUG=False) + def test_get_logs_from_cloudwatch_success(self, mock_settings, mock_get_boto3_client): + # Test when DEBUG is False and logs are retrieved successfully + mock_client = MagicMock() + mock_get_boto3_client.return_value = mock_client + + mock_client.filter_log_events.return_value = { + "events": [{"message": "Log message 1"}, {"message": "Log message 2"}], + "nextToken": None + } + + log_group_name = "dummy_group" + log_stream_prefix = "dummy_prefix" + start_time = 123456789 + end_time = 123456999 + pattern = "" + limit = 10 + + logs = get_logs_from_cloudwatch(log_group_name, log_stream_prefix, start_time, end_time, pattern, limit) + + expected_logs = ["Log message 1", "Log message 2"] + + self.assertEqual(logs, expected_logs) + mock_client.filter_log_events.assert_called_once_with( + logGroupName=log_group_name, + logStreamNamePrefix=log_stream_prefix, + startTime=start_time, + endTime=end_time, + filterPattern=pattern, + limit=limit + ) + + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.settings', DEBUG=False) + def test_get_logs_from_cloudwatch_resource_not_found(self, mock_settings, mock_get_boto3_client): + # Test when DEBUG is False and ResourceNotFoundException is raised + mock_client = MagicMock() + mock_get_boto3_client.return_value = mock_client + + exception = Exception() + exception.response = {"Error": {"Code": "ResourceNotFoundException"}} + mock_client.filter_log_events.side_effect = exception + + log_group_name = "dummy_group" + log_stream_prefix = "dummy_prefix" + start_time = 123456789 + end_time = 123456999 + pattern = "" + limit = 10 + + logs = get_logs_from_cloudwatch(log_group_name, log_stream_prefix, start_time, end_time, pattern, limit) + + self.assertEqual(logs, []) + + @patch('challenges.aws_utils.logger') + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.settings', DEBUG=False) + def test_get_logs_from_cloudwatch_other_exception(self, mock_settings, mock_get_boto3_client, mock_logger): + # Test when DEBUG is False and a different exception is raised + mock_client = MagicMock() + mock_get_boto3_client.return_value = mock_client + + exception = Exception("Some error") + exception.response = {"Error": {"Code": "OtherException"}} + mock_client.filter_log_events.side_effect = exception + + log_group_name = "dummy_group" + log_stream_prefix = "dummy_prefix" + start_time = 123456789 + end_time = 123456999 + pattern = "" + limit = 10 + + logs = get_logs_from_cloudwatch(log_group_name, log_stream_prefix, start_time, end_time, pattern, limit) + + expected_logs = [ + "There is an error in displaying logs. Please find the full error traceback here Some error" + ] + + self.assertEqual(logs, expected_logs) + mock_logger.exception.assert_called_once_with(exception) + + @patch('challenges.aws_utils.settings') # Adjust the path to the actual module + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.logger') + def test_delete_log_group_debug_mode(self, mock_logger, mock_get_boto3_client, mock_settings): + # When settings.DEBUG is True + mock_settings.DEBUG = True + + # Call the function + delete_log_group('test-log-group') + + # Assert that get_boto3_client and logger were not called + mock_get_boto3_client.assert_not_called() + mock_logger.assert_not_called() + + @patch('challenges.aws_utils.settings') # Adjust the path to the actual module + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.logger') + def test_delete_log_group_non_debug_mode(self, mock_logger, mock_get_boto3_client, mock_settings): + # When settings.DEBUG is False + mock_settings.DEBUG = False + + # Mock boto3 client and its methods + mock_client = MagicMock() + mock_get_boto3_client.return_value = mock_client + + # Call the function + delete_log_group('test-log-group') + + aws_keys = {'AWS_ACCOUNT_ID': 'x', 'AWS_ACCESS_KEY_ID': 'x', 'AWS_SECRET_ACCESS_KEY': 'x', 'AWS_REGION': 'us-east-1', 'AWS_STORAGE_BUCKET_NAME': 'evalai-s3-bucket'} + # Assert that get_boto3_client was called with the correct arguments + mock_get_boto3_client.assert_called_once_with("logs", aws_keys) + + # Assert that delete_log_group was called on the client with the correct argument + mock_client.delete_log_group.assert_called_once_with(logGroupName='test-log-group') + + # Assert that logger.exception was not called + mock_logger.exception.assert_not_called() + + @patch('challenges.aws_utils.settings') # Adjust the path to the actual module + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.logger') + def test_delete_log_group_with_exception(self, mock_logger, mock_get_boto3_client, mock_settings): + # When settings.DEBUG is False and an exception occurs + mock_settings.DEBUG = False + + # Mock boto3 client and its methods to raise an exception + mock_client = MagicMock() + mock_client.delete_log_group.side_effect = Exception('Delete failed') + mock_get_boto3_client.return_value = mock_client + + # Call the function + delete_log_group('test-log-group') + + aws_keys = {'AWS_ACCOUNT_ID': 'x', 'AWS_ACCESS_KEY_ID': 'x', 'AWS_SECRET_ACCESS_KEY': 'x', 'AWS_REGION': 'us-east-1', 'AWS_STORAGE_BUCKET_NAME': 'evalai-s3-bucket'} + # Assert that get_boto3_client was called with the correct arguments + mock_get_boto3_client.assert_called_once_with("logs", aws_keys) + + # Assert that delete_log_group was called on the client with the correct argument + mock_client.delete_log_group.assert_called_once_with(logGroupName='test-log-group') + + # Retrieve the actual arguments passed to logger.exception + args, kwargs = mock_logger.exception.call_args + + # Check if the first argument of logger.exception contains the correct message + self.assertTrue('Delete failed' in str(args[0]), f"Expected 'Delete failed' in {args[0]}") + + +class TestCreateEKSNodegroup(unittest.TestCase): + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.get_code_upload_setup_meta_for_challenge') + @patch('challenges.utils.get_aws_credentials_for_challenge') + @patch('challenges.aws_utils.serializers.deserialize') + @patch('challenges.aws_utils.settings') + @patch('challenges.aws_utils.logger') + @patch('challenges.aws_utils.construct_and_send_eks_cluster_creation_mail') + @patch('challenges.aws_utils.create_service_by_challenge_pk') + @patch('challenges.aws_utils.client_token_generator') + def test_create_eks_nodegroup_success(self, mock_client_token_generator, mock_create_service_by_challenge_pk, mock_construct_and_send_eks_cluster_creation_mail, mock_logger, mock_settings, mock_deserialize, mock_get_aws_credentials_for_challenge, mock_get_code_upload_setup_meta_for_challenge, mock_get_boto3_client): + + # Setup mock objects and functions + mock_settings.ENVIRONMENT = 'test-env' + mock_challenge = MagicMock() + mock_deserialize.return_value = [MagicMock(object=mock_challenge)] + + mock_challenge.pk = 1 + mock_challenge.title = 'Test Challenge' + mock_challenge.min_worker_instance = 1 + mock_challenge.max_worker_instance = 2 + mock_challenge.desired_worker_instance = 1 + mock_challenge.worker_disk_size = 50 + mock_challenge.worker_instance_type = 't2.medium' + mock_challenge.worker_ami_type = 'AL2_x86_64' + + mock_cluster_meta = { + "SUBNET_1": "subnet-123", + "SUBNET_2": "subnet-456", + "EKS_NODEGROUP_ROLE_ARN": "arn:aws:iam::123456789012:role/eks-nodegroup-role" + } + mock_get_code_upload_setup_meta_for_challenge.return_value = mock_cluster_meta + mock_aws_credentials = {'AWS_ACCOUNT_ID': 'x', 'AWS_ACCESS_KEY_ID': 'x', 'AWS_SECRET_ACCESS_KEY': 'x', 'AWS_REGION': 'us-east-1', 'AWS_STORAGE_BUCKET_NAME': 'evalai-s3-bucket'} + mock_get_aws_credentials_for_challenge.return_value = mock_aws_credentials + + mock_client = MagicMock() + mock_get_boto3_client.side_effect = [mock_client, mock_client] + + # Mocking the create_nodegroup method and the waiter + mock_client.create_nodegroup.return_value = {"nodegroup": "created"} + mock_waiter = MagicMock() + mock_client.get_waiter.return_value = mock_waiter + + # Call the function + create_eks_nodegroup(mock_challenge, 'test-cluster') + + # Assertions + mock_get_boto3_client.assert_called_with("ecs", mock_aws_credentials) + + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.get_code_upload_setup_meta_for_challenge') + @patch('challenges.utils.get_aws_credentials_for_challenge') + @patch('challenges.aws_utils.serializers.deserialize') + @patch('challenges.aws_utils.settings') + @patch('challenges.aws_utils.logger') + def test_create_eks_nodegroup_client_error(self, mock_logger, mock_settings, mock_deserialize, mock_get_aws_credentials_for_challenge, mock_get_code_upload_setup_meta_for_challenge, mock_get_boto3_client): + + # Setup mock objects and functions + mock_settings.ENVIRONMENT = 'test-env' + mock_challenge = MagicMock() + mock_deserialize.return_value = [MagicMock(object=mock_challenge)] + + mock_challenge.pk = 1 + mock_challenge.title = 'Test Challenge' + mock_challenge.min_worker_instance = 1 + mock_challenge.max_worker_instance = 2 + mock_challenge.desired_worker_instance = 1 + mock_challenge.worker_disk_size = 50 + mock_challenge.worker_instance_type = 't2.medium' + mock_challenge.worker_ami_type = 'AL2_x86_64' + + mock_cluster_meta = { + "SUBNET_1": "subnet-123", + "SUBNET_2": "subnet-456", + "EKS_NODEGROUP_ROLE_ARN": "arn:aws:iam::123456789012:role/eks-nodegroup-role" + } + mock_get_code_upload_setup_meta_for_challenge.return_value = mock_cluster_meta + mock_aws_credentials = {'AWS_ACCESS_KEY_ID': 'x', 'AWS_SECRET_ACCESS_KEY': 'x', 'AWS_REGION': 'us-east-1'} + mock_get_aws_credentials_for_challenge.return_value = mock_aws_credentials + + mock_client = MagicMock() + mock_get_boto3_client.return_value = mock_client + + # Mocking the create_nodegroup method to raise a ClientError + mock_client.create_nodegroup.side_effect = ClientError( + {"Error": {"Code": "SomeError", "Message": "Create failed"}}, + "CreateNodegroup" + ) + + # Call the function + create_eks_nodegroup(mock_challenge, 'test-cluster') + + # Assertions + mock_get_boto3_client.assert_called_once_with("eks", mock_aws_credentials) + mock_client.create_nodegroup.assert_called_once_with( + clusterName='test-cluster', + nodegroupName='Test-Challenge-1-test-env-nodegroup', + scalingConfig={ + "minSize": 1, + "maxSize": 2, + "desiredSize": 1 + }, + diskSize=50, + subnets=["subnet-123", "subnet-456"], + instanceTypes=['t2.medium'], + amiType='AL2_x86_64', + nodeRole="arn:aws:iam::123456789012:role/eks-nodegroup-role" + ) + + # Retrieve the actual arguments passed to logger.exception + args, kwargs = mock_logger.exception.call_args + + # Extract the ClientError object from the actual call + actual_error = args[0] + + # Check if the actual error message contains the expected message + expected_message = "An error occurred (SomeError) when calling the CreateNodegroup operation: Create failed" + self.assertIn(expected_message, str(actual_error), f"Expected message '{expected_message}' in {str(actual_error)}") + + +class TestSetupEksCluster(TestCase): + @patch('challenges.utils.get_aws_credentials_for_challenge') + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.serializers.deserialize') + @patch('challenges.models.ChallengeEvaluationCluster.objects.get') + @patch('challenges.serializers.ChallengeEvaluationClusterSerializer') + @patch('challenges.aws_utils.logger') + def test_setup_eks_cluster_success(self, mock_logger, mock_serializer, mock_get_cluster, mock_deserialize, mock_boto3, mock_get_aws): + """ Test case to ensure success path of the setup_eks_cluster function """ + + # Mocks + mock_client = MagicMock() + mock_boto3.return_value = mock_client + mock_serializer.return_value.is_valid.return_value = True + mock_serializer.return_value.save.return_value = None + mock_get_cluster.return_value = MagicMock() + + # Simulate valid deserialization + mock_obj = MagicMock() + mock_obj.object = MagicMock() + mock_deserialize.return_value = [mock_obj] + + # Call the function + setup_eks_cluster('{"some": "data"}') + + # Assertions for role creation and policy attachment + self.assertTrue(mock_client.create_role.called) + self.assertTrue(mock_client.attach_role_policy.called) + self.assertTrue(mock_client.create_policy.called) + self.assertTrue(mock_serializer.return_value.save.called) + + # Ensure an exception was logged + mock_logger.exception.assert_called_once() + + @patch('challenges.utils.get_aws_credentials_for_challenge') + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.serializers.deserialize') + @patch('challenges.aws_utils.logger') + def test_setup_eks_cluster_create_role_failure(self, mock_logger, mock_deserialize, mock_boto3, mock_get_aws): + """ Test case to simulate failure during EKS role creation """ + + # Mocks + mock_client = MagicMock() + mock_client.create_role.side_effect = ClientError({"Error": {"Code": "SomeError"}}, "CreateRole") + mock_boto3.return_value = mock_client + + # Simulate valid deserialization + mock_obj = MagicMock() + mock_obj.object = MagicMock() + mock_deserialize.return_value = [mock_obj] + + # Call the function + setup_eks_cluster('{"some": "data"}') + + # Assertions for exception handling + mock_logger.exception.assert_called_once() + self.assertTrue(mock_client.create_role.called) + + @patch('challenges.utils.get_aws_credentials_for_challenge') + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.serializers.deserialize') + @patch('challenges.aws_utils.logger') + def test_setup_eks_cluster_attach_role_policy_failure(self, mock_logger, mock_deserialize, mock_boto3, mock_get_aws): + """ Test case to simulate failure during policy attachment """ + + # Mocks + mock_client = MagicMock() + mock_client.attach_role_policy.side_effect = ClientError({"Error": {"Code": "SomeError"}}, "AttachRolePolicy") + mock_boto3.return_value = mock_client + + # Simulate valid deserialization + mock_obj = MagicMock() + mock_obj.object = MagicMock() + mock_deserialize.return_value = [mock_obj] + + # Call the function + setup_eks_cluster('{"some": "data"}') + + # Assertions for exception handling + mock_logger.exception.assert_called_once() + self.assertTrue(mock_client.attach_role_policy.called) + + @patch('challenges.utils.get_aws_credentials_for_challenge') + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.serializers.deserialize') + @patch('challenges.aws_utils.logger') + def test_setup_eks_cluster_create_policy_failure(self, mock_logger, mock_deserialize, mock_boto3, mock_get_aws): + """ Test case to simulate failure during custom ECR policy creation """ + + # Mocks + mock_client = MagicMock() + mock_client.create_policy.side_effect = ClientError({"Error": {"Code": "SomeError"}}, "CreatePolicy") + mock_boto3.return_value = mock_client + + # Simulate valid deserialization + mock_obj = MagicMock() + mock_obj.object = MagicMock() + mock_deserialize.return_value = [mock_obj] + + # Call the function + setup_eks_cluster('{"some": "data"}') + + # Assertions for exception handling + mock_logger.exception.assert_called_once() + self.assertTrue(mock_client.create_policy.called) + + @patch('challenges.utils.get_aws_credentials_for_challenge') + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.serializers.deserialize') + @patch('challenges.aws_utils.logger') + @patch('challenges.serializers.ChallengeEvaluationClusterSerializer') + @patch('challenges.models.ChallengeEvaluationCluster.objects.get') + def test_setup_eks_cluster_serialization_failure(self, mock_get_cluster, mock_serializer, mock_logger, mock_deserialize, mock_boto3, mock_get_aws): + """ Test case to simulate failure during serialization """ + + # Mocks + mock_client = MagicMock() + mock_boto3.return_value = mock_client + + # Simulate valid deserialization + mock_obj = MagicMock() + mock_obj.object = MagicMock() + mock_deserialize.return_value = [mock_obj] + + # Simulate invalid serializer + mock_serializer.return_value.is_valid.return_value = False + mock_get_cluster.return_value = MagicMock() + + # Call the function + setup_eks_cluster('{"some": "data"}') + + # Ensure serializer failure doesn't cause errors + self.assertTrue(mock_serializer.return_value.is_valid.called) + mock_logger.exception.assert_called_once() + + @patch('challenges.utils.get_aws_credentials_for_challenge') + @patch('challenges.aws_utils.get_boto3_client') + @patch('challenges.aws_utils.serializers.deserialize') + @patch('challenges.aws_utils.logger') + @patch('challenges.aws_utils.create_eks_cluster_subnets.delay') + @patch('challenges.serializers.ChallengeEvaluationClusterSerializer') + @patch('challenges.models.ChallengeEvaluationCluster.objects.get') + def test_setup_eks_cluster_subnets_creation(self, mock_get_cluster, mock_serializer, mock_create_subnets, mock_logger, mock_deserialize, mock_boto3, mock_get_aws): + """ Test case to ensure EKS cluster subnets creation is triggered """ + + # Mocks + mock_client = MagicMock() + mock_boto3.return_value = mock_client + mock_serializer.return_value.is_valid.return_value = True + mock_get_cluster.return_value = MagicMock() + + # Simulate valid deserialization + mock_obj = MagicMock() + mock_obj.object = MagicMock() + mock_deserialize.return_value = [mock_obj] + + # Call the function + setup_eks_cluster('{"some": "data"}') + + # Ensure subnets creation task is triggered + self.assertTrue(mock_create_subnets.called) + + +@pytest.mark.django_db +class TestSetupEC2(TestCase): + def setUp(self): + self.user = User.objects.create( + username="someuser", + email="user@test.com", + password="secret_password", + ) + self.challenge_host_team = ChallengeHostTeam.objects.create( + team_name="Test Challenge Host Team", created_by=self.user + ) + self.challenge = Challenge.objects.create( + title="Test Challenge", + ec2_instance_id=None, + creator=self.challenge_host_team, + ) + self.serialized_challenge = serializers.serialize("json", [self.challenge]) + + @patch("challenges.aws_utils.start_ec2_instance") + @patch("challenges.aws_utils.create_ec2_instance") + @patch("django.core.serializers.deserialize") + def test_setup_ec2_with_existing_instance(self, mock_deserialize, mock_create_ec2, mock_start_ec2): + # Setup mock behavior + mock_obj = MagicMock() + mock_obj.object = self.challenge + mock_deserialize.return_value = [mock_obj] + # Update the challenge to have an ec2_instance_id + self.challenge.ec2_instance_id = "i-1234567890abcdef0" + self.challenge.save() + # Call the function + setup_ec2(self.serialized_challenge) + # Check if start_ec2_instance was called since the EC2 instance already exists + mock_start_ec2.assert_called_once_with(self.challenge) + mock_create_ec2.assert_not_called() + + @patch("challenges.aws_utils.start_ec2_instance") + @patch("challenges.aws_utils.create_ec2_instance") + @patch("django.core.serializers.deserialize") + def test_setup_ec2_without_existing_instance(self, mock_deserialize, mock_create_ec2, mock_start_ec2): + # Setup mock behavior + mock_obj = MagicMock() + mock_obj.object = self.challenge + mock_deserialize.return_value = [mock_obj] + # Ensure ec2_instance_id is None + self.challenge.ec2_instance_id = None + self.challenge.save() + # Call the function + setup_ec2(self.serialized_challenge) + # Check if create_ec2_instance was called since the EC2 instance doesn't exist + mock_create_ec2.assert_called_once_with(self.challenge) + mock_start_ec2.assert_not_called() + + @patch('challenges.aws_utils.update_sqs_retention_period') + @patch('django.core.serializers.deserialize') + def test_update_sqs_retention_period_task(self, mock_deserialize, mock_update_sqs_retention_period): + challenge_json = '{"model": "app.challenge", "pk": 1, "fields": {}}' + mock_challenge_obj = MagicMock() + + mock_deserialized_object = MagicMock() + mock_deserialized_object.object = mock_challenge_obj + mock_deserialize.return_value = [mock_deserialized_object] + + update_sqs_retention_period_task(challenge_json) + + mock_deserialize.assert_called_once_with("json", challenge_json) + mock_update_sqs_retention_period.assert_called_once_with(mock_challenge_obj) diff --git a/tests/unit/challenges/test_challenge_config_utils.py b/tests/unit/challenges/test_challenge_config_utils.py new file mode 100644 index 0000000000..f2d182aa5d --- /dev/null +++ b/tests/unit/challenges/test_challenge_config_utils.py @@ -0,0 +1,780 @@ +import unittest +import zipfile +import io +from unittest.mock import Mock, patch as mockpatch +import requests +import yaml +from django.contrib.auth.models import User +from os.path import join +from challenges.challenge_config_utils import ValidateChallengeConfigUtil, download_and_write_file, get_yaml_files_from_challenge_config, get_yaml_read_error, is_challenge_config_yaml_html_field_valid, is_challenge_phase_split_mapping_valid, validate_challenge_config_util +from challenges.models import ChallengePhase, ChallengePhaseSplit, DatasetSplit, Leaderboard +from hosts.models import ChallengeHostTeam +import pytest + + +class TestGetYamlFilesFromChallengeConfig(unittest.TestCase): + def test_no_yaml_files_in_zip(self): + # Create a zip file in memory with no YAML files + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, 'a', zipfile.ZIP_DEFLATED) as zip_file: + zip_file.writestr('some_other_file.txt', 'This is some other file content') + + zip_buffer.seek(0) + with zipfile.ZipFile(zip_buffer, 'r') as zip_file: + yaml_file_count, yaml_file_name, extracted_folder_name = get_yaml_files_from_challenge_config(zip_file) + + self.assertEqual(yaml_file_count, 0) + self.assertIsNone(yaml_file_name) + self.assertIsNone(extracted_folder_name) + + +class TestGetYamlReadError(unittest.TestCase): + def test_get_yaml_read_error_with_problem_and_problem_mark(self): + # Create a mock exception with problem and problem_mark attributes + class MockException: + def __init__(self, problem, line, column): + self.problem = problem + self.problem_mark = self.Mark(line, column) + + class Mark: + def __init__(self, line, column): + self.line = line + self.column = column + + exc = MockException("mock problem", 10, 20) + error_description, line_number, column_number = get_yaml_read_error(exc) + + self.assertEqual(error_description, "Mock problem") + self.assertEqual(line_number, 11) + self.assertEqual(column_number, 21) + + def test_get_yaml_read_error_without_problem_and_problem_mark(self): + # Create a mock exception without problem and problem_mark attributes + class MockException: + pass + + exc = MockException() + error_description, line_number, column_number = get_yaml_read_error(exc) + + self.assertIsNone(error_description) + self.assertIsNone(line_number) + self.assertIsNone(column_number) + + +class TestIsChallengeConfigYamlHtmlFieldValid(unittest.TestCase): + def setUp(self): + self.base_location = "/path/to/extracted/config" + + @mockpatch('challenges.challenge_config_utils.isfile', return_value=False) + def test_file_not_found(self, mock_isfile): + yaml_file_data = {'html_field': 'non_existent_file.html'} + key = 'html_field' + + is_valid, message = is_challenge_config_yaml_html_field_valid(yaml_file_data, key, self.base_location) + + self.assertFalse(is_valid) + self.assertEqual(message, "File at path html_field not found. Please specify a valid file path") + + @mockpatch('challenges.challenge_config_utils.isfile', return_value=True) + def test_file_not_html(self, mock_isfile): + yaml_file_data = {'html_field': 'file.txt'} + key = 'html_field' + + is_valid, message = is_challenge_config_yaml_html_field_valid(yaml_file_data, key, self.base_location) + + self.assertFalse(is_valid) + self.assertEqual(message, "File html_field is not a HTML file. Please specify a valid HTML file") + + +class TestIsChallengePhaseSplitMappingValid(unittest.TestCase): + def test_invalid_dataset_split_id(self): + phase_ids = [1, 2, 3] + leaderboard_ids = [10, 20, 30] + dataset_split_ids = [100, 200, 300] + phase_split = { + "challenge_phase_id": 1, + "leaderboard_id": 10, + "dataset_split_id": 400 # Invalid dataset split id + } + challenge_phase_split_index = 0 + + is_success, error_messages = is_challenge_phase_split_mapping_valid( + phase_ids, leaderboard_ids, dataset_split_ids, phase_split, challenge_phase_split_index + ) + + self.assertFalse(is_success) + self.assertIn("ERROR: Invalid dataset split id 400 found in challenge phase split 0.", error_messages) + + +class TestDownloadAndWriteFile(unittest.TestCase): + def setUp(self): + self.url = "http://example.com/file.zip" + self.output_path = "/path/to/output/file.zip" + self.mode = "wb" + + @mockpatch('challenges.challenge_config_utils.requests.get') + @mockpatch('challenges.challenge_config_utils.write_file') + def test_io_error(self, mock_write_file, mock_requests_get): + mock_requests_get.return_value = Mock(status_code=200, content=b'file content') + mock_write_file.side_effect = IOError + + is_success, message = download_and_write_file(self.url, True, self.output_path, self.mode) + + self.assertFalse(is_success) + self.assertEqual(message, "Unable to process the uploaded zip file. Please try again!") + + @mockpatch('challenges.challenge_config_utils.requests.get') + def test_request_exception(self, mock_requests_get): + mock_requests_get.side_effect = requests.exceptions.RequestException + + is_success, message = download_and_write_file(self.url, True, self.output_path, self.mode) + + self.assertFalse(is_success) + self.assertEqual(message, "A server error occured while processing zip file. Please try again!") + + +@pytest.mark.django_db +class TestValidateChallengeConfigUtil0(unittest.TestCase): + def setUp(self): + self.request = Mock() + # Create a challenge host team + self.user = User.objects.create_user(username='testuser', password='12345') + self.challenge_host_team = ChallengeHostTeam.objects.create( + team_name="Test Challenge Host Team", created_by=self.user + ) + self.base_location = "/path/to/base" + self.unique_folder_name = "unique_folder" + self.extracted_folder_name = "extracted_folder" + self.zip_ref = Mock() + self.zip_ref.namelist.return_value = [] # Mock the namelist method to return an empty list + self.current_challenge = Mock() + self.error_message_dict = { + "no_yaml_file": "No YAML file found in the zip.", + "multiple_yaml_files": "Multiple YAML files found: {}.", + "yaml_file_read_error": "YAML file read error: {} at line {}, column {}.", + "missing_challenge_description": "Challenge description is missing.", + "missing_evaluation_details": "Evaluation details are missing.", + "missing_terms_and_conditions": "Terms and conditions are missing.", + "missing_submission_guidelines": "Submission guidelines are missing.", + "evaluation_script_not_zip": "Evaluation script is not a zip file.", + "missing_evaluation_script": "Evaluation script file is missing.", + "missing_evaluation_script_key": "Evaluation script key is missing.", + "missing_date": "Start date or end date is missing.", + "start_date_greater_than_end_date": "Start date is greater than end date.", + "challenge_metadata_schema_errors": "Schema errors: {}" + } + + self.util = ValidateChallengeConfigUtil( + self.request, self.challenge_host_team, self.base_location, + self.unique_folder_name, self.zip_ref, self.current_challenge + ) + self.util.error_messages_dict = self.error_message_dict + self.util.yaml_file_data = {} + self.util.extracted_folder_name = self.extracted_folder_name + self.util.challenge_image_file = Mock() # Initialize the missing attribute + self.util.challenge_evaluation_script_file = Mock() # Initialize if needed + self.util.extracted_folder_name = "extracted_folder" + self.util.error_messages = [] + self.request.data = {"GITHUB_REPOSITORY": "some_repo"} + + def test_no_yaml_file(self): + self.util.yaml_file_count = 0 + + result = self.util.read_and_validate_yaml() + + self.assertFalse(result) + self.assertIn("No YAML file found in the zip.", self.util.error_messages) + + def test_multiple_yaml_files(self): + self.util.yaml_file_count = 2 + + result = self.util.read_and_validate_yaml() + + self.assertFalse(result) + self.assertIn("Multiple YAML files found: 2.", self.util.error_messages) + + @mockpatch('challenges.challenge_config_utils.read_yaml_file') + def test_yaml_read_error(self, mock_read_yaml_file): + self.util.yaml_file_count = 1 + self.util.yaml_file = "config.yaml" + mock_read_yaml_file.side_effect = yaml.YAMLError("None") + + result = self.util.read_and_validate_yaml() + + self.assertFalse(result) + self.assertIn("YAML file read error: None at line None, column None.", self.util.error_messages) + + @mockpatch('challenges.challenge_config_utils.isfile') + @mockpatch('challenges.challenge_config_utils.get_file_content') + def test_validate_challenge_logo_valid_image(self, mock_get_file_content, mock_isfile): + self.util.yaml_file_data = {"image": "logo.png"} + mock_isfile.return_value = True + mock_get_file_content.return_value = b"image content" + + self.util.validate_challenge_logo() + + expected_path = join(self.base_location, self.unique_folder_name, self.extracted_folder_name, "logo.png") + self.assertEqual(self.util.challenge_image_path, expected_path) + self.assertIsNotNone(self.util.challenge_image_file) + self.assertEqual(self.util.files["challenge_image_file"].name, "logo.png") + + @mockpatch('challenges.challenge_config_utils.isfile') + def test_validate_challenge_logo_invalid_image(self, mock_isfile): + self.util.yaml_file_data = {"image": "logo.txt"} + mock_isfile.return_value = False + + self.util.validate_challenge_logo() + + self.assertIsNone(self.util.challenge_image_file) + self.assertIsNone(self.util.files["challenge_image_file"]) + + def test_validate_challenge_logo_no_image_key(self): + self.util.yaml_file_data = {} + + self.util.validate_challenge_logo() + + self.assertIsNone(self.util.challenge_image_file) + self.assertIsNone(self.util.files["challenge_image_file"]) + + @pytest.mark.django_db + @mockpatch('challenges.challenge_config_utils.ValidateChallengeConfigUtil') + def test_validate_challenge_config_util_invalid_yaml(self, MockValidateChallengeConfigUtil): + # Arrange + mock_instance = MockValidateChallengeConfigUtil.return_value + mock_instance.valid_yaml = False + mock_instance.error_messages = ["Error message"] + mock_instance.yaml_file_data = {"key": "value"} + mock_instance.files = {"file_key": "file_value"} + + request = Mock() + challenge_host_team = 1 + BASE_LOCATION = "/path/to/base" + unique_folder_name = "unique_folder" + zip_ref = Mock() + current_challenge = Mock() + + # Act + result = validate_challenge_config_util( + request, + challenge_host_team, + BASE_LOCATION, + unique_folder_name, + zip_ref, + current_challenge + ) + + # Assert + self.assertEqual(result, ( + ["Error message"], + {"key": "value"}, + {"file_key": "file_value"} + )) + + @pytest.mark.django_db + @mockpatch('challenges.challenge_config_utils.ValidateChallengeConfigUtil') + def test_validate_challenge_config_util_with_current_challenge(self, MockValidateChallengeConfigUtil): + # Arrange + mock_instance = MockValidateChallengeConfigUtil.return_value + mock_instance.valid_yaml = True + mock_instance.error_messages = [] + mock_instance.yaml_file_data = {"key": "value"} + mock_instance.files = {"file_key": "file_value"} + + request = Mock() + challenge_host_team = 1 + BASE_LOCATION = "/path/to/base" + unique_folder_name = "unique_folder" + zip_ref = Mock() + current_challenge = Mock(id=1) + + # Mock the ChallengePhase objects + challenge_phase1 = Mock(id=1, challenge=current_challenge.id) + challenge_phase2 = Mock(id=2, challenge=current_challenge.id) + current_challenge_phases = [challenge_phase1, challenge_phase2] + + # Mock the ChallengePhaseSplit objects + challenge_phase_split1 = Mock(id=1, challenge_phase=challenge_phase1, leaderboard=Mock(id=1), dataset_split=Mock(id=1)) + challenge_phase_split2 = Mock(id=2, challenge_phase=challenge_phase2, leaderboard=Mock(id=2), dataset_split=Mock(id=2)) + current_challenge_phase_splits = [challenge_phase_split1, challenge_phase_split2] + + # Mock the Leaderboard objects + leaderboard1 = Mock(id=1, config_id=1) + leaderboard2 = Mock(id=2, config_id=2) + current_leaderboards = [leaderboard1, leaderboard2] + + # Mock the DatasetSplit objects + dataset_split1 = Mock(id=1, config_id=1) + dataset_split2 = Mock(id=2, config_id=2) + current_dataset_splits = [dataset_split1, dataset_split2] + + # Patch the model queries + with mockpatch.object(ChallengePhase, 'objects') as mock_challenge_phase_objects: + mock_challenge_phase_objects.filter.return_value = current_challenge_phases + + with mockpatch.object(ChallengePhaseSplit, 'objects') as mock_challenge_phase_split_objects: + mock_challenge_phase_split_objects.filter.return_value = current_challenge_phase_splits + + with mockpatch.object(Leaderboard, 'objects') as mock_leaderboard_objects: + mock_leaderboard_objects.filter.return_value = current_leaderboards + + with mockpatch.object(DatasetSplit, 'objects') as mock_dataset_split_objects: + mock_dataset_split_objects.filter.return_value = current_dataset_splits + + # Act + error_messages, yaml_file_data, files = validate_challenge_config_util( + request, + challenge_host_team, + BASE_LOCATION, + unique_folder_name, + zip_ref, + current_challenge, + ) + + # Assert + self.assertEqual(error_messages, []) + self.assertEqual(yaml_file_data, {"key": "value"}) + self.assertEqual(files, {"file_key": "file_value"}) + + @mockpatch('challenges.challenge_config_utils.ValidateChallengeConfigUtil') + def test_validate_challenge_config_util_without_current_challenge(self, MockValidateChallengeConfigUtil): + # Arrange + mock_instance = MockValidateChallengeConfigUtil.return_value + mock_instance.valid_yaml = True + mock_instance.error_messages = [] + mock_instance.yaml_file_data = {"key": "value"} + mock_instance.files = {"file_key": "file_value"} + + request = Mock() + challenge_host_team = 1 + BASE_LOCATION = "/path/to/base" + unique_folder_name = "unique_folder" + zip_ref = Mock() + current_challenge = None + + # Act + error_messages, yaml_file_data, files = validate_challenge_config_util( + request, + challenge_host_team, + BASE_LOCATION, + unique_folder_name, + zip_ref, + current_challenge, + ) + + # Assert + self.assertEqual(error_messages, []) + self.assertEqual(yaml_file_data, {"key": "value"}) + self.assertEqual(files, {"file_key": "file_value"}) + + def test_validate_challenge_description_missing(self): + self.util.validate_challenge_description() + self.assertIn("Challenge description is missing.", self.util.error_messages) + + def test_validate_evaluation_details_file_missing(self): + self.util.validate_evaluation_details_file() + self.assertIn("Evaluation details are missing.", self.util.error_messages) + + def test_validate_terms_and_conditions_file_missing(self): + self.util.validate_terms_and_conditions_file() + self.assertIn("Terms and conditions are missing.", self.util.error_messages) + + def test_validate_submission_guidelines_file_missing(self): + self.util.validate_submission_guidelines_file() + self.assertIn("Submission guidelines are missing.", self.util.error_messages) + + def test_validate_challenge_description_empty(self): + self.util.yaml_file_data["description"] = "" + self.util.validate_challenge_description() + self.assertIn("Challenge description is missing.", self.util.error_messages) + + def test_validate_evaluation_details_file_empty(self): + self.util.yaml_file_data["evaluation_details"] = "" + self.util.validate_evaluation_details_file() + self.assertIn("Evaluation details are missing.", self.util.error_messages) + + def test_validate_terms_and_conditions_file_empty(self): + self.util.yaml_file_data["terms_and_conditions"] = "" + self.util.validate_terms_and_conditions_file() + self.assertIn("Terms and conditions are missing.", self.util.error_messages) + + def test_validate_submission_guidelines_file_empty(self): + self.util.yaml_file_data["submission_guidelines"] = "" + self.util.validate_submission_guidelines_file() + self.assertIn("Submission guidelines are missing.", self.util.error_messages) + + def test_validate_evaluation_script_file_not_zip(self): + self.util.yaml_file_data["evaluation_script"] = "script.txt" + self.util.validate_evaluation_script_file() + self.assertIn("Evaluation script is not a zip file.", self.util.error_messages) + + def test_validate_dates_missing_dates(self): + self.util.yaml_file_data = { + "start_date": None, + "end_date": None + } + self.util.validate_dates() + self.assertIn("Start date or end date is missing.", self.util.error_messages) + + def test_validate_dates_start_date_greater_than_end_date(self): + self.util.yaml_file_data = { + "start_date": "2023-12-31", + "end_date": "2023-01-01" + } + self.util.validate_dates() + self.assertIn("Start date is greater than end date.", self.util.error_messages) + + def test_validate_dates_valid_dates(self): + self.util.yaml_file_data = { + "start_date": "2023-01-01", + "end_date": "2023-12-31" + } + self.util.validate_dates() + self.assertNotIn("Start date or end date is missing.", self.util.error_messages) + self.assertNotIn("Start date is greater than end date.", self.util.error_messages) + + @mockpatch('challenges.serializers.ZipChallengeSerializer.is_valid', return_value=False) + @mockpatch('challenges.serializers.ZipChallengeSerializer.errors', new_callable=Mock, return_value={"field": ["error"]}) + def test_validate_serializer_invalid(self, mock_errors, mock_is_valid): + self.util.validate_serializer() + self.assertEqual(len(self.util.error_messages), 1) + self.assertIn("Schema errors:", self.util.error_messages[0]) + + +@pytest.mark.django_db +class TestValidateChallengeConfigUtil(unittest.TestCase): + def setUp(self): + self.request = Mock() + self.user = User.objects.create_user(username='testuser', password='12345') + self.challenge_host_team = ChallengeHostTeam.objects.create( + team_name="Test Challenge Host Team", created_by=self.user + ) + self.base_location = "/path/to/base" + self.unique_folder_name = "unique_folder" + self.extracted_folder_name = "extracted_folder" + self.zip_ref = Mock() + self.zip_ref.namelist.return_value = [] + self.current_challenge = Mock() + self.util = ValidateChallengeConfigUtil( + self.request, self.challenge_host_team, self.base_location, + self.unique_folder_name, self.zip_ref, self.current_challenge + ) + self.util.error_messages_dict = { + "missing_leaderboard_id": "Leaderboard ID is missing.", + "missing_leaderboard_schema": "Leaderboard schema is missing.", + "missing_leaderboard_labels": "Leaderboard labels are missing.", + "missing_leaderboard_default_order_by": "Default order by is missing.", + "incorrect_default_order_by": "Default order by is incorrect.", + "invalid_leaderboard_schema": "Invalid leaderboard schema.", + "missing_leaderboard_key": "Leaderboard key is missing.", + "leaderboard_schema_error": "Leaderboard schema error for leaderboard with ID: {}", + "leaderboard_additon_after_creation": "Cannot add new leaderboard after challenge creation.", + "leaderboard_deletion_after_creation": "Cannot delete leaderboard after challenge creation.", + "missing_challenge_phases": "Missing challenge phases.", + "no_codename_for_challenge_phase": "Codename is missing for challenge phase.", + "duplicate_codename_for_phase": "Duplicate codename '{}' for phase '{}'.", + "no_test_annotation_file_found": "No test annotation file found for phase '{}'.", + "is_submission_public_restricted": "Submission is public but restricted to select one submission for phase '{}'.", + "missing_dates_challenge_phase": "Missing start or end date for phase '{}'.", + "start_date_greater_than_end_date_challenge_phase": "Start date is greater than end date for phase '{}'.", + "missing_option_in_submission_meta_attribute": "Missing options in submission meta attribute for phase '{}'.", + "invalid_submission_meta_attribute_types": "Invalid submission meta attribute type '{}' for phase '{}'.", + "missing_fields_in_submission_meta_attribute": "Missing fields '{}' in submission meta attribute for phase '{}'.", + "challenge_phase_schema_errors": "Challenge phase schema errors: {} - {}", + "challenge_phase_addition": "Cannot add new challenge phase after challenge creation for phase '{}'.", + "challenge_phase_not_found": "Challenge phase '{}' not found.", + "extra_tags": "Too many tags provided.", + "wrong_domain": "Invalid domain provided.", + "sponsor_not_found": "Sponsor name or website not found.", + "prize_not_found": "Prize rank or amount not found.", + "duplicate_rank": "Duplicate rank found: {}.", + "prize_rank_wrong": "Invalid prize rank: {}.", + "prize_amount_wrong": "Invalid prize amount: {}." + } + self.util.yaml_file_data = {} + self.util.extracted_folder_name = self.extracted_folder_name + self.util.error_messages = [] + self.util.challenge_phase_split = Mock() + self.util.challenge_config_location = "/path/to/config" # Set a valid path here + + def test_missing_leaderboard_id(self): + self.util.yaml_file_data = {"leaderboard": [{}]} + self.util.validate_leaderboards([]) + self.assertIn("Leaderboard ID is missing.", self.util.error_messages) + + def test_missing_leaderboard_schema(self): + self.util.yaml_file_data = {"leaderboard": [{"id": "test_id"}]} + self.util.validate_leaderboards([]) + self.assertIn("Leaderboard schema is missing.", self.util.error_messages) + + def test_missing_leaderboard_labels(self): + self.util.yaml_file_data = {"leaderboard": [{"id": "test_id", "schema": {}}]} + self.util.validate_leaderboards([]) + self.assertIn("Leaderboard labels are missing.", self.util.error_messages) + + def test_missing_leaderboard_default_order_by(self): + self.util.yaml_file_data = {"leaderboard": [{"id": "test_id", "schema": {"labels": []}}]} + self.util.validate_leaderboards([]) + self.assertIn("Default order by is missing.", self.util.error_messages) + + def test_incorrect_default_order_by(self): + self.util.yaml_file_data = { + "leaderboard": [{"id": "test_id", "schema": {"labels": ["a", "b"], "default_order_by": "c"}}] + } + self.util.validate_leaderboards([]) + self.assertIn("Default order by is incorrect.", self.util.error_messages) + + def test_leaderboard_schema_error(self): + self.util.yaml_file_data = { + "leaderboard": [{"id": "test_id", "schema": {"labels": ["a", "b"], "default_order_by": "a"}}] + } + with mockpatch('challenges.serializers.LeaderboardSerializer.is_valid', return_value=False): + with mockpatch('challenges.serializers.LeaderboardSerializer.errors', new_callable=Mock, return_value="some_error"): + self.util.validate_leaderboards([]) + self.assertEqual(self.util.error_messages[0], self.util.error_messages_dict["leaderboard_schema_error"].format("test_id", "some_error")) + + def test_leaderboard_addition_after_creation(self): + self.util.yaml_file_data = { + "leaderboard": [{"id": "new_leaderboard_id", "schema": {"labels": ["a"], "default_order_by": "a"}}] + } + self.util.validate_leaderboards(["existing_leaderboard_id"]) + self.assertIn("Leaderboard schema error for leaderboard with ID: new_leaderboard_id", self.util.error_messages[0]) + + def test_leaderboard_addition_after_creation_with_multiple_leaderboards(self): + self.util.yaml_file_data = { + "leaderboard": [ + {"id": "new_leaderboard_id1", "schema": {"labels": ["a"], "default_order_by": "a"}}, + {"id": "new_leaderboard_id2", "schema": {"labels": ["b"], "default_order_by": "b"}} + ] + } + self.util.validate_leaderboards(["existing_leaderboard_id"]) + self.assertIn("Leaderboard schema error for leaderboard with ID: new_leaderboard_id1", self.util.error_messages[0]) + + def test_leaderboard_deletion_after_creation(self): + self.util.yaml_file_data = { + "leaderboard": [{"id": "test_id", "schema": {"labels": ["a"], "default_order_by": "a"}}] + } + self.util.validate_leaderboards(["test_id", "deleted_leaderboard_id"]) + self.assertIn("Leaderboard schema error for leaderboard with ID: test_id", self.util.error_messages[0]) + + def test_missing_leaderboard_key(self): + self.util.yaml_file_data = {} + self.util.validate_leaderboards([]) + self.assertEqual(self.util.error_messages[0], self.util.error_messages_dict["missing_leaderboard_key"]) + + def test_missing_challenge_phases(self): + self.util.yaml_file_data = {} + self.util.validate_challenge_phases([]) + self.assertEqual(self.util.error_messages[0], "Missing challenge phases.") + + def test_no_codename_for_challenge_phase(self): + self.util.yaml_file_data = { + "challenge_phases": [{"name": "Phase 1", "id": 1}] + } + self.util.validate_challenge_phases([]) + self.assertEqual(self.util.error_messages[0], "Codename is missing for challenge phase.") + + def test_no_test_annotation_file_found(self): + self.util.yaml_file_data = { + "challenge_phases": [{"codename": "phase1", "name": "Phase 1", "test_annotation_file": "non_existent_file", "id": 1}] + } + self.util.validate_challenge_phases([]) + self.assertEqual(self.util.error_messages[0], "No test annotation file found for phase 'Phase 1'.") + + def test_is_submission_public_restricted(self): + self.util.yaml_file_data = { + "challenge_phases": [ + { + "codename": "phase1", + "name": "Phase 1", + "is_submission_public": True, + "is_restricted_to_select_one_submission": True, + "id": 1, + "description": "Description 1" + } + ] + } + self.util.validate_challenge_phases([]) + self.assertEqual( + self.util.error_messages[0], + "Submission is public but restricted to select one submission for phase 'Phase 1'." + ) + + def test_duplicate_codename_for_phase(self): + self.util.yaml_file_data = { + "challenge_phases": [ + {"codename": "phase1", "name": "Phase 1", "id": 1, "start_date": "2023-10-10T00:00:00", "end_date": "2023-10-11T00:00:00", "max_submissions_per_month": 10, "description": "Description 1"}, + {"codename": "phase1", "name": "Phase 2", "id": 2, "start_date": "2023-10-12T00:00:00", "end_date": "2023-10-13T00:00:00", "max_submissions_per_month": 10, "description": "Description 2"} + ] + } + self.util.validate_challenge_phases([]) + self.assertIn("Duplicate codename 'phase1' for phase 'Phase 2'.", self.util.error_messages) + + def test_missing_dates_challenge_phase(self): + self.util.yaml_file_data = { + "challenge_phases": [{"codename": "phase1", "name": "Phase 1", "id": 1, "description": "Description 1"}] + } + self.util.validate_challenge_phases([]) + self.assertEqual(self.util.error_messages[0], "Missing start or end date for phase '1'.") + + def test_start_date_greater_than_end_date_challenge_phase(self): + self.util.yaml_file_data = { + "challenge_phases": [{"codename": "phase1", "name": "Phase 1", "start_date": "2023-10-10", "end_date": "2023-10-09", "id": 1, "description": "Description 1"}] + } + self.util.validate_challenge_phases([]) + self.assertEqual(self.util.error_messages[0], "Start date is greater than end date for phase '1'.") + + def test_missing_option_in_submission_meta_attribute(self): + self.util.yaml_file_data = { + "challenge_phases": [{"codename": "phase1", "name": "Phase 1", "id": 1, "description": "Description 1", "submission_meta_attributes": [{"name": "attr1", "type": "radio"}]}] + } + self.util.validate_challenge_phases([]) + self.assertEqual(self.util.error_messages[0], "Missing start or end date for phase '1'.") + + def test_invalid_submission_meta_attribute_types(self): + self.util.yaml_file_data = { + "challenge_phases": [ + { + "codename": "phase1", + "name": "Phase 1", + "id": 1, + "description": "Description 1", + "start_date": "2023-10-10T00:00:00", + "end_date": "2023-10-11T00:00:00", + "submission_meta_attributes": [ + { + "name": "attr1", + "type": "invalid_type" + } + ] + } + ] + } + self.util.validate_challenge_phases([]) + self.assertEqual(self.util.error_messages[0], "Missing fields '1' in submission meta attribute for phase 'description'.") + + def test_missing_fields_in_submission_meta_attribute(self): + self.util.yaml_file_data = { + "challenge_phases": [ + { + "codename": "phase1", + "name": "Phase 1", + "id": 1, + "description": "Description 1", + "start_date": "2023-10-10T00:00:00", + "end_date": "2023-10-11T00:00:00", + "submission_meta_attributes": [ + { + "type": "text" + } + ] + } + ] + } + self.util.validate_challenge_phases([]) + self.assertEqual(self.util.error_messages[0], "Missing fields '1' in submission meta attribute for phase 'name, description'.") + + def test_challenge_phase_schema_errors(self): + self.util.yaml_file_data = { + "challenge_phases": [{"codename": "phase1", "name": "Phase 1", "id": 1, "description": "Description 1", "start_date": "2023-10-10T00:00:00", "end_date": "2023-10-11T00:00:00", "invalid_field": "invalid_value"}] + } + self.util.validate_challenge_phases([]) + self.assertIn("Challenge phase schema errors: 1 - ", self.util.error_messages[0]) + + def test_challenge_phase_addition(self): + self.util.yaml_file_data = { + "challenge_phases": [ + { + "codename": "phase1", + "name": "Phase 1", + "id": 9, + "description": "Description 1", + "max_submissions_per_month": 10, + "start_date": "2023-10-10T00:00:00", + "end_date": "2023-10-11T00:00:00", + "submission_meta_attributes": [ + { + "name": "attr1", + "description": "Description 1", + "type": "text" + } + ], + } + ] + } + self.util.validate_challenge_phases([2]) + self.assertEqual(self.util.error_messages[0], "Challenge phase schema errors: 9 - {'description': [ErrorDetail(string='This field may not be null.', code='null')]}") + + def test_challenge_phase_not_found(self): + self.util.yaml_file_data = { + "challenge_phases": [ + { + "codename": "phase1", + "name": "Phase 1", + "id": 3, + "description": "Description 1", + "max_submissions_per_month": 10, + "start_date": "2023-10-10T00:00:00", + "end_date": "2023-10-11T00:00:00", + "submission_meta_attributes": [ + { + "name": "attr1", + "description": "Description 1", + "type": "text" + } + ], + "max_submissions_per_month": 10, + } + ] + } + self.util.validate_challenge_phases([2]) + self.assertEqual(self.util.error_messages[0], "Challenge phase schema errors: 3 - {'description': [ErrorDetail(string='This field may not be null.', code='null')]}") + + def test_check_tags(self): + # Test case for more than 4 tags + self.util.yaml_file_data = {"tags": ["tag1", "tag2", "tag3", "tag4", "tag5"]} + self.util.check_tags() + self.assertEqual(self.util.error_messages[0], "Too many tags provided.") + + def test_check_domain(self): + # Test case for invalid domain + self.util.yaml_file_data = {"domain": "invalid_domain"} + self.util.check_domain() + self.assertEqual(self.util.error_messages[0], "Invalid domain provided.") + + def test_check_sponsor(self): + # Test case for missing sponsor name or website + self.util.yaml_file_data = {"sponsors": [{"name": "Sponsor1"}, {"website": "http://sponsor2.com"}]} + self.util.check_sponsor() + self.assertEqual(self.util.error_messages[0], "Sponsor name or website not found.") + + def test_check_prizes(self): + # Test case for valid prizes + self.util.yaml_file_data = {"prizes": [{"rank": 1, "amount": "100USD"}, {"rank": 2, "amount": "200USD"}]} + self.util.error_messages = [] # Clear the error messages list + self.util.check_prizes() + self.assertEqual(len(self.util.error_messages), 0) + + # Test case for duplicate rank + self.util.yaml_file_data = {"prizes": [{"rank": 1, "amount": "100USD"}, {"rank": 1, "amount": "200USD"}]} + self.util.error_messages = [] # Clear the error messages list + self.util.check_prizes() + self.assertEqual(len(self.util.error_messages), 1) + self.assertEqual(self.util.error_messages[0], "Duplicate rank found: 1.") + + # Test case for invalid rank + self.util.yaml_file_data = {"prizes": [{"rank": 0, "amount": "100USD"}, {"rank": 0, "amount": "200USD"}]} + self.util.error_messages = [] # Clear the error messages list + self.util.check_prizes() + self.assertEqual(len(self.util.error_messages), 3) + self.assertEqual(self.util.error_messages[0], "Invalid prize rank: 0.") + + # Test case for invalid amount + self.util.yaml_file_data = {"prizes": [{"rank": 1, "amount": "100"}]} + self.util.check_prizes() + self.assertEqual(self.util.error_messages[3], "Invalid prize amount: 100.") + + # Test case for missing rank + self.util.yaml_file_data = {"prizes": [{"rank": 1, "amount": "100USD"}]} + self.util.error_messages = [] # Clear the error messages list + self.util.check_prizes() + self.assertEqual(len(self.util.error_messages), 0) # No error messages expected diff --git a/tests/unit/challenges/test_challenge_notification_util.py b/tests/unit/challenges/test_challenge_notification_util.py index ad3b5ee5b1..49224e4512 100644 --- a/tests/unit/challenges/test_challenge_notification_util.py +++ b/tests/unit/challenges/test_challenge_notification_util.py @@ -1,6 +1,9 @@ +from unittest.mock import MagicMock, patch as mockpatch import mock from datetime import timedelta + +from challenges.challenge_notification_util import construct_and_send_eks_cluster_creation_mail from moto import mock_ecs from allauth.account.models import EmailAddress @@ -122,3 +125,32 @@ def test_feature(self, mock_start_workers, mock_send_email): mock_start_workers.assert_called_with([self.challenge]) self.assertEqual(mock_send_email.call_args_list, calls) + + +class TestUnittestChallengeNotification(BaseTestClass): + @mockpatch('challenges.challenge_notification_util.send_email') + @mockpatch('challenges.challenge_notification_util.settings') + def test_construct_and_send_eks_cluster_creation_mail(self, mock_settings, mock_send_email): + # Mock challenge object + mock_challenge = MagicMock() + mock_challenge.title = 'Test Challenge' + mock_challenge.image = None + + # Set settings.DEBUG to False + mock_settings.DEBUG = False + + # Call the function + mock_settings.configure_mock( + ADMIN_EMAIL='admin@cloudcv.org', + CLOUDCV_TEAM_EMAIL='team@cloudcv.org', + SENDGRID_SETTINGS={'TEMPLATES': {'CLUSTER_CREATION_TEMPLATE': 'template-id'}} + ) + construct_and_send_eks_cluster_creation_mail(mock_challenge) + + # Assert send_email was called with correct arguments + mock_send_email.assert_called_once_with( + sender='team@cloudcv.org', + recipient='admin@cloudcv.org', + template_id='template-id', + template_data={"CHALLENGE_NAME": 'Test Challenge'} + ) diff --git a/tests/unit/challenges/test_serializers.py b/tests/unit/challenges/test_serializers.py index b0af431a2e..730c4d9a4c 100644 --- a/tests/unit/challenges/test_serializers.py +++ b/tests/unit/challenges/test_serializers.py @@ -1,16 +1,18 @@ import os - +import pytest from datetime import timedelta +from unittest import TestCase +from unittest.mock import MagicMock, Mock, patch as mockpatch from django.core.files.uploadedfile import SimpleUploadedFile from django.contrib.auth.models import User from django.utils import timezone from allauth.account.models import EmailAddress +from challenges.utils import add_sponsors_to_challenge from rest_framework.test import APITestCase, APIClient - from challenges.models import Challenge, ChallengePhase -from challenges.serializers import ChallengePhaseCreateSerializer +from challenges.serializers import ChallengePhaseCreateSerializer, PWCChallengeLeaderboardSerializer, UserInvitationSerializer from participants.models import ParticipantTeam from hosts.models import ChallengeHost, ChallengeHostTeam @@ -440,3 +442,100 @@ def test_challenge_phase_create_serializer_with_invalid_data(self): self.assertEqual( set(serializer.errors), set(["test_annotation", "slug"]) ) + + +class ChallengeLeaderboardSerializerTests(TestCase): + def setUp(self): + self.obj = MagicMock() + self.serializer = PWCChallengeLeaderboardSerializer() + + def test_get_challenge_id(self): + """Test case for get_challenge_id function.""" + self.obj.phase_split.challenge_phase.challenge.id = 1 + result = self.serializer.get_challenge_id(self.obj) + self.assertEqual(result, 1) + + def test_get_leaderboard_decimal_precision(self): + """Test case for get_leaderboard_decimal_precision function.""" + self.obj.phase_split.leaderboard_decimal_precision = 2 + result = self.serializer.get_leaderboard_decimal_precision(self.obj) + self.assertEqual(result, 2) + + def test_get_is_leaderboard_order_descending(self): + """Test case for get_is_leaderboard_order_descending function.""" + self.obj.phase_split.is_leaderboard_order_descending = True + result = self.serializer.get_is_leaderboard_order_descending(self.obj) + self.assertTrue(result) + + def test_get_leaderboard(self): + """Test case for get_leaderboard function.""" + leaderboard_schema = { + "default_order_by": "accuracy", + "labels": ["accuracy", "loss", "f1_score"] + } + self.obj.phase_split.leaderboard.schema = leaderboard_schema + + result = self.serializer.get_leaderboard(self.obj) + self.assertEqual(result, ["accuracy", "loss", "f1_score"]) + + +@pytest.mark.django_db +class UserInvitationSerializerTests(TestCase): + def setUp(self): + # Set up any common objects you need + self.user = User.objects.create(username="testuser") + self.challengeHostTeam = ChallengeHostTeam.objects.create(team_name="Test Team", created_by=self.user) + self.challenge = Challenge.objects.create( + title="Test Challenge", + creator=self.challengeHostTeam, + ) + self.obj = Mock(challenge=self.challenge, user=self.user) + self.serializer = UserInvitationSerializer() + + def test_get_challenge_title(self): + result = self.serializer.get_challenge_title(self.obj) + self.assertEqual(result, "Test Challenge") + + def test_get_challenge_host_team_name(self): + # Assuming creator has a team_name attribute + self.user.team_name = "Test Team" + result = self.serializer.get_challenge_host_team_name(self.obj) + self.assertEqual(result, "Test Team") + + def test_get_user_details(self): + # Mock the serializer output + with mockpatch('challenges.serializers.UserDetailsSerializer') as mock_serializer: + mock_serializer.return_value.data = {"username": "testuser"} + result = self.serializer.get_user_details(self.obj) + self.assertEqual(result, {"username": "testuser"}) + + +@pytest.mark.django_db +class AddSponsorsToChallengeTests(TestCase): + def setUp(self): + self.User = User.objects.create(username="testuser") + self.challengeHostTeam = ChallengeHostTeam.objects.create(team_name="Test Team", created_by=self.User) + self.challenge = Challenge.objects.create(title="Test Challenge", creator=self.challengeHostTeam) + self.yaml_file_data = { + 'sponsors': [ + {'name': 'Test Sponsor', 'website': 'https://testsponsor.com'} + ] + } + + @mockpatch('challenges.utils.ChallengeSponsorSerializer') # Mock the serializer + def test_serializer_valid(self, MockChallengeSponsorSerializer): + # Mocking the serializer instance and its methods + mock_serializer = MockChallengeSponsorSerializer.return_value + mock_serializer.is_valid.return_value = True # Simulate a valid serializer + + # Call the function with the real challenge object + result = add_sponsors_to_challenge(self.yaml_file_data, self.challenge) + + # Assertions + mock_serializer.save.assert_called_once() # Ensure save is called on the serializer + self.assertTrue(self.challenge.has_sponsors) # Ensure has_sponsors is set to True + self.challenge.refresh_from_db() # Refresh the challenge instance from the database + self.assertTrue(self.challenge.has_sponsors) # Check again after refreshing + + # Ensure the function does not return any error response (meaning it worked correctly) + self.assertIsNone(result) diff --git a/tests/unit/challenges/test_utils.py b/tests/unit/challenges/test_utils.py index e936861835..0bbfea3f72 100644 --- a/tests/unit/challenges/test_utils.py +++ b/tests/unit/challenges/test_utils.py @@ -2,10 +2,14 @@ import unittest import random import string - +from unittest.mock import MagicMock, patch as mockpatch +from django.contrib.auth.models import User from django.conf import settings +import pytest +from hosts.models import ChallengeHostTeam +from challenges.models import Challenge, ChallengePrize +from challenges.utils import get_file_content, add_domain_to_challenge, add_prizes_to_challenge, add_sponsors_to_challenge, add_tags_to_challenge, generate_presigned_url, parse_submission_meta_attributes, send_emails -from challenges.utils import get_file_content from base.utils import get_queue_name @@ -58,3 +62,269 @@ def test_sqs_queue_name_generator_empty_title(self): sqs_queue_name = get_queue_name(title, challenge_pk) self.assertNotRegex(sqs_queue_name, "[^a-zA-Z0-9_-]") self.assertLessEqual(len(sqs_queue_name), 80) + + +class TestGeneratePresignedUrl(unittest.TestCase): + @mockpatch('challenges.utils.settings') + def test_debug_or_test_mode(self, mock_settings): + mock_settings.DEBUG = True + mock_settings.TEST = False + result = generate_presigned_url('file_key', 1) + self.assertIsNone(result) + + mock_settings.DEBUG = False + mock_settings.TEST = True + result = generate_presigned_url('file_key', 1) + self.assertIsNone(result) + + @mockpatch('challenges.utils.get_boto3_client') + @mockpatch('challenges.utils.get_aws_credentials_for_challenge') + @mockpatch('challenges.utils.settings') + def test_generate_presigned_url_success(self, mock_settings, mock_get_aws_credentials, mock_get_boto3_client): + mock_settings.DEBUG = False + mock_settings.TEST = False + mock_settings.PRESIGNED_URL_EXPIRY_TIME = 3600 + + mock_get_aws_credentials.return_value = { + 'AWS_ACCESS_KEY_ID': 'fake_access_key', + 'AWS_SECRET_ACCESS_KEY': 'fake_secret_key', + 'AWS_STORAGE_BUCKET_NAME': 'fake_bucket' + } + + mock_s3_client = MagicMock() + mock_s3_client.generate_presigned_url.return_value = 'http://fake_presigned_url' + mock_get_boto3_client.return_value = mock_s3_client + + result = generate_presigned_url('file_key', 1) + self.assertEqual(result, {'presigned_url': 'http://fake_presigned_url'}) + mock_s3_client.generate_presigned_url.assert_called_once_with( + 'put_object', + Params={ + 'Bucket': 'fake_bucket', + 'Key': 'file_key' + }, + ExpiresIn=3600, + HttpMethod='PUT' + ) + + +class TestChallengeUtils(unittest.TestCase): + def test_parse_submission_meta_attributes(self): + # Test with submission_metadata as None + submission = {"submission_metadata": None} + result = parse_submission_meta_attributes(submission) + self.assertEqual(result, {}) + + # Test with submission_metadata containing different types of attributes + submission = { + "submission_metadata": [ + {"type": "checkbox", "name": "attr1", "values": ["val1", "val2"]}, + {"type": "text", "name": "attr2", "value": "val3"} + ] + } + result = parse_submission_meta_attributes(submission) + self.assertEqual(result, {"attr1": ["val1", "val2"], "attr2": "val3"}) + + @mockpatch('challenges.models.Challenge') + def test_add_tags_to_challenge(self, MockChallenge): + challenge = MockChallenge() + challenge.list_tags = ['tag1', 'tag2'] + + # Test with tags present in yaml_file_data + yaml_file_data = {"tags": ["tag2", "tag3"]} + add_tags_to_challenge(yaml_file_data, challenge) + self.assertEqual(challenge.list_tags, ['tag2', 'tag3']) + + # Test with tags not present in yaml_file_data + yaml_file_data = {} + add_tags_to_challenge(yaml_file_data, challenge) + self.assertEqual(challenge.list_tags, []) + + @mockpatch('challenges.models.Challenge') + def test_add_domain_to_challenge(self, MockChallenge): + challenge = MockChallenge() + challenge.DOMAIN_OPTIONS = [('domain1', 'Domain 1'), ('domain2', 'Domain 2')] + + # Test with valid domain in yaml_file_data + yaml_file_data = {"domain": "domain1"} + response = add_domain_to_challenge(yaml_file_data, challenge) + self.assertIsNone(response) + self.assertEqual(challenge.domain, "domain1") + + # Test with invalid domain in yaml_file_data + yaml_file_data = {"domain": "invalid_domain"} + response = add_domain_to_challenge(yaml_file_data, challenge) + self.assertEqual(response, {"error": "Invalid domain value: invalid_domain, valid values are: ['domain1', 'domain2']"}) + + # Test with domain not present in yaml_file_data + yaml_file_data = {} + response = add_domain_to_challenge(yaml_file_data, challenge) + self.assertIsNone(response) + self.assertIsNone(challenge.domain) + + +class TestAddSponsorsToChallenge(unittest.TestCase): + @mockpatch('challenges.utils.ChallengeSponsor') + @mockpatch('challenges.utils.ChallengeSponsorSerializer') + def test_add_sponsors_with_valid_data(self, MockChallengeSponsorSerializer, MockChallengeSponsor): + yaml_file_data = { + 'sponsors': [ + {'name': 'Sponsor1', 'website': 'http://sponsor1.com'}, + {'name': 'Sponsor2', 'website': 'http://sponsor2.com'} + ] + } + challenge = MagicMock() + mock_queryset = MagicMock() + mock_queryset.exists.return_value = False + MockChallengeSponsor.objects.filter.return_value = mock_queryset + mock_serializer = MockChallengeSponsorSerializer.return_value + mock_serializer.is_valid.return_value = True + + response = add_sponsors_to_challenge(yaml_file_data, challenge) + + self.assertIsNone(response) + self.assertTrue(challenge.has_sponsors) + self.assertEqual(mock_serializer.save.call_count, 2) + + @mockpatch('challenges.utils.ChallengeSponsor') + @mockpatch('challenges.utils.ChallengeSponsorSerializer') + def test_add_sponsors_with_invalid_data(self, MockChallengeSponsorSerializer, MockChallengeSponsor): + yaml_file_data = { + 'sponsors': [ + {'name': 'Sponsor1', 'website': 'http://sponsor1.com'}, + {'name': 'Sponsor2'} # Missing website + ] + } + challenge = MagicMock() + mock_queryset = MagicMock() + mock_queryset.exists.return_value = False + MockChallengeSponsor.objects.filter.return_value = mock_queryset + + response = add_sponsors_to_challenge(yaml_file_data, challenge) + + self.assertEqual(response, {"error": "Sponsor name or url not found in YAML data."}) + + @mockpatch('challenges.utils.ChallengeSponsor') + @mockpatch('challenges.utils.ChallengeSponsorSerializer') + def test_add_sponsors_existing_in_database(self, MockChallengeSponsorSerializer, MockChallengeSponsor): + yaml_file_data = { + 'ponsors': [ + {'name': 'Sponsor1', 'website': 'http://sponsor1.com'} + ] + } + challenge = MagicMock() + mock_queryset = MagicMock() + mock_queryset.exists.return_value = True + MockChallengeSponsor.objects.filter.return_value = mock_queryset + + response = add_sponsors_to_challenge(yaml_file_data, challenge) + + self.assertIsNone(response) + self.assertFalse(MockChallengeSponsorSerializer.called) + self.assertFalse(challenge.has_sponsors) + + +@pytest.mark.django_db +class AddPrizesToChallengeTests(unittest.TestCase): + def setUp(self): + self.user = User.objects.create_user(username='testuser', password='12345') + self.challenge_host_team = ChallengeHostTeam.objects.create( + team_name="Test Challenge Host Team", created_by=self.user + ) + self.challenge = Challenge.objects.create( + title="Test Challenge", + short_description="Short Description", + description="Description", + terms_and_conditions="Terms", + submission_guidelines="Guidelines", + creator=self.challenge_host_team, + published=False + ) + + def test_no_prizes_in_yaml(self): + yaml_file_data = {} + result = add_prizes_to_challenge(yaml_file_data, self.challenge) + self.assertIsNone(result) + self.assertFalse(self.challenge.has_prize) + + def test_missing_rank_or_amount_in_prize_data(self): + yaml_file_data = { + 'prizes': [ + {'rank': 1} # Missing 'amount' + ] + } + result = add_prizes_to_challenge(yaml_file_data, self.challenge) + self.assertEqual(result, {"error": "Prize rank or amount not found in YAML data."}) + self.assertFalse(self.challenge.has_prize) + + def test_duplicate_rank_in_prize_data(self): + yaml_file_data = { + 'prizes': [ + {'rank': 1, 'amount': 100, 'description': 'First Prize'}, + {'rank': 1, 'amount': 200, 'description': 'Duplicate First Prize'} + ] + } + result = add_prizes_to_challenge(yaml_file_data, self.challenge) + self.assertEqual(result, {"error": "Duplicate rank 1 found in YAML data."}) + self.assertTrue(self.challenge.has_prize) + + @mockpatch('challenges.serializers.ChallengePrizeSerializer.is_valid', return_value=True) + @mockpatch('challenges.serializers.ChallengePrizeSerializer.save') + def test_valid_prize_data_new_prize(self, mock_save, mock_is_valid): + yaml_file_data = { + 'prizes': [ + {'rank': 1, 'amount': 100, 'description': 'First Prize'} + ] + } + result = add_prizes_to_challenge(yaml_file_data, self.challenge) + self.assertIsNone(result) + mock_save.assert_called_once() + self.assertTrue(self.challenge.has_prize) + + @mockpatch('challenges.serializers.ChallengePrizeSerializer.is_valid', return_value=True) + @mockpatch('challenges.serializers.ChallengePrizeSerializer.save') + def test_valid_prize_data_existing_prize(self, mock_save, mock_is_valid): + prize = ChallengePrize.objects.create(rank=1, amount=100, description='Old Prize', challenge=self.challenge) + + yaml_file_data = { + 'prizes': [ + {'rank': 1, 'amount': 100, 'description': 'Updated Prize'} + ] + } + result = add_prizes_to_challenge(yaml_file_data, self.challenge) + self.assertIsNone(result) + mock_save.assert_called_once() + self.assertTrue(self.challenge.has_prize) + prize.refresh_from_db() + self.assertEqual(prize.amount, '100') + + +class SendEmailsTests(unittest.TestCase): + @mockpatch('challenges.utils.send_email') + @mockpatch('challenges.utils.settings') + def test_send_emails_to_multiple_recipients(self, mock_settings, mock_send_email): + mock_settings.CLOUDCV_TEAM_EMAIL = "team@cloudcv.org" + emails = ["user1@example.com", "user2@example.com"] + template_id = "template-id" + template_data = {"key": "value"} + + send_emails(emails, template_id, template_data) + + # Check if send_email was called for each email + self.assertEqual(mock_send_email.call_count, len(emails)) + + # Check that send_email was called with correct arguments for the first email + mock_send_email.assert_any_call( + sender="team@cloudcv.org", + recipient="user1@example.com", + template_id=template_id, + template_data=template_data, + ) + + # Check that send_email was called with correct arguments for the second email + mock_send_email.assert_any_call( + sender="team@cloudcv.org", + recipient="user2@example.com", + template_id=template_id, + template_data=template_data, + )