Skip to content

Commit

Permalink
Add anon client (#52)
Browse files Browse the repository at this point in the history
Add anon client to download ASTs from gcs bucket.
Also some path changes (more usage of os.path).

Fixes #21
  • Loading branch information
trashvisor authored Feb 1, 2024
1 parent d06fbc9 commit 51a636b
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions data_prep/project_context/context_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from collections import defaultdict
from typing import List, Tuple

from google.cloud import storage


class ContextRetriever:
"""ContextRetriever attempts to retrieve context for a certain project/function from ASTs."""
Expand All @@ -19,6 +21,8 @@ class ContextRetriever:

DOWNLOAD_TO_PATH = 'oss-fuzz-data/asts'

OSS_FUZZ_EXP_BUCKET = 'oss-fuzz-llm-public'

def __init__(self, project_name: str, function_signature: str):
self._record_decl_nodes = defaultdict(list)
self._typedef_decl_nodes = defaultdict(list)
Expand All @@ -27,7 +31,8 @@ def __init__(self, project_name: str, function_signature: str):
self._function_signature = function_signature
self._download_from_path = f'{self.AST_BASE_PATH}/{self._project_name}/*'
self._uuid = uuid.uuid4()
self._ast_path = f'{self.DOWNLOAD_TO_PATH}/{self._project_name}-{self._uuid}'
self._ast_path = os.path.join(self.DOWNLOAD_TO_PATH,
f'{self._project_name}-{self._uuid}')

def _get_function_name(self, target_function_signature: str) -> str:
"""Retrieves the function name from the target function signature."""
Expand Down Expand Up @@ -253,17 +258,20 @@ def _get_header_from_file(self, fully_qualified_path: str) -> str:

def retrieve_asts(self):
"""Downloads ASTs for the given project."""
os.makedirs(self._ast_path, exist_ok=True)

download_command = [
'gsutil', '-m', 'cp', '-r', self._download_from_path, self._ast_path
]
subprocess.run(
download_command,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
storage_client = storage.Client.create_anonymous_client()
bucket = storage_client.bucket(self.OSS_FUZZ_EXP_BUCKET)
project_prefix = os.path.join('project_asts', self._project_name)
blobs = bucket.list_blobs(prefix=project_prefix)
ast_dir = os.path.abspath(
os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(__file__))),
self._ast_path))

os.makedirs(ast_dir, exist_ok=True)

for blob in blobs:
file_relpath = blob.name.replace(f'{project_prefix}/', '')
blob.download_to_filename(os.path.join(ast_dir, file_relpath))

def cleanup_asts(self):
"""Removes ASTs for the given project."""
Expand Down

0 comments on commit 51a636b

Please sign in to comment.