diff --git a/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/alias-computation.yaml b/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/alias-computation.yaml index ae44e4b0d32..ff9eee4f106 100644 --- a/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/alias-computation.yaml +++ b/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/alias-computation.yaml @@ -12,3 +12,5 @@ spec: env: - name: GOOGLE_CLOUD_PROJECT value: oss-vdb-test + - name: OSV_VULNERABILITIES_BUCKET + value: osv-test-vulnerabilities diff --git a/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/importer-deleter.yaml b/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/importer-deleter.yaml index 3e9b2cb8fa3..6f54e61a922 100644 --- a/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/importer-deleter.yaml +++ b/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/importer-deleter.yaml @@ -12,6 +12,8 @@ spec: env: - name: GOOGLE_CLOUD_PROJECT value: oss-vdb-test + - name: OSV_VULNERABILITIES_BUCKET + value: osv-test-vulnerabilities image: importer args: - --delete diff --git a/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/importer.yaml b/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/importer.yaml index 62e30052536..334f5b9e41a 100644 --- a/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/importer.yaml +++ b/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/importer.yaml @@ -12,6 +12,8 @@ spec: env: - name: GOOGLE_CLOUD_PROJECT value: oss-vdb-test + - name: OSV_VULNERABILITIES_BUCKET + value: osv-test-vulnerabilities args: # TODO(michaelkedar): ssh secrets # TODO(michaelkedar): single source of truth w/ terraform config diff --git a/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/workers.yaml b/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/workers.yaml index 8419c2f2f05..62d39981997 100644 --- a/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/workers.yaml +++ b/deployment/clouddeploy/gke-workers/environments/oss-vdb-test/workers.yaml @@ -10,6 +10,8 @@ spec: env: - name: GOOGLE_CLOUD_PROJECT value: oss-vdb-test + - name: OSV_VULNERABILITIES_BUCKET + value: osv-test-vulnerabilities args: # TODO(michaelkedar): ssh secrets # TODO(michaelkedar): Somehow grab or enforce redis endpoint from terraform diff --git a/gcp/website/frontend_emulator.py b/gcp/website/frontend_emulator.py index 9a913be08d1..bd0ef8aeec2 100644 --- a/gcp/website/frontend_emulator.py +++ b/gcp/website/frontend_emulator.py @@ -34,7 +34,10 @@ def setUp(): import_last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), timestamp=datetime.datetime(2023, 8, 14, tzinfo=datetime.UTC), ).put() - osv.AliasGroup(bug_ids=['ALIAS-CVE-1', 'CVE-1', 'ALIAS'],).put() + osv.AliasGroup( + bug_ids=['ALIAS-CVE-1', 'CVE-1', 'ALIAS'], + last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), + ).put() osv.Bug( id='CVE-1', diff --git a/gcp/workers/alias/alias_computation.py b/gcp/workers/alias/alias_computation.py index 805c2220ba9..34653c9da0b 100755 --- a/gcp/workers/alias/alias_computation.py +++ b/gcp/workers/alias/alias_computation.py @@ -19,20 +19,26 @@ from google.cloud import ndb import osv +from osv import gcs import osv.logs ALIAS_GROUP_VULN_LIMIT = 32 VULN_ALIASES_LIMIT = 5 -def _update_group(bug_ids, alias_group): +def _update_group(bug_ids: list[str], alias_group: osv.AliasGroup, + changed_vulns: dict[str, osv.AliasGroup | None]): """Updates the alias group in the datastore.""" if len(bug_ids) <= 1: logging.info('Deleting alias group due to too few bugs: %s', bug_ids) + for vuln_id in bug_ids: + changed_vulns[vuln_id] = None alias_group.key.delete() return if len(bug_ids) > ALIAS_GROUP_VULN_LIMIT: logging.info('Deleting alias group due to too many bugs: %s', bug_ids) + for vuln_id in bug_ids: + changed_vulns[vuln_id] = None alias_group.key.delete() return @@ -42,9 +48,12 @@ def _update_group(bug_ids, alias_group): alias_group.bug_ids = bug_ids alias_group.last_modified = datetime.datetime.now(datetime.UTC) alias_group.put() + for vuln_id in bug_ids: + changed_vulns[vuln_id] = alias_group -def _create_alias_group(bug_ids): +def _create_alias_group(bug_ids: list[str], + changed_vulns: dict[str, osv.AliasGroup | None]): """Creates a new alias group in the datastore.""" if len(bug_ids) <= 1: logging.info('Skipping alias group creation due to too few bugs: %s', @@ -58,9 +67,12 @@ def _create_alias_group(bug_ids): new_group = osv.AliasGroup(bug_ids=bug_ids) new_group.last_modified = datetime.datetime.now(datetime.UTC) new_group.put() + for vuln_id in bug_ids: + changed_vulns[vuln_id] = new_group -def _compute_aliases(bug_id, visited, bug_aliases): +def _compute_aliases(bug_id: str, visited: set[str], + bug_aliases: dict[str, set[str]]) -> list[str]: """Computes all aliases for the given bug ID. The returned list contains the bug ID itself, all the IDs from the bug's raw aliases, all the IDs of bugs that have the current bug as an alias, @@ -82,6 +94,49 @@ def _compute_aliases(bug_id, visited, bug_aliases): return sorted(bug_ids) +def _update_vuln_with_group(vuln_id: str, alias_group: osv.AliasGroup | None): + """Updates the Vulnerability in Datastore & GCS with the new alias group. + If `alias_group` is None, assumes a preexisting AliasGroup was just deleted. + """ + # TODO(michaelkedar): Currently, only want to run this on the test instance + # (or when running tests). Remove this check when we're ready for prod. + project = getattr(ndb.get_context().client, 'project') + if not project: + logging.error('failed to get GCP project from ndb.Client') + if project not in ('oss-vdb-test', 'test-osv'): + return + # Get the existing vulnerability first, so we can recalculate search_indices + result = gcs.get_by_id_with_generation(vuln_id) + if result is None: + if osv.Vulnerability.get_by_id(vuln_id) is not None: + logging.error('vulnerability not in GCS - %s', vuln_id) + # TODO(michaelkedar): send pub/sub message to reimport + return + vuln_proto, generation = result + + def transaction(): + vuln: osv.Vulnerability = osv.Vulnerability.get_by_id(vuln_id) + if vuln is None: + logging.error('vulnerability not in Datastore - %s', vuln_id) + # TODO: Raise exception + return + if alias_group is None: + modified = datetime.datetime.now(datetime.UTC) + aliases = [] + else: + modified = alias_group.last_modified + aliases = alias_group.bug_ids + aliases = sorted(set(aliases) - {vuln_id}) + vuln_proto.aliases[:] = aliases + vuln_proto.modified.FromDatetime(modified) + osv.ListedVulnerability.from_vulnerability(vuln_proto).put() + vuln.modified = modified + vuln.put() + + ndb.transaction(transaction) + gcs.upload_vulnerability(vuln_proto, generation) + + def main(): """Updates all alias groups in the datastore by re-computing existing AliasGroups and creating new AliasGroups for un-computed bugs.""" @@ -118,6 +173,10 @@ def main(): visited = set() + # Keep track of vulnerabilities that have been modified, to update GCS later. + # `None` means the AliasGroup has been removed. + changed_vulns: dict[str, osv.AliasGroup | None] = {} + # For each alias group, re-compute the bug IDs in the group and update the # group with the computed bug IDs. for alias_group in all_alias_group: @@ -125,16 +184,23 @@ def main(): # If the bug has already been counted in a different alias group, # we delete the original one to merge two alias groups. if bug_id in visited: + for vuln_id in alias_group.bug_ids: + if vuln_id not in changed_vulns: + changed_vulns[vuln_id] = None alias_group.key.delete() continue bug_ids = _compute_aliases(bug_id, visited, bug_aliases) - _update_group(bug_ids, alias_group) + _update_group(bug_ids, alias_group, changed_vulns) # For each bug ID that has not been visited, create new alias groups. for bug_id in bug_aliases: if bug_id not in visited: bug_ids = _compute_aliases(bug_id, visited, bug_aliases) - _create_alias_group(bug_ids) + _create_alias_group(bug_ids, changed_vulns) + + # For each updated vulnerability, update them in Datastore & GCS + for vuln_id, alias_group in changed_vulns.items(): + _update_vuln_with_group(vuln_id, alias_group) if __name__ == '__main__': diff --git a/gcp/workers/alias/alias_computation_test.py b/gcp/workers/alias/alias_computation_test.py index c12177cd9b4..2654e2211e1 100644 --- a/gcp/workers/alias/alias_computation_test.py +++ b/gcp/workers/alias/alias_computation_test.py @@ -17,7 +17,7 @@ import unittest from google.cloud import ndb -from google.protobuf import timestamp_pb2 +from google.protobuf import json_format, timestamp_pb2 import osv import alias_computation @@ -30,8 +30,28 @@ class AliasTest(unittest.TestCase, tests.ExpectationTest(TEST_DATA_DIR)): """Alias tests.""" + def _get_aliases_from_bucket(self, vuln_id): + """Get the aliases from the vulnerabilities written to the GCS bucket.""" + bucket = osv.gcs.get_osv_bucket() + pb_blob = bucket.blob(os.path.join(osv.gcs.VULN_PB_PATH, vuln_id + '.pb')) + pb = osv.vulnerability_pb2.Vulnerability.FromString( + pb_blob.download_as_bytes()) + pb_aliases = list(pb.aliases) + + json_blob = bucket.blob( + os.path.join(osv.gcs.VULN_JSON_PATH, vuln_id + '.json')) + pb = json_format.Parse(json_blob.download_as_bytes(), + osv.vulnerability_pb2.Vulnerability()) + json_aliases = list(pb.aliases) + + return pb_aliases, json_aliases + def test_basic(self): """Tests basic case.""" + osv.AliasGroup( + bug_ids=['aaa-123', 'aaa-124'], + last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), + ).put() osv.Bug( id='aaa-123', db_id='aaa-123', @@ -49,15 +69,15 @@ def test_basic(self): public=True, import_last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), ).put() - osv.AliasGroup( - bug_ids=['aaa-123', 'aaa-124'], - last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), - ).put() alias_computation.main() bug_ids = osv.AliasGroup.query( osv.AliasGroup.bug_ids == 'aaa-123').get().bug_ids self.assertEqual(['aaa-123', 'aaa-124'], bug_ids) + pb_aliases, json_aliases = self._get_aliases_from_bucket('aaa-123') + self.assertEqual(['aaa-124'], pb_aliases) + self.assertEqual(['aaa-124'], json_aliases) + def test_bug_reaches_limit(self): """Tests bug reaches limit.""" osv.Bug( @@ -76,6 +96,10 @@ def test_bug_reaches_limit(self): osv.AliasGroup.bug_ids == 'aaa-111').get() self.assertIsNone(alias_group) + pb_aliases, json_aliases = self._get_aliases_from_bucket('aaa-111') + self.assertEqual([], pb_aliases) + self.assertEqual([], json_aliases) + def test_update_alias_group(self): """Tests updating an existing alias group.""" osv.AliasGroup( @@ -112,12 +136,16 @@ def test_update_alias_group(self): alias_computation.main() alias_group = osv.AliasGroup.query( osv.AliasGroup.bug_ids == 'bbb-123').get() - self.assertEqual(['bbb-123', 'bbb-234', 'bbb-345', 'bbb-456', 'bbb-789'], - alias_group.bug_ids) + expected_aliases = ['bbb-234', 'bbb-345', 'bbb-456', 'bbb-789'] + self.assertEqual(['bbb-123'] + expected_aliases, alias_group.bug_ids) self.assertNotEqual( datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), alias_group.last_modified) + pb_aliases, json_aliases = self._get_aliases_from_bucket('bbb-123') + self.assertEqual(expected_aliases, pb_aliases) + self.assertEqual(expected_aliases, json_aliases) + def test_create_alias_group(self): """Tests adding a new alias group.""" osv.Bug( @@ -144,8 +172,16 @@ def test_create_alias_group(self): self.assertIsNotNone(alias_group) self.assertEqual(['test-123', 'test-124', 'test-222'], alias_group.bug_ids) + pb_aliases, json_aliases = self._get_aliases_from_bucket('test-123') + self.assertEqual(['test-124', 'test-222'], pb_aliases) + self.assertEqual(['test-124', 'test-222'], json_aliases) + def test_delete_alias_group(self): """Tests deleting alias groups that only has one vulnerability.""" + osv.AliasGroup( + bug_ids=['ccc-123'], + last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), + ).put() osv.Bug( id='ccc-123', db_id='ccc-123', @@ -154,19 +190,23 @@ def test_delete_alias_group(self): public=True, import_last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), ).put() - osv.AliasGroup( - bug_ids=['ccc-123'], - last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), - ).put() alias_computation.main() alias_group = osv.AliasGroup.query( osv.AliasGroup.bug_ids == 'ccc-123').get() self.assertIsNone(alias_group) + pb_aliases, json_aliases = self._get_aliases_from_bucket('ccc-123') + self.assertEqual([], pb_aliases) + self.assertEqual([], json_aliases) + def test_split_alias_group(self): """Tests split an existing alias group into two. AliasGroup A -> B -> C -> D, remove the B -> C alias to get AliasGroups A -> B and C -> D.""" + osv.AliasGroup( + bug_ids=['ddd-123', 'ddd-124', 'ddd-125', 'ddd-126'], + last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), + ).put() osv.Bug( id='ddd-123', db_id='ddd-123', @@ -201,10 +241,6 @@ def test_split_alias_group(self): public=True, import_last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), ).put() - osv.AliasGroup( - bug_ids=['ddd-123', 'ddd-124', 'ddd-125', 'ddd-126'], - last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), - ).put() alias_computation.main() alias_group = osv.AliasGroup.query( osv.AliasGroup.bug_ids == 'ddd-123').get() @@ -215,27 +251,44 @@ def test_split_alias_group(self): self.assertIsNotNone(alias_group) self.assertEqual(['ddd-125', 'ddd-126'], alias_group.bug_ids) + pb_aliases, json_aliases = self._get_aliases_from_bucket('ddd-123') + self.assertEqual(['ddd-124'], pb_aliases) + self.assertEqual(['ddd-124'], json_aliases) + pb_aliases, json_aliases = self._get_aliases_from_bucket('ddd-125') + self.assertEqual(['ddd-126'], pb_aliases) + self.assertEqual(['ddd-126'], json_aliases) + def test_allow_list(self): """Test allow list.""" + osv.AliasAllowListEntry(bug_id='eee-111',).put() + raw_aliases = [ + 'eee-222', 'eee-333', 'eee-444', 'eee-555', 'eee-666', 'eee-777' + ] osv.Bug( id='eee-111', db_id='eee-111', - aliases=[ - 'eee-222', 'eee-333', 'eee-444', 'eee-555', 'eee-666', 'eee-777' - ], + aliases=raw_aliases, status=1, source='test', public=True, import_last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), ).put() - osv.AliasAllowListEntry(bug_id='eee-111',).put() alias_computation.main() alias_group = osv.AliasGroup.query( osv.AliasGroup.bug_ids == 'eee-111').get() self.assertEqual(7, len(alias_group.bug_ids)) + pb_aliases, json_aliases = self._get_aliases_from_bucket('eee-111') + self.assertEqual(raw_aliases, pb_aliases) + self.assertEqual(raw_aliases, json_aliases) + def test_deny_list(self): """Tests deny list.""" + osv.AliasGroup( + bug_ids=['fff-124', 'fff-125'], + last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), + ).put() + osv.AliasDenyListEntry(bug_id='fff-123',).put() osv.Bug( id='fff-123', db_id='fff-123', @@ -254,19 +307,26 @@ def test_deny_list(self): public=True, import_last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), ).put() - osv.AliasGroup( - bug_ids=['fff-124', 'fff-125'], - last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), - ).put() - osv.AliasDenyListEntry(bug_id='fff-123',).put() alias_computation.main() bug_ids = osv.AliasGroup.query( osv.AliasGroup.bug_ids == 'fff-124').get().bug_ids self.assertEqual(['fff-124', 'fff-125'], bug_ids) + pb_aliases, json_aliases = self._get_aliases_from_bucket('fff-124') + self.assertEqual(['fff-125'], pb_aliases) + self.assertEqual(['fff-125'], json_aliases) + def test_merge_alias_group(self): """Tests all bugs of one alias group have been merged to other alias group.""" + osv.AliasGroup( + bug_ids=['ggg-123', 'ggg-124'], + last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), + ).put() + osv.AliasGroup( + bug_ids=['ggg-125', 'ggg-126'], + last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), + ).put() osv.Bug( id='ggg-123', db_id='ggg-123', @@ -276,13 +336,14 @@ def test_merge_alias_group(self): public=True, import_last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), ).put() - osv.AliasGroup( - bug_ids=['ggg-123', 'ggg-124'], - last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), - ).put() - osv.AliasGroup( - bug_ids=['ggg-125', 'ggg-126'], - last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), + osv.Bug( + id='ggg-125', + db_id='ggg-125', + aliases=[], + status=1, + source='test', + public=True, + import_last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), ).put() alias_computation.main() alias_group = osv.AliasGroup.query( @@ -291,8 +352,20 @@ def test_merge_alias_group(self): self.assertEqual(['ggg-123', 'ggg-124', 'ggg-125', 'ggg-126'], alias_group[0].bug_ids) + pb_aliases, json_aliases = self._get_aliases_from_bucket('ggg-125') + self.assertEqual(['ggg-123', 'ggg-124', 'ggg-126'], pb_aliases) + self.assertEqual(['ggg-123', 'ggg-124', 'ggg-126'], json_aliases) + def test_partial_merge_alias_group(self): """Tests merging some bugs of one alias group to another alias group.""" + osv.AliasGroup( + bug_ids=['hhh-123', 'hhh-124'], + last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), + ).put() + osv.AliasGroup( + bug_ids=['hhh-125', 'hhh-126', 'hhh-127'], + last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), + ).put() osv.Bug( id='hhh-123', db_id='hhh-123', @@ -311,13 +384,23 @@ def test_partial_merge_alias_group(self): public=True, import_last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), ).put() - osv.AliasGroup( - bug_ids=['hhh-123', 'hhh-124'], - last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), + osv.Bug( + id='hhh-125', + db_id='hhh-125', + aliases=[], + status=1, + source='test', + public=True, + import_last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), ).put() - osv.AliasGroup( - bug_ids=['hhh-125', 'hhh-126', 'hhh-127'], - last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), + osv.Bug( + id='hhh-127', + db_id='hhh-127', + aliases=[], + status=1, + source='test', + public=True, + import_last_modified=datetime.datetime(2023, 1, 1, tzinfo=datetime.UTC), ).put() alias_computation.main() alias_group = osv.AliasGroup.query( @@ -329,6 +412,13 @@ def test_partial_merge_alias_group(self): self.assertEqual(1, len(alias_group)) self.assertEqual(['hhh-126', 'hhh-127'], alias_group[0].bug_ids) + pb_aliases, json_aliases = self._get_aliases_from_bucket('hhh-125') + self.assertEqual(['hhh-123', 'hhh-124'], pb_aliases) + self.assertEqual(['hhh-123', 'hhh-124'], json_aliases) + pb_aliases, json_aliases = self._get_aliases_from_bucket('hhh-127') + self.assertEqual(['hhh-126'], pb_aliases) + self.assertEqual(['hhh-126'], json_aliases) + def test_alias_group_reaches_limit(self): """Tests a alias group reaches limit.""" aliases = [] @@ -349,6 +439,10 @@ def test_alias_group_reaches_limit(self): alias_group = osv.AliasGroup.query(osv.AliasGroup.bug_ids == 'iii-1').get() self.assertIsNone(alias_group) + pb_aliases, json_aliases = self._get_aliases_from_bucket('iii-1') + self.assertEqual([], pb_aliases) + self.assertEqual([], json_aliases) + def test_to_vulnerability(self): """Tests OSV bug to vulnerability function.""" bug = osv.Bug( diff --git a/gcp/workers/alias/upstream_computation.py b/gcp/workers/alias/upstream_computation.py index b7709710449..60a58ad261e 100644 --- a/gcp/workers/alias/upstream_computation.py +++ b/gcp/workers/alias/upstream_computation.py @@ -14,14 +14,16 @@ # limitations under the License. """OSV Upstream relation computation.""" +from collections import defaultdict import datetime +import json +import logging + from google.cloud import ndb import osv import osv.logs -import json -import logging -from collections import defaultdict +from osv import gcs def compute_upstream(target_bug, bugs: dict[str, set[str]]) -> list[str]: @@ -51,7 +53,7 @@ def compute_upstream(target_bug, bugs: dict[str, set[str]]) -> list[str]: return sorted(visited) -def _create_group(bug_id, upstream_ids) -> osv.UpstreamGroup: +def _create_group(bug_id: str, upstream_ids: list[str]) -> osv.UpstreamGroup: """Creates a new upstream group in the datastore.""" new_group = osv.UpstreamGroup( @@ -60,17 +62,19 @@ def _create_group(bug_id, upstream_ids) -> osv.UpstreamGroup: upstream_ids=upstream_ids, last_modified=datetime.datetime.now(datetime.UTC)) new_group.put() + _update_vuln_with_group(bug_id, new_group) return new_group def _update_group(upstream_group: osv.UpstreamGroup, - upstream_ids: list) -> osv.UpstreamGroup | None: + upstream_ids: list[str]) -> osv.UpstreamGroup | None: """Updates the upstream group in the datastore.""" if len(upstream_ids) == 0: logging.info('Deleting upstream group due to too few bugs: %s', upstream_ids) upstream_group.key.delete() + _update_vuln_with_group(upstream_group.db_id, None) return None if upstream_ids == upstream_group.upstream_ids: @@ -79,9 +83,51 @@ def _update_group(upstream_group: osv.UpstreamGroup, upstream_group.upstream_ids = upstream_ids upstream_group.last_modified = datetime.datetime.now(datetime.UTC) upstream_group.put() + _update_vuln_with_group(upstream_group.db_id, upstream_group) return upstream_group +def _update_vuln_with_group(vuln_id: str, upstream: osv.UpstreamGroup | None): + """Updates the Vulnerability in Datastore & GCS with the new upstream group. + If `upstream` is None, assumes a preexisting UpstreamGroup was just deleted. + """ + # TODO(michaelkedar): Currently, only want to run this on the test instance + # (or when running tests). Remove this check when we're ready for prod. + project = getattr(ndb.get_context().client, 'project') + if not project: + logging.error('failed to get GCP project from ndb.Client') + if project not in ('oss-vdb-test', 'test-osv'): + return + # Get the existing vulnerability first, so we can recalculate search_indices + result = gcs.get_by_id_with_generation(vuln_id) + if result is None: + logging.error('vulnerability not in GCS - %s', vuln_id) + # TODO(michaelkedar): send pub/sub message to reimport + return + vuln_proto, generation = result + + def transaction(): + vuln: osv.Vulnerability = osv.Vulnerability.get_by_id(vuln_id) + if vuln is None: + logging.error('vulnerability not in Datastore - %s', vuln_id) + # TODO: Raise exception + return + if upstream is None: + modified = datetime.datetime.now(datetime.UTC) + upstream_group = [] + else: + modified = upstream.last_modified + upstream_group = upstream.upstream_ids + vuln_proto.upstream[:] = upstream_group + vuln_proto.modified.FromDatetime(modified) + osv.ListedVulnerability.from_vulnerability(vuln_proto).put() + vuln.modified = modified + vuln.put() + + ndb.transaction(transaction) + gcs.upload_vulnerability(vuln_proto, generation) + + def compute_upstream_hierarchy( target_upstream_group: osv.UpstreamGroup, all_upstream_groups: dict[str, osv.UpstreamGroup]) -> None: diff --git a/gcp/workers/alias/upstream_computation_test.py b/gcp/workers/alias/upstream_computation_test.py index 0c87d49cadc..c7e2de6eb66 100644 --- a/gcp/workers/alias/upstream_computation_test.py +++ b/gcp/workers/alias/upstream_computation_test.py @@ -13,12 +13,14 @@ # limitations under the License. """Upstream computation tests.""" import datetime +import json +import logging import os import unittest -import logging + from google.cloud import ndb +from google.protobuf import json_format -import json import osv import upstream_computation from osv import tests @@ -387,6 +389,22 @@ def test_incomplete_compute_upstream(self): bug_ids = upstream_computation.compute_upstream(bugs.get('VULN-4'), bugs) self.assertEqual(['VULN-1', 'VULN-3'], bug_ids) + def _get_upstreams_from_bucket(self, vuln_id): + """Get the upstreams from the vulnerabilities written to the GCS bucket.""" + bucket = osv.gcs.get_osv_bucket() + pb_blob = bucket.blob(os.path.join(osv.gcs.VULN_PB_PATH, vuln_id + '.pb')) + pb = osv.vulnerability_pb2.Vulnerability.FromString( + pb_blob.download_as_bytes()) + pb_upstream = list(pb.upstream) + + json_blob = bucket.blob( + os.path.join(osv.gcs.VULN_JSON_PATH, vuln_id + '.json')) + pb = json_format.Parse(json_blob.download_as_bytes(), + osv.vulnerability_pb2.Vulnerability()) + json_upstream = list(pb.upstream) + + return pb_upstream, json_upstream + def test_upstream_group_basic(self): """Test the upstream group get by db_id""" upstream_computation.main() @@ -399,6 +417,10 @@ def test_upstream_group_basic(self): osv.UpstreamGroup.db_id == 'CVE-3').get().upstream_ids self.assertEqual(['CVE-1', 'CVE-2'], bug_ids) + pb_upstream, json_upstream = self._get_upstreams_from_bucket('CVE-3') + self.assertEqual(['CVE-1', 'CVE-2'], pb_upstream) + self.assertEqual(['CVE-1', 'CVE-2'], json_upstream) + def test_upstream_group_complex(self): """Testing more complex, realworld case""" upstream_ids = [ @@ -414,6 +436,10 @@ def test_upstream_group_complex(self): self.assertEqual(upstream_ids, bug_ids) + pb_upstream, json_upstream = self._get_upstreams_from_bucket('USN-7234-3') + self.assertEqual(upstream_ids, pb_upstream) + self.assertEqual(upstream_ids, json_upstream) + def test_upstream_hierarchy_computation(self): upstream_computation.main() bug_ids = osv.UpstreamGroup.query( diff --git a/gcp/workers/worker/worker_test.py b/gcp/workers/worker/worker_test.py index 5d968c8b04d..de9a5da23d5 100644 --- a/gcp/workers/worker/worker_test.py +++ b/gcp/workers/worker/worker_test.py @@ -702,8 +702,11 @@ def test_update_redhat_toobig(self): task_runner._source_update(message) self.assertIn( - "ERROR:root:Unexpected exception while writing RHSA-2018:3140 to Datastore", + 'ERROR:root:Not writing new entities for RHSA-2018:3140 since Bug.put() failed', logs.output[0]) + self.assertIn( + 'ERROR:root:Unexpected exception while writing RHSA-2018:3140 to Datastore', + logs.output[1]) self.mock_publish.assert_not_called() diff --git a/osv/gcs.py b/osv/gcs.py new file mode 100644 index 00000000000..a86fa4c7ff5 --- /dev/null +++ b/osv/gcs.py @@ -0,0 +1,92 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Helpers for interacting with the OSV vulnerabilities in the GCS bucket.""" +import datetime +import logging +import os + +from google.cloud import storage +from google.protobuf import json_format + +from .vulnerability_pb2 import Vulnerability + +VULN_JSON_PATH = 'all/json/' +VULN_PB_PATH = 'all/pb/' + +_storage_client = None + + +def _get_storage_client() -> storage.Client: + global _storage_client + if _storage_client is None: + _storage_client = storage.Client() + return _storage_client + + +def get_osv_bucket() -> storage.Bucket: + """Gets the osv vulnerability bucket, from OSV_VULNERABILITIES_BUCKET""" + try: + bucket = os.environ['OSV_VULNERABILITIES_BUCKET'] + bucket = bucket.lstrip('gs://') + except KeyError: + logging.error('OSV_VULNERABILITIES_BUCKET environment variable not set') + raise + client = _get_storage_client() + return client.bucket(bucket) + + +def get_by_id_with_generation(vuln_id: str) -> tuple[Vulnerability, int] | None: + """Gets the OSV record, and the object's generation from the GCS bucket. + Returns None if the record is not found. + """ + bucket = get_osv_bucket() + pb_blob = bucket.get_blob(os.path.join(VULN_PB_PATH, vuln_id + '.pb')) + if pb_blob is None: + return None + try: + vuln = Vulnerability.FromString(pb_blob.download_as_bytes()) + return vuln, pb_blob.generation + except Exception: + logging.exception('failed to download %s protobuf from GCS', vuln_id) + raise + + +def upload_vulnerability(vulnerability: Vulnerability, + pb_generation: int | None = None): + """Uploads the OSV record to the GCS bucket. + If set, checks if the existing proto blob's generation matches pb_generation + before uploading.""" + bucket = get_osv_bucket() + vuln_id = vulnerability.id + modified = vulnerability.modified.ToDatetime(datetime.UTC) + try: + pb_blob = bucket.blob(os.path.join(VULN_PB_PATH, vuln_id + '.pb')) + pb_blob.custom_time = modified + pb_blob.upload_from_string( + vulnerability.SerializeToString(deterministic=True), + content_type='application/octet-stream', + if_generation_match=pb_generation) + except Exception: + logging.exception('failed to upload %s protobuf to GCS', vuln_id) + # TODO(michaelkedar): send pub/sub message to retry + + try: + json_blob = bucket.blob(os.path.join(VULN_JSON_PATH, vuln_id + '.json')) + json_blob.custom_time = modified + json_data = json_format.MessageToJson( + vulnerability, preserving_proto_field_name=True, indent=None) + json_blob.upload_from_string(json_data, content_type='application/json') + except Exception: + logging.exception('failed to upload %s protobuf to GCS', vuln_id) + # TODO(michaelkedar): send pub/sub message to retry diff --git a/osv/gcs_mock.py b/osv/gcs_mock.py new file mode 100644 index 00000000000..68068511379 --- /dev/null +++ b/osv/gcs_mock.py @@ -0,0 +1,96 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functions for mocking the GCS bucket for testing.""" +import contextlib +import datetime +import os +import pathlib +import tempfile +from typing import Any +from unittest import mock + +from google.cloud import exceptions + +from . import gcs + + +@contextlib.contextmanager +def gcs_mock(directory: str | None = None): + """A context for mocking reads/writes to the vulnerabilities GCS bucket. + + If `directory` is set, blobs will be read from/written to files in the + directory, which will remain after the context exits. + Otherwise, blobs will be written to a temporary directory, which is deleted + when the context exits. + """ + with (tempfile.TemporaryDirectory() + if directory is None else contextlib.nullcontext(directory)) as db_dir: + pathlib.Path(db_dir, gcs.VULN_JSON_PATH).mkdir(parents=True, exist_ok=True) + pathlib.Path(db_dir, gcs.VULN_PB_PATH).mkdir(parents=True, exist_ok=True) + bucket = _MockBucket(db_dir) + with mock.patch('osv.gcs.get_osv_bucket', return_value=bucket): + yield db_dir + + +class _MockBlob: + """Mock google.cloud.storage.Blob with only necessary methods for tests.""" + + def __init__(self, path: str): + self._path = path + self.custom_time: datetime.datetime | None = None + + def upload_from_string(self, + data: str | bytes, + content_type: str | None = None, + if_generation_match: Any | None = None): + """Implements google.cloud.storage.Blob.upload_from_string.""" + del content_type # Can't do anything with this. + + if if_generation_match not in (None, 1): + raise exceptions.PreconditionFailed('Generation mismatch') + + if isinstance(data, str): + data = data.encode() + with open(self._path, 'wb') as f: + f.write(data) + + # Use the file's modified time to store the CustomTime metadata. + if self.custom_time is not None: + ts = self.custom_time.timestamp() + os.utime(self._path, (ts, ts)) + + def download_as_bytes(self) -> bytes: + """Implements google.cloud.storage.Blob.download_as_bytes.""" + with open(self._path, 'rb') as f: + return f.read() + + +class _MockBucket: + """Mock google.cloud.storage.Bucket with only necessary methods for tests.""" + + def __init__(self, db_dir: str): + self._db_dir = db_dir + + def blob(self, blob_name: str) -> _MockBlob: + return _MockBlob(os.path.join(self._db_dir, blob_name)) + + def get_blob(self, blob_name: str) -> _MockBlob | None: + path = os.path.join(self._db_dir, blob_name) + if not os.path.exists(path): + return None + blob = _MockBlob(path) + ts = os.path.getmtime(path) + blob.custom_time = datetime.datetime.fromtimestamp(ts, datetime.UTC) + blob.generation = 1 + return blob diff --git a/osv/models.py b/osv/models.py index 9449a9fd1cd..d8bb5827461 100644 --- a/osv/models.py +++ b/osv/models.py @@ -23,13 +23,13 @@ from typing import Self from google.cloud import ndb -from google.protobuf import json_format -from google.protobuf import timestamp_pb2 +from google.protobuf import json_format, timestamp_pb2 from osv import importfinding_pb2 # pylint: disable=relative-beyond-top-level from . import bug from . import ecosystems +from . import gcs from . import purl_helpers from . import semver_index from . import sources @@ -39,6 +39,13 @@ _MAX_GIT_VERSIONS_TO_INDEX = 5000 +_EVENT_ORDER = { + 'introduced': 0, + 'last_affected': 1, + 'fixed': 2, + 'limit': 3, +} + def _check_valid_severity(prop, value): """Check valid severity.""" @@ -108,26 +115,15 @@ def _maybe_strip_repo_prefixes(versions: list[str], return repo_stripped_versions +# --- OSS-Fuzz-related Entities --- + + class IDCounter(ndb.Model): """Counter for ID allocations.""" # Next ID to allocate. next_id: int = ndb.IntegerProperty() -class AffectedCommits(ndb.Model): - """AffectedCommits entry.""" - MAX_COMMITS_PER_ENTITY = 10000 - - # The main bug ID. - bug_id: str = ndb.StringProperty() - # The commit hash. - commits: list[bytes] = ndb.BlobProperty(repeated=True, indexed=True) - # Whether or not the bug is public. - public: bool = ndb.BooleanProperty() - # The page for this batch of commits. - page: int = ndb.IntegerProperty(indexed=False) - - class RegressResult(ndb.Model): """Regression results.""" # The commit hash. @@ -180,6 +176,9 @@ class FixResult(ndb.Model): timestamp: datetime.datetime = ndb.DateTimeProperty(tzinfo=datetime.UTC) +# --- OSV Bug entities --- + + class AffectedEvent(ndb.Model): """Affected event.""" type: str = ndb.StringProperty(validator=_check_valid_event_type) @@ -359,38 +358,11 @@ def get_by_id(cls, vuln_id, *args, **kwargs) -> Self | None: return super().get_by_id(vuln_id, *args, **kwargs) - def _tokenize(self, value): - """Tokenize value for indexing.""" - if not value: - return [] - - value_lower = value.lower() - - # Deconstructs the id given into parts by retrieving parts that are - # alphanumeric. - # This addresses special cases like SUSE that include ':' in their id suffix - tokens = {token for token in re.split(r'\W+', value_lower) if token} - tokens.add(value_lower) - - # Add subsection combinations from id (split at '-') in the search indices - # Specifically addresses situation in which UBUNTU-CVE-XXXs don't show up - # when searching for the CVE-XXX. - # e.g. `a-b-c-d' becomes ['a-b', 'b-c', 'c-d', 'a-b-c', 'b-c-d', 'a-b-c-d'] - # Does not account for combinations with the suffix sections ':' like SUSE - parts = value_lower.split('-') - num_parts = len(parts) - for length in range(2, num_parts + 1): - for i in range(num_parts - length + 1): - sub_parts = parts[i:i + length] - combo = '-'.join(sub_parts) - tokens.add(combo) - return tokens - def _pre_put_hook(self): # pylint: disable=arguments-differ """Pre-put hook for populating search indices.""" search_indices = set() - search_indices.update(self._tokenize(self.id())) + search_indices.update(_tokenize(self.id())) for pkg in self.affected_packages: # Set PURL if it wasn't provided. @@ -441,18 +413,18 @@ def _pre_put_hook(self): # pylint: disable=arguments-differ self.purl.sort() for project in self.project: - search_indices.update(self._tokenize(project)) + search_indices.update(_tokenize(project)) for ecosystem in self.ecosystem: - search_indices.update(self._tokenize(ecosystem)) + search_indices.update(_tokenize(ecosystem)) for alias in self.aliases: - search_indices.update(self._tokenize(alias)) + search_indices.update(_tokenize(alias)) # Please note this will not include exhaustive transitive upstream # so may not appear for all cases. for upstream in self.upstream_raw: - search_indices.update(self._tokenize(upstream)) + search_indices.update(_tokenize(upstream)) for affected_package in self.affected_packages: for affected_range in affected_package.ranges: @@ -673,6 +645,8 @@ def to_vulnerability_minimal_async(self, modified_times.append(self.last_modified) # Fetch the last_modified dates from the upstream/alias groups. + # TODO(michaelkedar): modified time needs to update if related changes. + # Probably requires a RelatedGroup entity/cron alias_future = get_aliases_async(self.id()) if include_alias else None upstream_future = ( get_upstream_async(self.id()) if include_upstream else None) @@ -769,8 +743,11 @@ def to_vulnerability(self, else: withdrawn = None - published = timestamp_pb2.Timestamp() - published.FromDatetime(self.timestamp) + if self.timestamp: + published = timestamp_pb2.Timestamp() + published.FromDatetime(self.timestamp) + else: + published = None references = [] if self.reference_url_types: @@ -801,6 +778,8 @@ def to_vulnerability(self, Bug.related == self.db_id, projection=[Bug.db_id]).fetch() related_bug_ids = [bug.db_id for bug in related_bugs] related = sorted(list(set(related_bug_ids + self.related))) + # TODO(michaelkedar): modified time needs to update if related changes. + # Probably requires a RelatedGroup entity/cron alias_group = AliasGroup.query(AliasGroup.bug_ids == self.db_id).get() if alias_group: @@ -875,6 +854,8 @@ def to_vulnerability_async(self, related_bug_ids = yield related_future vulnerability.related[:] = sorted( list(set(related_bug_ids + list(vulnerability.related)))) + # TODO(michaelkedar): modified time needs to update if related changes. + # Probably requires a RelatedGroup entity/cron alias_group = yield alias_future if alias_group: alias_ids = sorted(list(set(alias_group.bug_ids) - {vulnerability.id})) @@ -892,6 +873,380 @@ def to_vulnerability_async(self, vulnerability.modified.FromDatetime(modified_time) return vulnerability + def _post_put_hook(self: Self, future: ndb.Future): # pylint: disable=arguments-differ + """Post-put hook for writing new entities for database migration.""" + # TODO(michaelkedar): Currently, only want to run this on the test instance + # (or when running tests). Remove this check when we're ready for prod. + # To get the current GCP project without relying on environment variables + # that may not be set, grab the project name from the undocumented(?) field + # on the ndb.Client, which we find from the current context. + project = getattr(ndb.get_context().client, 'project') + if not project: + logging.error('failed to get GCP project from ndb.Client') + if project not in ('oss-vdb-test', 'test-osv'): + return + if future.exception(): + logging.error("Not writing new entities for %s since Bug.put() failed", + self.db_id) + return + populate_entities_from_bug(self) + + +def _tokenize(value): + """Tokenize value for indexing.""" + if not value: + return [] + + value_lower = value.lower() + + # Deconstructs the id given into parts by retrieving parts that are + # alphanumeric. + # This addresses special cases like SUSE that include ':' in their id suffix + tokens = {token for token in re.split(r'\W+', value_lower) if token} + tokens.add(value_lower) + + # Add subsection combinations from id (split at '-') in the search indices + # Specifically addresses situation in which UBUNTU-CVE-XXXs don't show up + # when searching for the CVE-XXX. + # e.g. `a-b-c-d' becomes ['a-b', 'b-c', 'c-d', 'a-b-c', 'b-c-d', 'a-b-c-d'] + # Does not account for combinations with the suffix sections ':' like SUSE + parts = value_lower.split('-') + num_parts = len(parts) + for length in range(2, num_parts + 1): + for i in range(num_parts - length + 1): + sub_parts = parts[i:i + length] + combo = '-'.join(sub_parts) + tokens.add(combo) + return tokens + + +# --- Vulnerability Entity --- + + +class Vulnerability(ndb.Model): + """A Vulnerability entry. + + Contains a minimal amount of information of an OSV record, including the + overall modified date, an some raw fields that are overwritten by our + enrichment. + + The entity's key/id is ID in OSV. + """ + + # The source identifier. + # For OSS-Fuzz, this oss-fuzz:. + # For others this is :. + source_id: str = ndb.StringProperty() + # When this record was truly last modified (including e.g. aliases/upstream). + modified: datetime.datetime = ndb.DateTimeProperty(tzinfo=datetime.UTC) + # Whether this record has been withdrawn + # TODO(michaelkedar): I don't think this is necessary + is_withdrawn: bool = ndb.BooleanProperty() + + # Raw fields from the original source. + # The reported modified date in the record. + modified_raw: datetime.datetime = ndb.DateTimeProperty(tzinfo=datetime.UTC) + # The reported aliased IDs. + alias_raw: list[str] = ndb.StringProperty(repeated=True) + # The reported related IDs. + related_raw: list[str] = ndb.StringProperty(repeated=True) + # The reported upstream IDs. + upstream_raw: list[str] = ndb.StringProperty(repeated=True) + + +# --- Affected versions for matching --- + + +class AffectedCommits(ndb.Model): + """AffectedCommits entry.""" + MAX_COMMITS_PER_ENTITY = 10000 + + # The main bug ID. + bug_id: str = ndb.StringProperty() + # The commit hash. + commits: list[bytes] = ndb.BlobProperty(repeated=True, indexed=True) + # Whether or not the bug is public. + public: bool = ndb.BooleanProperty() + # The page for this batch of commits. + page: int = ndb.IntegerProperty(indexed=False) + + +class AffectedVersions(ndb.Model): + """AffectedVersions entry, used for finding matching vulnerabilities within + the OSV API.""" + # The main vulnerability ID. + vuln_id: str = ndb.StringProperty() + # The ecosystem of the affected package. + ecosystem: str = ndb.StringProperty() + # The name of the affected package. + name: str = ndb.StringProperty() + + # Only one of the following should be set: + # The enumerated affected versions. + versions: list[str] = ndb.TextProperty(repeated=True) + # The sorted affected events. + events: list[AffectedEvent] = ndb.LocalStructuredProperty( + AffectedEvent, repeated=True) + + def sort_key(self): + """Key function for comparison and deduplication.""" + return (self.vuln_id, self.ecosystem, self.name, tuple(self.versions), + tuple((e.type, e.value) for e in self.events)) + + +# --- Website search / list entity --- + + +class ListedVulnerability(ndb.Model): + """ListedVulnerability entry, used for the website's /list page.""" + # The entity's key/id is ID in OSV + + # The date the vulnerability was published (for sorting & display). + published: datetime.datetime = ndb.DateTimeProperty(tzinfo=datetime.UTC) + # The ecosystems the vulnerability belongs to (for filtering). + ecosystems: list[str] = ndb.StringProperty(repeated=True) + # The list of rendered affected packages (for display). + # e.g. 'PyPI/urllib3', 'github.com/torvalds/linux' + packages: list[str] = ndb.TextProperty(repeated=True) + # The summary line (for display). + summary: str = ndb.TextProperty() + # Whether there is a fix available (for display). + is_fixed: bool = ndb.BooleanProperty(indexed=False) + # The severities of the vulnerability (for display). + severities: list[Severity] = ndb.LocalStructuredProperty( + Severity, repeated=True) + + # Strings that the search bar may suggest while typing. + autocomplete_tags: list[str] = ndb.StringProperty(repeated=True) + # Strings this matches when searching. + search_indices: list[str] = ndb.StringProperty(repeated=True) + + @classmethod + def from_vulnerability( + cls: Self, vulnerability: vulnerability_pb2.Vulnerability) -> Self: + """Construct a ListedVulnerability from a complete vulnerability proto""" + published = vulnerability.published.ToDatetime(datetime.UTC) + summary = vulnerability.summary + # TODO(michaelkedar): Take the first line of details if summary is missing. + + all_ecosystems = set() + all_packages = set() + + is_fixed = False + severities = set() + for sev in vulnerability.severity: + severities.add( + (vulnerability_pb2.Severity.Type.Name(sev.type), sev.score)) + + search_indices = set() + search_indices.update(_tokenize(vulnerability.id)) + autocomplete_tags = {vulnerability.id.lower()} + + for alias in vulnerability.aliases: + search_indices.update(_tokenize(alias)) + for upstream in vulnerability.upstream: + search_indices.update(_tokenize(upstream)) + # related intentionally omitted + + for affected in vulnerability.affected: + if affected.package.name: + search_indices.update(_tokenize(affected.package.name)) + autocomplete_tags.add(affected.package.name.lower()) + all_packages.add(affected.package.ecosystem + '/' + + affected.package.name) + if affected.package.ecosystem: + all_ecosystems.add(affected.package.ecosystem) + for sev in affected.severity: + severities.add( + (vulnerability_pb2.Severity.Type.Name(sev.type), sev.score)) + for r in affected.ranges: + if r.type == vulnerability_pb2.Range.Type.GIT: + all_ecosystems.add('GIT') + search_indices.add(r.repo) + autocomplete_tags.add(r.repo.lower()) + split = r.repo.split('//') + if len(split) >= 2: + no_http = split[1] + all_packages.add(no_http) + search_indices.add(no_http) + # Add the path components exluding the domain name + search_indices.update(no_http.split('/')[1:]) + else: + all_packages.add(r.repo) + + if any(e.fixed or e.limit for e in r.events): + is_fixed = True + + for eco in all_ecosystems: + # TODO(michaelkedar): Seems like a noisy/useless search index? + search_indices.update(_tokenize(eco)) + if (e := ecosystems.remove_variants(eco)) is not None: + search_indices.update(_tokenize(e)) + + ecos = sorted({ecosystems.normalize(e) for e in all_ecosystems}) + pkgs = sorted(all_packages) + sevs = [Severity(type=t, score=s) for t, s in sorted(severities)] + search_indices = sorted(search_indices) + autocomplete_tags = sorted(autocomplete_tags) + + return cls( + id=vulnerability.id, + published=published, + ecosystems=ecos, + packages=pkgs, + summary=summary, + is_fixed=is_fixed, + severities=sevs, + autocomplete_tags=autocomplete_tags, + search_indices=search_indices, + ) + + +def populate_entities_from_bug(entity: Bug): + """Puts entities (Vulnerability, ListedVulnerability, AffectedVersions) from + a given Bug entity, and writes completed OSV records to GCS bucket.""" + if not entity.public or entity.status == bug.BugStatus.UNPROCESSED: + # OSS-Fuzz private Bugs + return + + vuln_pb = entity.to_vulnerability( + include_source=True, include_alias=True, include_upstream=True) + + def transaction(): + to_put = [] + to_delete = [] + vuln = Vulnerability.get_by_id(entity.db_id) + if vuln is None: + vuln = Vulnerability(id=entity.db_id) + if vuln.modified != vuln_pb.modified.ToDatetime(datetime.UTC): + vuln.source_id = entity.source_id + vuln.modified = vuln_pb.modified.ToDatetime(datetime.UTC) + vuln.is_withdrawn = entity.withdrawn is not None + vuln.modified_raw = entity.import_last_modified + vuln.alias_raw = entity.aliases + vuln.related_raw = entity.related + vuln.upstream_raw = entity.upstream_raw + to_put.append(vuln) + + old_affected = AffectedVersions.query( + AffectedVersions.vuln_id == entity.db_id).fetch() + if vuln.is_withdrawn: + # We do not want the vuln to be searchable if it's been withdrawn. + to_delete.append(ndb.Key(ListedVulnerability, vuln_pb.id)) + to_delete.extend(av.key for av in old_affected) + else: + to_put.append(ListedVulnerability.from_vulnerability(vuln_pb)) + new_affected = affected_from_bug(entity) + added, removed = diff_affected_versions(old_affected, new_affected) + to_put.extend(added) + to_delete.extend(r.key for r in removed) + + ndb.put_multi(to_put) + ndb.delete_multi(to_delete) + + ndb.transaction(transaction) + gcs.upload_vulnerability(vuln_pb) + + +def affected_from_bug(entity: Bug) -> list[AffectedVersions]: + """Compute the AffectedVersions from a Bug entity.""" + affected_versions = [] + for affected in entity.affected_packages: + pkg_ecosystem = affected.package.ecosystem + # Make sure we capture all possible ecosystem variants for matching. + # e.g. {'Ubuntu:22.04:LTS', 'Ubuntu:22.04', 'Ubuntu'} + all_pkg_ecosystems = {pkg_ecosystem, ecosystems.normalize(pkg_ecosystem)} + if (e := ecosystems.remove_variants(pkg_ecosystem)) is not None: + all_pkg_ecosystems.add(e) + + pkg_name = ecosystems.maybe_normalize_package_names(affected.package.name, + pkg_ecosystem) + + # Ecosystem helper for sorting the events. + e_helper = ecosystems.get(pkg_ecosystem) + if e_helper is not None and not (e_helper.supports_comparing or + e_helper.is_semver): + e_helper = None + + # TODO(michaelkedar): I am matching the current behaviour of the API, + # where GIT tags match to the first git repo in the ranges list, even if + # there are non-git ranges or multiple git repos in a range. + repo_url = '' + for r in affected.ranges: + if r.type == 'GIT': + if not repo_url: + repo_url = r.repo_url + continue + if r.type not in ('SEMVER', 'ECOSYSTEM'): + logging.warning('Unknown range type "%s" in %s', r.type, entity.db_id) + continue + + events = r.events + if e_helper is not None: + # If we have an ecosystem helper sort the events to help with querying. + events.sort(key=lambda e, sort_key=e_helper.sort_key: + (sort_key(e.value), _EVENT_ORDER.get(e.type, -1))) + # If we don't have an ecosystem helper, assume the events are in order. + for e in all_pkg_ecosystems: + affected_versions.append( + AffectedVersions( + vuln_id=entity.db_id, + ecosystem=e, + name=pkg_name, + events=events, + )) + + # Add the enumerated versions + if affected.versions: + if pkg_name: # We need at least a package name to perform matching. + for e in all_pkg_ecosystems: + affected_versions.append( + AffectedVersions( + vuln_id=entity.db_id, + ecosystem=e, + name=pkg_name, + versions=affected.versions, + )) + if repo_url: + affected_versions.append( + AffectedVersions( + vuln_id=entity.db_id, + ecosystem='GIT', + name=repo_url, + versions=affected.versions, + )) + + # Deduplicate and sort the affected_versions + unique_affected_dict = {av.sort_key(): av for av in affected_versions} + affected_versions = sorted( + unique_affected_dict.values(), key=AffectedVersions.sort_key) + + return affected_versions + + +def diff_affected_versions( + old: list[AffectedVersions], new: list[AffectedVersions] +) -> tuple[list[AffectedVersions], list[AffectedVersions]]: + """Find all the AffectedVersion entities that were added/removed from `old` to + get `new` (ignoring the entity IDs). + + returns (added, removed) + """ + all_dict = {av.sort_key(): av for av in old + new} + old_set = {av.sort_key() for av in old} + new_set = {av.sort_key() for av in new} + + added_keys = new_set - old_set + removed_keys = old_set - new_set + + added = [all_dict[k] for k in added_keys] + removed = [all_dict[k] for k in removed_keys] + + return added, removed + + +# --- Indexer entities --- + class RepoIndex(ndb.Model): """RepoIndex entry""" @@ -933,6 +1288,9 @@ class RepoIndexBucket(ndb.Model): files_contained: int = ndb.IntegerProperty() +# --- SourceRepository --- + + class SourceRepositoryType(enum.IntEnum): """SourceRepository type.""" GIT = 0 @@ -1009,6 +1367,9 @@ def _pre_put_hook(self): # pylint: disable=arguments-differ raise ValueError('BUCKET SourceRepository cannot be editable.') +# --- Alias & Upstream --- + + class AliasGroup(ndb.Model): """Alias group.""" bug_ids: list[str] = ndb.StringProperty(repeated=True) @@ -1040,6 +1401,7 @@ class UpstreamGroup(ndb.Model): last_modified: datetime.datetime = ndb.DateTimeProperty(tzinfo=datetime.UTC) +# --- ImportFinding --- # TODO(gongh@): redesign this to make it easy to scale. class ImportFindings(enum.IntEnum): """The possible quality findings about an individual record.""" diff --git a/osv/models_test.py b/osv/models_test.py new file mode 100644 index 00000000000..2ba33b285fb --- /dev/null +++ b/osv/models_test.py @@ -0,0 +1,411 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Models tests.""" + +import datetime +import json +import os +import unittest + +from . import models + +from . import bug +from . import gcs +from . import sources +from . import tests +from . import vulnerability_pb2 + +from google.cloud import ndb + + +class ModelsTest(unittest.TestCase): + """Tests for ndb Model migrations.""" + + def setUp(self): + models.SourceRepository( + id='test', + name='test', + db_prefix=['TEST-'], + ).put() + return super().setUp() + + def test_bug_post_put(self): + """Test _post_put_hook for Bug to populate new datastore/gcs entities.""" + vuln_id = 'TEST-123' + # Create a handmade populated Bug + models.AliasGroup( + bug_ids=sorted([vuln_id, 'CVE-123', 'OSV-123']), + last_modified=datetime.datetime(2025, 3, 4, tzinfo=datetime.UTC)).put() + models.UpstreamGroup( + db_id=vuln_id, + upstream_ids=['TEST-1', 'TEST-12'], + last_modified=datetime.datetime(2025, 3, 5, tzinfo=datetime.UTC)).put() + models.Bug( + db_id=vuln_id, + aliases=['CVE-123'], + related=['TEST-234'], + upstream_raw=['TEST-12'], + summary='This is a vuln', + severities=[ + models.Severity(type='CVSS_V2', score='AV:N/AC:L/Au:S/C:P/I:P/A:N') + ], + status=bug.BugStatus.PROCESSED, + timestamp=datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC), + last_modified=datetime.datetime(2025, 3, 3, tzinfo=datetime.UTC), + import_last_modified=datetime.datetime(2025, 2, 2, tzinfo=datetime.UTC), + source_id=f'test:{vuln_id}.json', + source_of_truth=models.SourceOfTruth.SOURCE_REPO, + public=True, + affected_packages=[ + models.AffectedPackage( + package=models.Package(ecosystem='npm', name='testjs'), + ranges=[ + models.AffectedRange2( + type='SEMVER', + events=[ + models.AffectedEvent(type='fixed', value='1.0.0'), + models.AffectedEvent(type='introduced', value='0'), + ]), + models.AffectedRange2( + type='SEMVER', + events=[ + models.AffectedEvent( + type='last_affected', value='2.2.0'), + models.AffectedEvent( + type='introduced', value='2.0.0'), + ]) + ], + versions=['0.1.0', '0.2.0', '0.3.0', '2.0.0', '2.1.0', + '2.2.0']), + models.AffectedPackage( + package=models.Package( + ecosystem='Ubuntu:24.04:LTS', name='test'), + ranges=[ + models.AffectedRange2( + type='ECOSYSTEM', + events=[ + models.AffectedEvent(type='introduced', value='0'), + models.AffectedEvent(type='fixed', value='1.0.0-3'), + ]) + ], + versions=['1.0.0-1', '1.0.0-2'], + severities=[models.Severity(type='Ubuntu', score='Low')]), + models.AffectedPackage( + package=models.Package(ecosystem='Ubuntu:25.04', name='test'), + ranges=[ + models.AffectedRange2( + type='ECOSYSTEM', + events=[ + models.AffectedEvent(type='introduced', value='0'), + models.AffectedEvent(type='fixed', value='1.0.0-3'), + ]) + ], + versions=['1.0.0-1', '1.0.0-2'], + severities=[models.Severity(type='Ubuntu', score='High')]), + models.AffectedPackage( + package=models.Package(ecosystem='', name=''), + ranges=[ + models.AffectedRange2( + type='GIT', repo_url='https://github.com/test/test') + ], + versions=['v1', 'v2']), + ], + ).put() + put_bug = models.Bug.get_by_id(vuln_id) + self.assertIsNotNone(put_bug) + put_bug: models.Bug + + # Check if new db entities were created. + vulnerability = models.Vulnerability.get_by_id(vuln_id) + self.assertIsNotNone(vulnerability) + vulnerability: models.Vulnerability + self.assertEqual('test:TEST-123.json', vulnerability.source_id) + self.assertEqual( + datetime.datetime(2025, 3, 5, tzinfo=datetime.UTC), + vulnerability.modified) + self.assertFalse(vulnerability.is_withdrawn) + self.assertEqual( + datetime.datetime(2025, 2, 2, tzinfo=datetime.UTC), + vulnerability.modified_raw) + self.assertListEqual(['CVE-123'], vulnerability.alias_raw) + self.assertListEqual(['TEST-234'], vulnerability.related_raw) + self.assertListEqual(['TEST-12'], vulnerability.upstream_raw) + + listed_vuln = models.ListedVulnerability.get_by_id(vuln_id) + self.assertIsNotNone(listed_vuln) + listed_vuln: models.ListedVulnerability + self.assertEqual( + datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC), + listed_vuln.published) + self.assertListEqual(['GIT', 'Ubuntu', 'npm'], listed_vuln.ecosystems) + self.assertListEqual([ + 'Ubuntu:24.04:LTS/test', 'Ubuntu:25.04/test', 'github.com/test/test', + 'npm/testjs' + ], listed_vuln.packages) + self.assertEqual('This is a vuln', listed_vuln.summary) + self.assertTrue(listed_vuln.is_fixed) + self.assertListEqual([ + models.Severity(type='CVSS_V2', score='AV:N/AC:L/Au:S/C:P/I:P/A:N'), + models.Severity(type='Ubuntu', score='High'), + models.Severity(type='Ubuntu', score='Low') + ], listed_vuln.severities) + self.assertListEqual( + ['https://github.com/test/test', 'test', 'test-123', 'testjs'], + listed_vuln.autocomplete_tags) + # search_indices should include all the original search indices, + # plus the transitive alias & upstream ids + search_indices = sorted(put_bug.search_indices + + ['osv-123', 'osv', 'test-1', '1']) + self.assertListEqual(search_indices, listed_vuln.search_indices) + + affected: list[models.AffectedVersions] = models.AffectedVersions.query( + models.AffectedVersions.vuln_id == vuln_id).fetch() + affected.sort(key=lambda x: x.sort_key()) + want = [ + models.AffectedVersions( + vuln_id=vuln_id, + ecosystem='GIT', + name='https://github.com/test/test', + versions=['v1', 'v2']), + models.AffectedVersions( + vuln_id=vuln_id, + ecosystem='Ubuntu', + name='test', + events=[ + models.AffectedEvent(type='introduced', value='0'), + models.AffectedEvent(type='fixed', value='1.0.0-3') + ]), + models.AffectedVersions( + vuln_id=vuln_id, + ecosystem='Ubuntu', + name='test', + versions=['1.0.0-1', '1.0.0-2']), + models.AffectedVersions( + vuln_id=vuln_id, + ecosystem='Ubuntu:24.04', + name='test', + events=[ + models.AffectedEvent(type='introduced', value='0'), + models.AffectedEvent(type='fixed', value='1.0.0-3') + ]), + models.AffectedVersions( + vuln_id=vuln_id, + ecosystem='Ubuntu:24.04', + name='test', + versions=['1.0.0-1', '1.0.0-2']), + models.AffectedVersions( + vuln_id=vuln_id, + ecosystem='Ubuntu:24.04:LTS', + name='test', + events=[ + models.AffectedEvent(type='introduced', value='0'), + models.AffectedEvent(type='fixed', value='1.0.0-3') + ]), + models.AffectedVersions( + vuln_id=vuln_id, + ecosystem='Ubuntu:24.04:LTS', + name='test', + versions=['1.0.0-1', '1.0.0-2']), + models.AffectedVersions( + vuln_id=vuln_id, + ecosystem='Ubuntu:25.04', + name='test', + events=[ + models.AffectedEvent(type='introduced', value='0'), + models.AffectedEvent(type='fixed', value='1.0.0-3') + ]), + models.AffectedVersions( + vuln_id=vuln_id, + ecosystem='Ubuntu:25.04', + name='test', + versions=['1.0.0-1', '1.0.0-2']), + models.AffectedVersions( + vuln_id=vuln_id, + ecosystem='npm', + name='testjs', + events=[ + models.AffectedEvent(type='introduced', value='0'), + models.AffectedEvent(type='fixed', value='1.0.0') + ]), + models.AffectedVersions( + vuln_id=vuln_id, + ecosystem='npm', + name='testjs', + events=[ + models.AffectedEvent(type='introduced', value='2.0.0'), + models.AffectedEvent(type='last_affected', value='2.2.0') + ]), + models.AffectedVersions( + vuln_id=vuln_id, + ecosystem='npm', + name='testjs', + versions=['0.1.0', '0.2.0', '0.3.0', '2.0.0', '2.1.0', '2.2.0']), + ] + self.assertListEqual([a.to_dict() for a in want], + [a.to_dict() for a in affected]) + + # Check the records written to the 'bucket' (which is mocked) are expected. + vuln_pb = put_bug.to_vulnerability(True, True, True) + + bucket = gcs.get_osv_bucket() + blob = bucket.get_blob(os.path.join(gcs.VULN_PB_PATH, f'{vuln_id}.pb')) + self.assertIsNotNone(blob) + self.assertEqual(blob.custom_time, + datetime.datetime(2025, 3, 5, tzinfo=datetime.UTC)) + got_pb = vulnerability_pb2.Vulnerability().FromString( + blob.download_as_bytes()) + self.assertEqual(got_pb, vuln_pb) + + blob = bucket.get_blob(os.path.join(gcs.VULN_JSON_PATH, f'{vuln_id}.json')) + self.assertIsNotNone(blob) + self.assertEqual(blob.custom_time, + datetime.datetime(2025, 3, 5, tzinfo=datetime.UTC)) + got_json = json.loads(blob.download_as_bytes()) + self.assertDictEqual(got_json, sources.vulnerability_to_dict(vuln_pb)) + + def test_bug_withdraw(self): + """Test if withdrawing a Bug correctly removes unneeded indices.""" + # First put the bug un-withdrawn + vuln_id = 'TEST-999' + models.Bug( + db_id=vuln_id, + status=bug.BugStatus.PROCESSED, + timestamp=datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC), + last_modified=datetime.datetime(2025, 3, 3, tzinfo=datetime.UTC), + import_last_modified=datetime.datetime(2025, 2, 2, tzinfo=datetime.UTC), + source_id=f'test:{vuln_id}.json', + source_of_truth=models.SourceOfTruth.SOURCE_REPO, + public=True, + affected_packages=[ + models.AffectedPackage( + package=models.Package(ecosystem='PyPI', name='testpy'), + ranges=[ + models.AffectedRange2( + type='ECOSYSTEM', + events=[ + models.AffectedEvent(type='introduced', value='0'), + models.AffectedEvent(type='fixed', value='1.0'), + ]) + ], + versions=['0.1', '0.2'], + ), + ], + ).put() + put_bug = models.Bug.get_by_id(vuln_id) + self.assertIsNotNone(put_bug) + put_bug: models.Bug + + vulnerability = models.Vulnerability.get_by_id(vuln_id) + self.assertIsNotNone(vulnerability) + vulnerability: models.Vulnerability + self.assertFalse(vulnerability.is_withdrawn) + listed_vuln = models.ListedVulnerability.get_by_id(vuln_id) + self.assertIsNotNone(listed_vuln) + affected = models.AffectedVersions.query( + models.AffectedVersions.vuln_id == vuln_id).fetch() + self.assertEqual(2, len(affected)) + + bucket = gcs.get_osv_bucket() + blob = bucket.get_blob(os.path.join(gcs.VULN_PB_PATH, f'{vuln_id}.pb')) + self.assertIsNotNone(blob) + self.assertEqual(blob.custom_time, + datetime.datetime(2025, 3, 3, tzinfo=datetime.UTC)) + blob = bucket.get_blob(os.path.join(gcs.VULN_JSON_PATH, f'{vuln_id}.json')) + self.assertIsNotNone(blob) + self.assertEqual(blob.custom_time, + datetime.datetime(2025, 3, 3, tzinfo=datetime.UTC)) + + # Now withdraw the Bug + put_bug.withdrawn = datetime.datetime(2025, 4, 4, tzinfo=datetime.UTC) + put_bug.last_modified = datetime.datetime(2025, 4, 4, tzinfo=datetime.UTC) + put_bug.put() + + # Vulnerability exists, but is withdrawn + vulnerability = models.Vulnerability.get_by_id(vuln_id) + self.assertIsNotNone(vulnerability) + vulnerability: models.Vulnerability + self.assertTrue(vulnerability.is_withdrawn) + # ListedVulnerability and AffectedVersions have been removed + listed_vuln = models.ListedVulnerability.get_by_id(vuln_id) + self.assertIsNone(listed_vuln) + affected = models.AffectedVersions.query( + models.AffectedVersions.vuln_id == vuln_id).fetch() + self.assertEqual(0, len(affected)) + # Blobs still exist, and were re-written + bucket = gcs.get_osv_bucket() + blob = bucket.get_blob(os.path.join(gcs.VULN_PB_PATH, f'{vuln_id}.pb')) + self.assertIsNotNone(blob) + self.assertEqual(blob.custom_time, + datetime.datetime(2025, 4, 4, tzinfo=datetime.UTC)) + blob = bucket.get_blob(os.path.join(gcs.VULN_JSON_PATH, f'{vuln_id}.json')) + self.assertIsNotNone(blob) + self.assertEqual(blob.custom_time, + datetime.datetime(2025, 4, 4, tzinfo=datetime.UTC)) + + def test_oss_fuzz_private(self): + """Test that non-public Bugs from OSS-Fuzz are not indexed.""" + vuln_id = 'TEST-OSSFUZZ' + models.Bug( + db_id=vuln_id, + status=bug.BugStatus.UNPROCESSED, + public=False, + timestamp=datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC), + last_modified=datetime.datetime(2025, 3, 3, tzinfo=datetime.UTC), + import_last_modified=datetime.datetime(2025, 2, 2, tzinfo=datetime.UTC), + source_id=f'test:{vuln_id}.json', + source_of_truth=models.SourceOfTruth.SOURCE_REPO, + affected_packages=[ + models.AffectedPackage( + package=models.Package(ecosystem='PyPI', name='testpy'), + ranges=[ + models.AffectedRange2( + type='ECOSYSTEM', + events=[ + models.AffectedEvent(type='introduced', value='0'), + models.AffectedEvent(type='fixed', value='1.0'), + ]) + ], + versions=['0.1', '0.2'], + ), + ], + ).put() + + vulnerability = models.Vulnerability.get_by_id(vuln_id) + self.assertIsNone(vulnerability) + listed_vuln = models.ListedVulnerability.get_by_id(vuln_id) + self.assertIsNone(listed_vuln) + affected = models.AffectedVersions.query( + models.AffectedVersions.vuln_id == vuln_id).fetch() + self.assertEqual(0, len(affected)) + bucket = gcs.get_osv_bucket() + blob = bucket.get_blob(os.path.join(gcs.VULN_PB_PATH, f'{vuln_id}.pb')) + self.assertIsNone(blob) + blob = bucket.get_blob(os.path.join(gcs.VULN_JSON_PATH, f'{vuln_id}.json')) + self.assertIsNone(blob) + + +def setUpModule(): + """Set up the test module.""" + tests.start_datastore_emulator() + ndb_client = ndb.Client() + unittest.enterModuleContext(ndb_client.context(cache_policy=False)) + + +def tearDownModule(): + """Tear down the test module.""" + tests.stop_emulator() + + +if __name__ == '__main__': + unittest.main() diff --git a/osv/tests.py b/osv/tests.py index 2c979453f57..7ee6e28628c 100644 --- a/osv/tests.py +++ b/osv/tests.py @@ -28,6 +28,8 @@ import pygit2 import pygit2.enums +from . import gcs_mock + _EMULATOR_TIMEOUT = 30 _DATASTORE_EMULATOR_PORT = '8002' _DATASTORE_READY_INDICATOR = b'is now running' @@ -108,9 +110,12 @@ def commit(self, author_name, author_email, message='Changes'): _ds_data_dir = None +_mock_gcs_ctx = None + def start_datastore_emulator(): """Starts Datastore emulator.""" + # TODO(michaelkedar): turn this into a context (`with datastore_emulator()`) _kill_existing_datastore_emulator() port = os.environ.get('DATASTORE_EMULATOR_PORT', _DATASTORE_EMULATOR_PORT) @@ -119,23 +124,30 @@ def start_datastore_emulator(): os.environ['GOOGLE_CLOUD_PROJECT'] = TEST_PROJECT_ID global _ds_data_dir _ds_data_dir = tempfile.TemporaryDirectory() - + # TODO(michaelkedar): use `gcloud emulators firestore` with + # `--database-mode=datastore-mode` instead. proc = subprocess.Popen([ 'gcloud', 'beta', 'emulators', 'datastore', 'start', - '--consistency=1.0', '--host-port=localhost:' + port, '--project=' + TEST_PROJECT_ID, '--no-store-on-disk', f'--data-dir={_ds_data_dir.name}', + '--use-firestore-in-datastore-mode', ], stdout=subprocess.PIPE, stderr=subprocess.STDOUT) _wait_for_emulator_ready(proc, 'datastore', _DATASTORE_READY_INDICATOR) + + # Also mock the GCS bucket. + global _mock_gcs_ctx + _mock_gcs_ctx = gcs_mock.gcs_mock() + _mock_gcs_ctx.__enter__() # pylint: disable=unnecessary-dunder-call + return proc @@ -242,6 +254,11 @@ def reset_emulator(): def stop_emulator(): """Stops emulator.""" + global _mock_gcs_ctx + if _mock_gcs_ctx is not None: + _mock_gcs_ctx.__exit__(None, None, None) + _mock_gcs_ctx = None + try: port = os.environ.get('DATASTORE_EMULATOR_PORT', _DATASTORE_EMULATOR_PORT) resp = requests.post( diff --git a/run_tests.sh b/run_tests.sh index 7cc0a9654eb..de0afb09ac3 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -6,6 +6,7 @@ poetry run python -m unittest osv.purl_helpers_test poetry run python -m unittest osv.request_helper_test poetry run python -m unittest osv.semver_index_test poetry run python -m unittest osv.impact_test +poetry run python -m unittest osv.models_test # Run all osv.ecosystems tests poetry run python -m unittest discover osv/ecosystems/ "*_test.py" . diff --git a/tools/datafix/reput_bugs.py b/tools/datafix/reput_bugs.py index 470bf175a96..e3527e918e0 100755 --- a/tools/datafix/reput_bugs.py +++ b/tools/datafix/reput_bugs.py @@ -30,32 +30,24 @@ def reput_bugs(dryrun: bool, source: str, ids: list) -> None: num_reputted = 0 time_start = time.perf_counter() - # This handles the actual transaction of reputting the bugs with ndb - def _reput_ndb(): - # Reputting the bug runs the Bug _pre_put_hook() in models.py - if dryrun: - print("Dry run mode. Preventing transaction from commiting") - raise Exception("Dry run mode") # pylint: disable=broad-exception-raised - ndb.put_multi_async([ - osv.Bug.get_by_id(r.id()) for r in result[batch:batch + MAX_BATCH_SIZE] - ]) - print(f"Time elapsed: {(time.perf_counter() - time_start):.2f} seconds.") - # Chunk the results to reput in acceptibly sized batches for the API. for batch in range(0, len(result), MAX_BATCH_SIZE): try: num_reputted += len(result[batch:batch + MAX_BATCH_SIZE]) print( f"Reput {num_reputted} bugs... - {num_reputted/len(result)*100:.2f}%") - ndb.transaction(_reput_ndb) - except Exception as e: - # Don't have the first batch's transaction-aborting exception stop - # subsequent batches from being attempted. - if dryrun and e.args[0].startswith("Dry run mode"): - print("Dry run mode. Preventing transaction from commiting") + if dryrun: + print("Dry run mode. Preventing put") else: - print([r.id() for r in result[batch:batch + MAX_BATCH_SIZE]]) - print(f"Exception {e} occurred. Continuing to next batch.") + # Reputting the bug runs the Bug _pre/post_put_hook() in models.py + ndb.put_multi([ + osv.Bug.get_by_id(r.id()) + for r in result[batch:batch + MAX_BATCH_SIZE] + ]) + print(f"Time elapsed: {(time.perf_counter()-time_start):.2f} seconds.") + except Exception as e: + print([r.id() for r in result[batch:batch + MAX_BATCH_SIZE]]) + print(f"Exception {e} occurred. Continuing to next batch.") print("Reputted!") @@ -90,7 +82,7 @@ def main() -> None: args = parser.parse_args() client = ndb.Client(project=args.project) - with client.context(): + with client.context(cache_policy=False): reput_bugs(args.dryrun, args.source, args.bugs) diff --git a/tools/datafix/reput_helper.py b/tools/datafix/reput_helper.py index 0a9a8f44a40..5c950649d20 100644 --- a/tools/datafix/reput_helper.py +++ b/tools/datafix/reput_helper.py @@ -11,7 +11,6 @@ import argparse import json -import functools import time import typing @@ -19,12 +18,11 @@ # Global flags verbose = False -fullrefresh = False transform = True class DryRunException(Exception): - """This exception is raised to cancel a transaction during dry runs""" + """This exception is raised to cancel a put during dry runs""" def get_relevant_ids() -> list[str]: @@ -89,30 +87,31 @@ def refresh_ids(dryrun: bool, loadcache: str) -> None: num_reputted = 0 time_start = time.perf_counter() - # This handles the actual transaction of reputting - # the bugs with ndb + # This handles the actual reput of the bugs with ndb def _refresh_ids(batch: int): buf: list[osv.Bug] = [ osv.Bug.get_by_id(r) for r in relevant_ids[batch:batch + MAX_BATCH_SIZE] ] - if fullrefresh: - # Delete the existing entries. This must be done in a transaction - # to avoid losing data if interrupted - ndb.delete_multi([r.key for r in buf]) + old_keys = {r.key for r in buf} if transform: # Clear the key so the key name will be regenerated to the new key format for elem in buf: transform_bug(elem) - # Reput the bug back in - ndb.put_multi_async(buf) - if dryrun: - print("Dry run mode. Preventing transaction from committing") + print("Dry run mode. Preventing put") raise DryRunException + # Reput the bug back in + new_keys = set(ndb.put_multi(buf)) + + # If the keys have changed, delete the old keys + # There's a potential for old keys to not get cleaned up if something fails + if deleted := old_keys - new_keys: + ndb.delete_multi(deleted) + print(f"Time elapsed: {(time.perf_counter() - time_start):.2f} seconds.") # Chunk the results to reput in acceptibly sized batches for the API. @@ -121,11 +120,11 @@ def _refresh_ids(batch: int): num_reputted += len(relevant_ids[batch:batch + MAX_BATCH_SIZE]) print(f"Reput {num_reputted} bugs... - " f"{num_reputted/len(relevant_ids)*100:.2f}%") - ndb.transaction(functools.partial(_refresh_ids, batch)) + _refresh_ids(batch) except DryRunException: - # Don't have the first batch's transaction-aborting exception stop + # Don't have the first batch's put-aborting exception stop # subsequent batches from being attempted. - print("Dry run mode. Preventing transaction from committing") + print("Dry run mode. Preventing put") except Exception as e: print(relevant_ids[batch:batch + MAX_BATCH_SIZE]) print(f"Exception {e} occurred. Continuing to next batch.") @@ -148,12 +147,6 @@ def main() -> None: dest="verbose", default=False, help="Print each ID that needs to be processed") - parser.add_argument( - "--full-refresh", - action=argparse.BooleanOptionalAction, - dest="fullrefresh", - default=False, - help="Deletes the bug before reputting, necessary for key changes") parser.add_argument( "--transform", action=argparse.BooleanOptionalAction, @@ -174,16 +167,14 @@ def main() -> None: args = parser.parse_args() global verbose - global fullrefresh global transform verbose = args.verbose - fullrefresh = args.fullrefresh transform = args.transform client = ndb.Client(project=args.project) print(f"Running on project {args.project}.") - with client.context(): + with client.context(cache_policy=False): refresh_ids(args.dryrun, args.loadcache)