Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 15 additions & 34 deletions gcp/workers/alias/alias_computation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import unittest

from google.cloud import ndb
from google.protobuf import json_format, timestamp_pb2
from google.protobuf import timestamp_pb2

import osv
import alias_computation
Expand All @@ -38,13 +38,7 @@ def _get_aliases_from_bucket(self, vuln_id):
pb_blob.download_as_bytes())
pb_aliases = list(pb.aliases)

json_blob = bucket.blob(
os.path.join(osv.gcs.VULN_JSON_PATH, vuln_id + '.json'))
pb = json_format.Parse(json_blob.download_as_bytes(),
osv.vulnerability_pb2.Vulnerability())
json_aliases = list(pb.aliases)

return pb_aliases, json_aliases
return pb_aliases

def test_basic(self):
"""Tests basic case."""
Expand Down Expand Up @@ -74,9 +68,8 @@ def test_basic(self):
osv.AliasGroup.bug_ids == 'aaa-123').get().bug_ids
self.assertEqual(['aaa-123', 'aaa-124'], bug_ids)

pb_aliases, json_aliases = self._get_aliases_from_bucket('aaa-123')
pb_aliases = self._get_aliases_from_bucket('aaa-123')
self.assertEqual(['aaa-124'], pb_aliases)
self.assertEqual(['aaa-124'], json_aliases)

def test_bug_reaches_limit(self):
"""Tests bug reaches limit."""
Expand All @@ -96,9 +89,8 @@ def test_bug_reaches_limit(self):
osv.AliasGroup.bug_ids == 'aaa-111').get()
self.assertIsNone(alias_group)

pb_aliases, json_aliases = self._get_aliases_from_bucket('aaa-111')
pb_aliases = self._get_aliases_from_bucket('aaa-111')
self.assertEqual([], pb_aliases)
self.assertEqual([], json_aliases)

def test_update_alias_group(self):
"""Tests updating an existing alias group."""
Expand Down Expand Up @@ -142,9 +134,8 @@ def test_update_alias_group(self):
datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC),
alias_group.last_modified)

pb_aliases, json_aliases = self._get_aliases_from_bucket('bbb-123')
pb_aliases = self._get_aliases_from_bucket('bbb-123')
self.assertEqual(expected_aliases, pb_aliases)
self.assertEqual(expected_aliases, json_aliases)

def test_create_alias_group(self):
"""Tests adding a new alias group."""
Expand Down Expand Up @@ -172,9 +163,8 @@ def test_create_alias_group(self):
self.assertIsNotNone(alias_group)
self.assertEqual(['test-123', 'test-124', 'test-222'], alias_group.bug_ids)

pb_aliases, json_aliases = self._get_aliases_from_bucket('test-123')
pb_aliases = self._get_aliases_from_bucket('test-123')
self.assertEqual(['test-124', 'test-222'], pb_aliases)
self.assertEqual(['test-124', 'test-222'], json_aliases)

def test_delete_alias_group(self):
"""Tests deleting alias groups that only has one vulnerability."""
Expand All @@ -195,9 +185,8 @@ def test_delete_alias_group(self):
osv.AliasGroup.bug_ids == 'ccc-123').get()
self.assertIsNone(alias_group)

pb_aliases, json_aliases = self._get_aliases_from_bucket('ccc-123')
pb_aliases = self._get_aliases_from_bucket('ccc-123')
self.assertEqual([], pb_aliases)
self.assertEqual([], json_aliases)

def test_split_alias_group(self):
"""Tests split an existing alias group into two.
Expand Down Expand Up @@ -251,12 +240,10 @@ def test_split_alias_group(self):
self.assertIsNotNone(alias_group)
self.assertEqual(['ddd-125', 'ddd-126'], alias_group.bug_ids)

pb_aliases, json_aliases = self._get_aliases_from_bucket('ddd-123')
pb_aliases = self._get_aliases_from_bucket('ddd-123')
self.assertEqual(['ddd-124'], pb_aliases)
self.assertEqual(['ddd-124'], json_aliases)
pb_aliases, json_aliases = self._get_aliases_from_bucket('ddd-125')
pb_aliases = self._get_aliases_from_bucket('ddd-125')
self.assertEqual(['ddd-126'], pb_aliases)
self.assertEqual(['ddd-126'], json_aliases)

def test_allow_list(self):
"""Test allow list."""
Expand All @@ -278,9 +265,8 @@ def test_allow_list(self):
osv.AliasGroup.bug_ids == 'eee-111').get()
self.assertEqual(7, len(alias_group.bug_ids))

pb_aliases, json_aliases = self._get_aliases_from_bucket('eee-111')
pb_aliases = self._get_aliases_from_bucket('eee-111')
self.assertEqual(raw_aliases, pb_aliases)
self.assertEqual(raw_aliases, json_aliases)

def test_deny_list(self):
"""Tests deny list."""
Expand Down Expand Up @@ -312,9 +298,8 @@ def test_deny_list(self):
osv.AliasGroup.bug_ids == 'fff-124').get().bug_ids
self.assertEqual(['fff-124', 'fff-125'], bug_ids)

pb_aliases, json_aliases = self._get_aliases_from_bucket('fff-124')
pb_aliases = self._get_aliases_from_bucket('fff-124')
self.assertEqual(['fff-125'], pb_aliases)
self.assertEqual(['fff-125'], json_aliases)

def test_merge_alias_group(self):
"""Tests all bugs of one alias group have been
Expand Down Expand Up @@ -352,9 +337,8 @@ def test_merge_alias_group(self):
self.assertEqual(['ggg-123', 'ggg-124', 'ggg-125', 'ggg-126'],
alias_group[0].bug_ids)

pb_aliases, json_aliases = self._get_aliases_from_bucket('ggg-125')
pb_aliases = self._get_aliases_from_bucket('ggg-125')
self.assertEqual(['ggg-123', 'ggg-124', 'ggg-126'], pb_aliases)
self.assertEqual(['ggg-123', 'ggg-124', 'ggg-126'], json_aliases)

def test_partial_merge_alias_group(self):
"""Tests merging some bugs of one alias group to another alias group."""
Expand Down Expand Up @@ -412,12 +396,10 @@ def test_partial_merge_alias_group(self):
self.assertEqual(1, len(alias_group))
self.assertEqual(['hhh-126', 'hhh-127'], alias_group[0].bug_ids)

pb_aliases, json_aliases = self._get_aliases_from_bucket('hhh-125')
pb_aliases = self._get_aliases_from_bucket('hhh-125')
self.assertEqual(['hhh-123', 'hhh-124'], pb_aliases)
self.assertEqual(['hhh-123', 'hhh-124'], json_aliases)
pb_aliases, json_aliases = self._get_aliases_from_bucket('hhh-127')
pb_aliases = self._get_aliases_from_bucket('hhh-127')
self.assertEqual(['hhh-126'], pb_aliases)
self.assertEqual(['hhh-126'], json_aliases)

def test_alias_group_reaches_limit(self):
"""Tests a alias group reaches limit."""
Expand All @@ -439,9 +421,8 @@ def test_alias_group_reaches_limit(self):
alias_group = osv.AliasGroup.query(osv.AliasGroup.bug_ids == 'iii-1').get()
self.assertIsNone(alias_group)

pb_aliases, json_aliases = self._get_aliases_from_bucket('iii-1')
pb_aliases = self._get_aliases_from_bucket('iii-1')
self.assertEqual([], pb_aliases)
self.assertEqual([], json_aliases)

def test_to_vulnerability(self):
"""Tests OSV bug to vulnerability function."""
Expand Down
15 changes: 3 additions & 12 deletions gcp/workers/alias/upstream_computation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import unittest

from google.cloud import ndb
from google.protobuf import json_format

import osv
import upstream_computation
Expand Down Expand Up @@ -397,13 +396,7 @@ def _get_upstreams_from_bucket(self, vuln_id):
pb_blob.download_as_bytes())
pb_upstream = list(pb.upstream)

json_blob = bucket.blob(
os.path.join(osv.gcs.VULN_JSON_PATH, vuln_id + '.json'))
pb = json_format.Parse(json_blob.download_as_bytes(),
osv.vulnerability_pb2.Vulnerability())
json_upstream = list(pb.upstream)

return pb_upstream, json_upstream
return pb_upstream

def test_upstream_group_basic(self):
"""Test the upstream group get by db_id"""
Expand All @@ -417,9 +410,8 @@ def test_upstream_group_basic(self):
osv.UpstreamGroup.db_id == 'CVE-3').get().upstream_ids
self.assertEqual(['CVE-1', 'CVE-2'], bug_ids)

pb_upstream, json_upstream = self._get_upstreams_from_bucket('CVE-3')
pb_upstream = self._get_upstreams_from_bucket('CVE-3')
self.assertEqual(['CVE-1', 'CVE-2'], pb_upstream)
self.assertEqual(['CVE-1', 'CVE-2'], json_upstream)

def test_upstream_group_complex(self):
"""Testing more complex, realworld case"""
Expand All @@ -436,9 +428,8 @@ def test_upstream_group_complex(self):

self.assertEqual(upstream_ids, bug_ids)

pb_upstream, json_upstream = self._get_upstreams_from_bucket('USN-7234-3')
pb_upstream = self._get_upstreams_from_bucket('USN-7234-3')
self.assertEqual(upstream_ids, pb_upstream)
self.assertEqual(upstream_ids, json_upstream)

def test_upstream_hierarchy_computation(self):
upstream_computation.main()
Expand Down
18 changes: 3 additions & 15 deletions osv/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
import os

from google.cloud import storage
from google.protobuf import json_format

from .vulnerability_pb2 import Vulnerability

VULN_JSON_PATH = 'all/json/'
VULN_PB_PATH = 'all/pb/'

_storage_client = None
Expand Down Expand Up @@ -63,9 +61,9 @@ def get_by_id_with_generation(vuln_id: str) -> tuple[Vulnerability, int] | None:


def upload_vulnerability(vulnerability: Vulnerability,
pb_generation: int | None = None):
generation: int | None = None):
"""Uploads the OSV record to the GCS bucket.
If set, checks if the existing proto blob's generation matches pb_generation
If set, checks if the existing blob's generation matches `generation`
before uploading."""
bucket = get_osv_bucket()
vuln_id = vulnerability.id
Expand All @@ -76,17 +74,7 @@ def upload_vulnerability(vulnerability: Vulnerability,
pb_blob.upload_from_string(
vulnerability.SerializeToString(deterministic=True),
content_type='application/octet-stream',
if_generation_match=pb_generation)
except Exception:
logging.exception('failed to upload %s protobuf to GCS', vuln_id)
# TODO(michaelkedar): send pub/sub message to retry

try:
json_blob = bucket.blob(os.path.join(VULN_JSON_PATH, vuln_id + '.json'))
json_blob.custom_time = modified
json_data = json_format.MessageToJson(
vulnerability, preserving_proto_field_name=True, indent=None)
json_blob.upload_from_string(json_data, content_type='application/json')
if_generation_match=generation)
except Exception:
logging.exception('failed to upload %s protobuf to GCS', vuln_id)
# TODO(michaelkedar): send pub/sub message to retry
1 change: 0 additions & 1 deletion osv/gcs_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def gcs_mock(directory: str | None = None):
"""
with (tempfile.TemporaryDirectory()
if directory is None else contextlib.nullcontext(directory)) as db_dir:
pathlib.Path(db_dir, gcs.VULN_JSON_PATH).mkdir(parents=True, exist_ok=True)
pathlib.Path(db_dir, gcs.VULN_PB_PATH).mkdir(parents=True, exist_ok=True)
bucket = _MockBucket(db_dir)
with mock.patch('osv.gcs.get_osv_bucket', return_value=bucket):
Expand Down
19 changes: 0 additions & 19 deletions osv/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,13 @@
"""Models tests."""

import datetime
import json
import os
import unittest

from . import models

from . import bug
from . import gcs
from . import sources
from . import tests
from . import vulnerability_pb2

Expand Down Expand Up @@ -267,13 +265,6 @@ def test_bug_post_put(self):
blob.download_as_bytes())
self.assertEqual(got_pb, vuln_pb)

blob = bucket.get_blob(os.path.join(gcs.VULN_JSON_PATH, f'{vuln_id}.json'))
self.assertIsNotNone(blob)
self.assertEqual(blob.custom_time,
datetime.datetime(2025, 3, 5, tzinfo=datetime.UTC))
got_json = json.loads(blob.download_as_bytes())
self.assertDictEqual(got_json, sources.vulnerability_to_dict(vuln_pb))

def test_bug_withdraw(self):
"""Test if withdrawing a Bug correctly removes unneeded indices."""
# First put the bug un-withdrawn
Expand Down Expand Up @@ -319,10 +310,6 @@ def test_bug_withdraw(self):
bucket = gcs.get_osv_bucket()
blob = bucket.get_blob(os.path.join(gcs.VULN_PB_PATH, f'{vuln_id}.pb'))
self.assertIsNotNone(blob)
self.assertEqual(blob.custom_time,
datetime.datetime(2025, 3, 3, tzinfo=datetime.UTC))
blob = bucket.get_blob(os.path.join(gcs.VULN_JSON_PATH, f'{vuln_id}.json'))
self.assertIsNotNone(blob)
self.assertEqual(blob.custom_time,
datetime.datetime(2025, 3, 3, tzinfo=datetime.UTC))

Expand All @@ -348,10 +335,6 @@ def test_bug_withdraw(self):
self.assertIsNotNone(blob)
self.assertEqual(blob.custom_time,
datetime.datetime(2025, 4, 4, tzinfo=datetime.UTC))
blob = bucket.get_blob(os.path.join(gcs.VULN_JSON_PATH, f'{vuln_id}.json'))
self.assertIsNotNone(blob)
self.assertEqual(blob.custom_time,
datetime.datetime(2025, 4, 4, tzinfo=datetime.UTC))

def test_oss_fuzz_private(self):
"""Test that non-public Bugs from OSS-Fuzz are not indexed."""
Expand Down Expand Up @@ -391,8 +374,6 @@ def test_oss_fuzz_private(self):
bucket = gcs.get_osv_bucket()
blob = bucket.get_blob(os.path.join(gcs.VULN_PB_PATH, f'{vuln_id}.pb'))
self.assertIsNone(blob)
blob = bucket.get_blob(os.path.join(gcs.VULN_JSON_PATH, f'{vuln_id}.json'))
self.assertIsNone(blob)


def setUpModule():
Expand Down
Loading