From 567f1ec633985caedd420bd669d80a4a29ec7774 Mon Sep 17 00:00:00 2001 From: Keshav Priyadarshi Date: Mon, 26 Aug 2024 21:24:51 +0530 Subject: [PATCH] Fix failing test Signed-off-by: Keshav Priyadarshi --- vulnerabilities/pipelines/__init__.py | 7 ++++--- vulnerabilities/tests/__init__.py | 19 ++++++++++--------- .../tests/pipelines/test_base_pipeline.py | 4 ++-- .../{test_importer.py => test_advisory.py} | 6 ++++-- 4 files changed, 20 insertions(+), 16 deletions(-) rename vulnerabilities/tests/pipes/{test_importer.py => test_advisory.py} (84%) diff --git a/vulnerabilities/pipelines/__init__.py b/vulnerabilities/pipelines/__init__.py index 3dd1b8e73..50ce05432 100644 --- a/vulnerabilities/pipelines/__init__.py +++ b/vulnerabilities/pipelines/__init__.py @@ -18,7 +18,8 @@ from vulnerabilities.importer import AdvisoryData from vulnerabilities.improver import MAX_CONFIDENCE from vulnerabilities.models import Advisory -from vulnerabilities.pipes import advisory +from vulnerabilities.pipes.advisory import import_advisory +from vulnerabilities.pipes.advisory import insert_advisory from vulnerabilities.utils import classproperty module_logger = logging.getLogger(__name__) @@ -85,7 +86,7 @@ def collect_and_store_advisories(self): collected_advisory_count = 0 progress = LoopProgress(total_iterations=self.advisories_count(), logger=self.log) for advisory in progress.iter(self.collect_advisories()): - if _obj := advisory.insert_advisory( + if _obj := insert_advisory( advisory=advisory, pipeline_name=self.qualified_name, logger=self.log, @@ -115,7 +116,7 @@ def import_new_advisories(self): def import_advisory(self, advisory: Advisory) -> int: try: - advisory.import_advisory( + import_advisory( advisory=advisory, pipeline_name=self.qualified_name, confidence=self.advisory_confidence, diff --git a/vulnerabilities/tests/__init__.py b/vulnerabilities/tests/__init__.py index ee106cc74..2e6da3cea 100644 --- a/vulnerabilities/tests/__init__.py +++ b/vulnerabilities/tests/__init__.py @@ -31,15 +31,16 @@ ) -advisory1 = models.Advisory( - aliases=advisory_data1.aliases, - summary=advisory_data1.summary, - affected_packages=[pkg.to_dict() for pkg in advisory_data1.affected_packages], - references=[ref.to_dict() for ref in advisory_data1.references], - url=advisory_data1.url, - created_by="tests", - date_collected=timezone.now(), -) +def get_advisory1(created_by="test_pipeline"): + return models.Advisory.objects.create( + aliases=advisory_data1.aliases, + summary=advisory_data1.summary, + affected_packages=[pkg.to_dict() for pkg in advisory_data1.affected_packages], + references=[ref.to_dict() for ref in advisory_data1.references], + url=advisory_data1.url, + created_by=created_by, + date_collected=timezone.now(), + ) def get_all_vulnerability_relationships_objects(): diff --git a/vulnerabilities/tests/pipelines/test_base_pipeline.py b/vulnerabilities/tests/pipelines/test_base_pipeline.py index bda0479c0..3d747b421 100644 --- a/vulnerabilities/tests/pipelines/test_base_pipeline.py +++ b/vulnerabilities/tests/pipelines/test_base_pipeline.py @@ -13,8 +13,8 @@ from vulnerabilities import models from vulnerabilities.pipelines import VulnerableCodeBaseImporterPipeline -from vulnerabilities.tests import advisory1 from vulnerabilities.tests import advisory_data1 +from vulnerabilities.tests import get_advisory1 class TestVulnerableCodeBaseImporterPipeline(TestCase): @@ -50,7 +50,7 @@ def test_import_new_advisories(self): self.assertEqual(0, models.Vulnerability.objects.count()) base_pipeline = VulnerableCodeBaseImporterPipeline() - base_pipeline.new_advisories = [advisory1] + advisory1 = get_advisory1(created_by=base_pipeline.qualified_name) base_pipeline.import_new_advisories() self.assertEqual(1, models.Vulnerability.objects.count()) diff --git a/vulnerabilities/tests/pipes/test_importer.py b/vulnerabilities/tests/pipes/test_advisory.py similarity index 84% rename from vulnerabilities/tests/pipes/test_importer.py rename to vulnerabilities/tests/pipes/test_advisory.py index 4163009a7..8377a0b81 100644 --- a/vulnerabilities/tests/pipes/test_importer.py +++ b/vulnerabilities/tests/pipes/test_advisory.py @@ -9,13 +9,14 @@ import pytest -from vulnerabilities.pipes.importer import import_advisory -from vulnerabilities.tests import advisory1 +from vulnerabilities.pipes.advisory import import_advisory +from vulnerabilities.tests import get_advisory1 from vulnerabilities.tests import get_all_vulnerability_relationships_objects @pytest.mark.django_db def test_vulnerability_pipes_importer_import_advisory(): + advisory1 = get_advisory1(created_by="test_importer_pipeline") import_advisory(advisory=advisory1, pipeline_name="test_importer_pipeline") all_vulnerability_relation_objects = get_all_vulnerability_relationships_objects() import_advisory(advisory=advisory1, pipeline_name="test_importer_pipeline") @@ -24,6 +25,7 @@ def test_vulnerability_pipes_importer_import_advisory(): @pytest.mark.django_db def test_vulnerability_pipes_importer_import_advisory_different_pipelines(): + advisory1 = get_advisory1(created_by="test_importer_pipeline") import_advisory(advisory=advisory1, pipeline_name="test_importer1_pipeline") all_vulnerability_relation_objects = get_all_vulnerability_relationships_objects() import_advisory(advisory=advisory1, pipeline_name="test_importer2_pipeline")