From e47e89609afcd8dbd3dda706f6ffd14fdaaa53b0 Mon Sep 17 00:00:00 2001 From: bmc-msft <41130664+bmc-msft@users.noreply.github.com> Date: Wed, 18 Nov 2020 09:06:14 -0500 Subject: [PATCH] Use Storage Account types, rather than account_id (#320) We need to move to supporting data sharding. One of the steps towards that is stop passing around `account_id`, rather we need to specify the type of storage we need. --- .../__app__/agent_registration/__init__.py | 5 +- .../__app__/containers/__init__.py | 10 ++- src/api-service/__app__/download/__init__.py | 12 ++- .../__app__/onefuzzlib/azure/containers.py | 85 +++++++++++++------ .../__app__/onefuzzlib/azure/creds.py | 8 +- .../__app__/onefuzzlib/azure/queue.py | 42 ++++----- .../__app__/onefuzzlib/extension.py | 66 ++++++++------ .../__app__/onefuzzlib/notifications/main.py | 12 +-- src/api-service/__app__/onefuzzlib/pools.py | 22 ++--- src/api-service/__app__/onefuzzlib/proxy.py | 9 +- src/api-service/__app__/onefuzzlib/reports.py | 4 +- src/api-service/__app__/onefuzzlib/repro.py | 6 +- .../__app__/onefuzzlib/tasks/config.py | 23 ++--- .../__app__/onefuzzlib/tasks/main.py | 6 +- .../__app__/onefuzzlib/tasks/scheduler.py | 22 +++-- src/api-service/__app__/onefuzzlib/updates.py | 4 +- .../__app__/onefuzzlib/webhooks.py | 4 +- src/api-service/__app__/pool/__init__.py | 4 +- 18 files changed, 205 insertions(+), 139 deletions(-) diff --git a/src/api-service/__app__/agent_registration/__init__.py b/src/api-service/__app__/agent_registration/__init__.py index 9ed3657cff..82131b38db 100644 --- a/src/api-service/__app__/agent_registration/__init__.py +++ b/src/api-service/__app__/agent_registration/__init__.py @@ -13,7 +13,8 @@ from onefuzztypes.responses import AgentRegistration from ..onefuzzlib.agent_authorization import call_if_agent -from ..onefuzzlib.azure.creds import get_fuzz_storage, get_instance_url +from ..onefuzzlib.azure.containers import StorageType +from ..onefuzzlib.azure.creds import get_instance_url from ..onefuzzlib.azure.queue import get_queue_sas from ..onefuzzlib.pools import Node, NodeMessage, NodeTasks, Pool from ..onefuzzlib.request import not_ok, ok, parse_uri @@ -25,7 +26,7 @@ def create_registration_response(machine_id: UUID, pool: Pool) -> func.HttpRespo commands_url = "%s/api/agents/commands" % base_address work_queue = get_queue_sas( pool.get_pool_queue(), - account_id=get_fuzz_storage(), + StorageType.corpus, read=True, update=True, process=True, diff --git a/src/api-service/__app__/containers/__init__.py b/src/api-service/__app__/containers/__init__.py index 8713a83530..5e5f4f8248 100644 --- a/src/api-service/__app__/containers/__init__.py +++ b/src/api-service/__app__/containers/__init__.py @@ -13,6 +13,7 @@ from onefuzztypes.responses import BoolResult, ContainerInfo, ContainerInfoBase from ..onefuzzlib.azure.containers import ( + StorageType, create_container, delete_container, get_container_metadata, @@ -30,7 +31,7 @@ def get(req: func.HttpRequest) -> func.HttpResponse: if isinstance(request, Error): return not_ok(request, context="container get") if request is not None: - metadata = get_container_metadata(request.name) + metadata = get_container_metadata(request.name, StorageType.corpus) if metadata is None: return not_ok( Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]), @@ -41,6 +42,7 @@ def get(req: func.HttpRequest) -> func.HttpResponse: name=request.name, sas_url=get_container_sas_url( request.name, + StorageType.corpus, read=True, write=True, create=True, @@ -51,7 +53,7 @@ def get(req: func.HttpRequest) -> func.HttpResponse: ) return ok(info) - containers = get_containers() + containers = get_containers(StorageType.corpus) container_info = [] for name in containers: @@ -66,7 +68,7 @@ def post(req: func.HttpRequest) -> func.HttpResponse: return not_ok(request, context="container create") logging.info("container - creating %s", request.name) - sas = create_container(request.name, metadata=request.metadata) + sas = create_container(request.name, StorageType.corpus, metadata=request.metadata) if sas: return ok( ContainerInfo(name=request.name, sas_url=sas, metadata=request.metadata) @@ -83,7 +85,7 @@ def delete(req: func.HttpRequest) -> func.HttpResponse: return not_ok(request, context="container delete") logging.info("container - deleting %s", request.name) - return ok(BoolResult(result=delete_container(request.name))) + return ok(BoolResult(result=delete_container(request.name, StorageType.corpus))) def main(req: func.HttpRequest) -> func.HttpResponse: diff --git a/src/api-service/__app__/download/__init__.py b/src/api-service/__app__/download/__init__.py index 9903dd1926..65b6f7b853 100644 --- a/src/api-service/__app__/download/__init__.py +++ b/src/api-service/__app__/download/__init__.py @@ -8,6 +8,7 @@ from onefuzztypes.models import Error, FileEntry from ..onefuzzlib.azure.containers import ( + StorageType, blob_exists, container_exists, get_file_sas_url, @@ -20,13 +21,13 @@ def get(req: func.HttpRequest) -> func.HttpResponse: if isinstance(request, Error): return not_ok(request, context="download") - if not container_exists(request.container): + if not container_exists(request.container, StorageType.corpus): return not_ok( Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]), context=request.container, ) - if not blob_exists(request.container, request.filename): + if not blob_exists(request.container, request.filename, StorageType.corpus): return not_ok( Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid filename"]), context=request.filename, @@ -34,7 +35,12 @@ def get(req: func.HttpRequest) -> func.HttpResponse: return redirect( get_file_sas_url( - request.container, request.filename, read=True, days=0, minutes=5 + request.container, + request.filename, + StorageType.corpus, + read=True, + days=0, + minutes=5, ) ) diff --git a/src/api-service/__app__/onefuzzlib/azure/containers.py b/src/api-service/__app__/onefuzzlib/azure/containers.py index 8dcdc906cf..4e0077b701 100644 --- a/src/api-service/__app__/onefuzzlib/azure/containers.py +++ b/src/api-service/__app__/onefuzzlib/azure/containers.py @@ -6,37 +6,61 @@ import datetime import os import urllib.parse -from typing import Dict, Optional, Union, cast +from enum import Enum +from typing import Any, Dict, Optional, Union, cast from azure.common import AzureHttpError, AzureMissingResourceHttpError from azure.storage.blob import BlobPermissions, ContainerPermissions from memoization import cached -from .creds import get_blob_service +from .creds import get_blob_service, get_func_storage, get_fuzz_storage + + +class StorageType(Enum): + corpus = "corpus" + config = "config" + + +def get_account_id_by_type(storage_type: StorageType) -> str: + if storage_type == StorageType.corpus: + account_id = get_fuzz_storage() + elif storage_type == StorageType.config: + account_id = get_func_storage() + else: + raise NotImplementedError + return account_id + + +@cached(ttl=5) +def get_blob_service_by_type(storage_type: StorageType) -> Any: + account_id = get_account_id_by_type(storage_type) + return get_blob_service(account_id) @cached(ttl=5) -def container_exists(name: str, account_id: Optional[str] = None) -> bool: +def container_exists(name: str, storage_type: StorageType) -> bool: try: - get_blob_service(account_id).get_container_properties(name) + get_blob_service_by_type(storage_type).get_container_properties(name) return True except AzureHttpError: return False -def get_containers(account_id: Optional[str] = None) -> Dict[str, Dict[str, str]]: +def get_containers(storage_type: StorageType) -> Dict[str, Dict[str, str]]: return { x.name: x.metadata - for x in get_blob_service(account_id).list_containers(include_metadata=True) + for x in get_blob_service_by_type(storage_type).list_containers( + include_metadata=True + ) if not x.name.startswith("$") } def get_container_metadata( - name: str, account_id: Optional[str] = None + name: str, storage_type: StorageType ) -> Optional[Dict[str, str]]: try: - result = get_blob_service(account_id).get_container_metadata(name) + result = get_blob_service_by_type(storage_type).get_container_metadata(name) return cast(Dict[str, str], result) except AzureHttpError: pass @@ -44,22 +68,29 @@ def get_container_metadata( def create_container( - name: str, metadata: Optional[Dict[str, str]], account_id: Optional[str] = None + name: str, storage_type: StorageType, metadata: Optional[Dict[str, str]] ) -> Optional[str]: try: - get_blob_service(account_id).create_container(name, metadata=metadata) + get_blob_service_by_type(storage_type).create_container(name, metadata=metadata) except AzureHttpError: # azure storage already logs errors return None return get_container_sas_url( - name, read=True, add=True, create=True, write=True, delete=True, list=True + name, + storage_type, + read=True, + add=True, + create=True, + write=True, + delete=True, + list=True, ) -def delete_container(name: str, account_id: Optional[str] = None) -> bool: +def delete_container(name: str, storage_type: StorageType) -> bool: try: - return bool(get_blob_service(account_id).delete_container(name)) + return bool(get_blob_service_by_type(storage_type).delete_container(name)) except AzureHttpError: # azure storage already logs errors return False @@ -67,7 +98,8 @@ def delete_container(name: str, account_id: Optional[str] = None) -> bool: def get_container_sas_url( container: str, - account_id: Optional[str] = None, + storage_type: StorageType, + *, read: bool = False, add: bool = False, create: bool = False, @@ -75,7 +107,7 @@ def get_container_sas_url( delete: bool = False, list: bool = False, ) -> str: - service = get_blob_service(account_id) + service = get_blob_service_by_type(storage_type) expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30) permission = ContainerPermissions(read, add, create, write, delete, list) @@ -91,7 +123,8 @@ def get_container_sas_url( def get_file_sas_url( container: str, name: str, - account_id: Optional[str] = None, + storage_type: StorageType, + *, read: bool = False, add: bool = False, create: bool = False, @@ -102,7 +135,7 @@ def get_file_sas_url( hours: int = 0, minutes: int = 0, ) -> str: - service = get_blob_service(account_id) + service = get_blob_service_by_type(storage_type) expiry = datetime.datetime.utcnow() + datetime.timedelta( days=days, hours=hours, minutes=minutes ) @@ -117,9 +150,9 @@ def get_file_sas_url( def save_blob( - container: str, name: str, data: Union[str, bytes], account_id: Optional[str] = None + container: str, name: str, data: Union[str, bytes], storage_type: StorageType ) -> None: - service = get_blob_service(account_id) + service = get_blob_service_by_type(storage_type) service.create_container(container) if isinstance(data, str): service.create_blob_from_text(container, name, data) @@ -127,10 +160,8 @@ def save_blob( service.create_blob_from_bytes(container, name, data) -def get_blob( - container: str, name: str, account_id: Optional[str] = None -) -> Optional[bytes]: - service = get_blob_service(account_id) +def get_blob(container: str, name: str, storage_type: StorageType) -> Optional[bytes]: + service = get_blob_service_by_type(storage_type) try: blob = service.get_blob_to_bytes(container, name).content return cast(bytes, blob) @@ -138,8 +169,8 @@ def get_blob( return None -def blob_exists(container: str, name: str, account_id: Optional[str] = None) -> bool: - service = get_blob_service(account_id) +def blob_exists(container: str, name: str, storage_type: StorageType) -> bool: + service = get_blob_service_by_type(storage_type) try: service.get_blob_properties(container, name) return True @@ -147,8 +178,8 @@ def blob_exists(container: str, name: str, account_id: Optional[str] = None) -> return False -def delete_blob(container: str, name: str, account_id: Optional[str] = None) -> bool: - service = get_blob_service(account_id) +def delete_blob(container: str, name: str, storage_type: StorageType) -> bool: + service = get_blob_service_by_type(storage_type) try: service.delete_blob(container, name) return True diff --git a/src/api-service/__app__/onefuzzlib/azure/creds.py b/src/api-service/__app__/onefuzzlib/azure/creds.py index f72a247793..32d7cc15c3 100644 --- a/src/api-service/__app__/onefuzzlib/azure/creds.py +++ b/src/api-service/__app__/onefuzzlib/azure/creds.py @@ -87,12 +87,12 @@ def get_insights_appid() -> str: return os.environ["APPINSIGHTS_APPID"] -@cached +# @cached def get_fuzz_storage() -> str: return os.environ["ONEFUZZ_DATA_STORAGE"] -@cached +# @cached def get_func_storage() -> str: return os.environ["ONEFUZZ_FUNC_STORAGE"] @@ -109,9 +109,9 @@ def get_instance_url() -> str: @cached def get_instance_id() -> UUID: - from .containers import get_blob + from .containers import StorageType, get_blob - blob = get_blob("base-config", "instance_id", account_id=get_func_storage()) + blob = get_blob("base-config", "instance_id", StorageType.config) if blob is None: raise Exception("missing instance_id") return UUID(blob.decode()) diff --git a/src/api-service/__app__/onefuzzlib/azure/queue.py b/src/api-service/__app__/onefuzzlib/azure/queue.py index 95557b7021..e73edf35b1 100644 --- a/src/api-service/__app__/onefuzzlib/azure/queue.py +++ b/src/api-service/__app__/onefuzzlib/azure/queue.py @@ -19,6 +19,7 @@ from memoization import cached from pydantic import BaseModel +from .containers import StorageType, get_account_id_by_type from .creds import get_storage_account_name_key QueueNameType = Union[str, UUID] @@ -27,7 +28,8 @@ @cached(ttl=60) -def get_queue_client(account_id: str) -> QueueServiceClient: +def get_queue_client(storage_type: StorageType) -> QueueServiceClient: + account_id = get_account_id_by_type(storage_type) logging.debug("getting blob container (account_id: %s)", account_id) name, key = get_storage_account_name_key(account_id) account_url = "https://%s.queue.core.windows.net" % name @@ -41,13 +43,14 @@ def get_queue_client(account_id: str) -> QueueServiceClient: @cached(ttl=60) def get_queue_sas( queue: QueueNameType, + storage_type: StorageType, *, - account_id: str, read: bool = False, add: bool = False, update: bool = False, process: bool = False, ) -> str: + account_id = get_account_id_by_type(storage_type) logging.debug("getting queue sas %s (account_id: %s)", queue, account_id) name, key = get_storage_account_name_key(account_id) expiry = datetime.datetime.utcnow() + datetime.timedelta(days=30) @@ -67,31 +70,33 @@ def get_queue_sas( @cached(ttl=60) -def create_queue(name: QueueNameType, *, account_id: str) -> None: - client = get_queue_client(account_id) +def create_queue(name: QueueNameType, storage_type: StorageType) -> None: + client = get_queue_client(storage_type) try: client.create_queue(str(name)) except ResourceExistsError: pass -def delete_queue(name: QueueNameType, *, account_id: str) -> None: - client = get_queue_client(account_id) +def delete_queue(name: QueueNameType, storage_type: StorageType) -> None: + client = get_queue_client(storage_type) queues = client.list_queues() if str(name) in [x["name"] for x in queues]: client.delete_queue(name) -def get_queue(name: QueueNameType, *, account_id: str) -> Optional[QueueServiceClient]: - client = get_queue_client(account_id) +def get_queue( + name: QueueNameType, storage_type: StorageType +) -> Optional[QueueServiceClient]: + client = get_queue_client(storage_type) try: return client.get_queue_client(str(name)) except ResourceNotFoundError: return None -def clear_queue(name: QueueNameType, *, account_id: str) -> None: - queue = get_queue(name, account_id=account_id) +def clear_queue(name: QueueNameType, storage_type: StorageType) -> None: + queue = get_queue(name, storage_type) if queue: try: queue.clear_messages() @@ -102,12 +107,12 @@ def clear_queue(name: QueueNameType, *, account_id: str) -> None: def send_message( name: QueueNameType, message: bytes, + storage_type: StorageType, *, - account_id: str, visibility_timeout: Optional[int] = None, time_to_live: int = DEFAULT_TTL, ) -> None: - queue = get_queue(name, account_id=account_id) + queue = get_queue(name, storage_type) if queue: try: queue.send_message( @@ -119,9 +124,8 @@ def send_message( pass -def remove_first_message(name: QueueNameType, *, account_id: str) -> bool: - create_queue(name, account_id=account_id) - queue = get_queue(name, account_id=account_id) +def remove_first_message(name: QueueNameType, storage_type: StorageType) -> bool: + queue = get_queue(name, storage_type) if queue: try: for message in queue.receive_messages(): @@ -143,8 +147,8 @@ def remove_first_message(name: QueueNameType, *, account_id: str) -> bool: # https://docs.microsoft.com/en-us/python/api/azure-storage-queue/azure.storage.queue.queueclient def peek_queue( name: QueueNameType, + storage_type: StorageType, *, - account_id: str, object_type: Type[A], max_messages: int = MAX_PEEK_SIZE, ) -> List[A]: @@ -154,7 +158,7 @@ def peek_queue( if max_messages < MIN_PEEK_SIZE or max_messages > MAX_PEEK_SIZE: raise ValueError("invalid max messages: %s" % max_messages) - queue = get_queue(name, account_id=account_id) + queue = get_queue(name, storage_type) if not queue: return result @@ -168,12 +172,12 @@ def peek_queue( def queue_object( name: QueueNameType, message: BaseModel, + storage_type: StorageType, *, - account_id: str, visibility_timeout: Optional[int] = None, time_to_live: int = DEFAULT_TTL, ) -> bool: - queue = get_queue(name, account_id=account_id) + queue = get_queue(name, storage_type) if not queue: raise Exception("unable to queue object, no such queue: %s" % queue) diff --git a/src/api-service/__app__/onefuzzlib/extension.py b/src/api-service/__app__/onefuzzlib/extension.py index 3640c5d54a..b75eeb4bc1 100644 --- a/src/api-service/__app__/onefuzzlib/extension.py +++ b/src/api-service/__app__/onefuzzlib/extension.py @@ -11,8 +11,13 @@ from onefuzztypes.models import AgentConfig, ReproConfig from onefuzztypes.primitives import Extension, Region -from .azure.containers import get_container_sas_url, get_file_sas_url, save_blob -from .azure.creds import get_func_storage, get_instance_id, get_instance_url +from .azure.containers import ( + StorageType, + get_container_sas_url, + get_file_sas_url, + save_blob, +) +from .azure.creds import get_instance_id, get_instance_url from .azure.monitor import get_monitor_settings from .azure.queue import get_queue_sas from .reports import get_report @@ -96,7 +101,7 @@ def build_pool_config(pool_name: str) -> str: instrumentation_key=os.environ.get("APPINSIGHTS_INSTRUMENTATIONKEY"), heartbeat_queue=get_queue_sas( "node-heartbeat", - account_id=get_func_storage(), + StorageType.config, add=True, ), telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"), @@ -107,13 +112,13 @@ def build_pool_config(pool_name: str) -> str: "vm-scripts", "%s/config.json" % pool_name, config.json(), - account_id=get_func_storage(), + StorageType.config, ) return get_file_sas_url( "vm-scripts", "%s/config.json" % pool_name, - account_id=get_func_storage(), + StorageType.config, read=True, ) @@ -124,30 +129,26 @@ def update_managed_scripts() -> None: % ( get_container_sas_url( "instance-specific-setup", + StorageType.config, read=True, list=True, - account_id=get_func_storage(), ) ), "azcopy sync '%s' tools" - % ( - get_container_sas_url( - "tools", read=True, list=True, account_id=get_func_storage() - ) - ), + % (get_container_sas_url("tools", StorageType.config, read=True, list=True)), ] save_blob( "vm-scripts", "managed.ps1", "\r\n".join(commands) + "\r\n", - account_id=get_func_storage(), + StorageType.config, ) save_blob( "vm-scripts", "managed.sh", "\n".join(commands) + "\n", - account_id=get_func_storage(), + StorageType.config, ) @@ -164,25 +165,25 @@ def agent_config( get_file_sas_url( "vm-scripts", "managed.ps1", - account_id=get_func_storage(), + StorageType.config, read=True, ), get_file_sas_url( "tools", "win64/azcopy.exe", - account_id=get_func_storage(), + StorageType.config, read=True, ), get_file_sas_url( "tools", "win64/setup.ps1", - account_id=get_func_storage(), + StorageType.config, read=True, ), get_file_sas_url( "tools", "win64/onefuzz.ps1", - account_id=get_func_storage(), + StorageType.config, read=True, ), ] @@ -206,19 +207,19 @@ def agent_config( get_file_sas_url( "vm-scripts", "managed.sh", - account_id=get_func_storage(), + StorageType.config, read=True, ), get_file_sas_url( "tools", "linux/azcopy", - account_id=get_func_storage(), + StorageType.config, read=True, ), get_file_sas_url( "tools", "linux/setup.sh", - account_id=get_func_storage(), + StorageType.config, read=True, ), ] @@ -263,13 +264,22 @@ def repro_extensions( if setup_container: commands += [ "azcopy sync '%s' ./setup" - % (get_container_sas_url(setup_container, read=True, list=True)), + % ( + get_container_sas_url( + setup_container, StorageType.corpus, read=True, list=True + ) + ), ] urls = [ - get_file_sas_url(repro_config.container, repro_config.path, read=True), get_file_sas_url( - report.input_blob.container, report.input_blob.name, read=True + repro_config.container, repro_config.path, StorageType.corpus, read=True + ), + get_file_sas_url( + report.input_blob.container, + report.input_blob.name, + StorageType.corpus, + read=True, ), ] @@ -288,7 +298,7 @@ def repro_extensions( "task-configs", "%s/%s" % (repro_id, script_name), task_script, - account_id=get_func_storage(), + StorageType.config, ) for repro_file in repro_files: @@ -296,13 +306,13 @@ def repro_extensions( get_file_sas_url( "repro-scripts", repro_file, - account_id=get_func_storage(), + StorageType.config, read=True, ), get_file_sas_url( "task-configs", "%s/%s" % (repro_id, script_name), - account_id=get_func_storage(), + StorageType.config, read=True, ), ] @@ -318,13 +328,13 @@ def proxy_manager_extensions(region: Region) -> List[Extension]: get_file_sas_url( "proxy-configs", "%s/config.json" % region, - account_id=get_func_storage(), + StorageType.config, read=True, ), get_file_sas_url( "tools", "linux/onefuzz-proxy-manager", - account_id=get_func_storage(), + StorageType.config, read=True, ), ] diff --git a/src/api-service/__app__/onefuzzlib/notifications/main.py b/src/api-service/__app__/onefuzzlib/notifications/main.py index 8564012dfe..fdcbc0c2b3 100644 --- a/src/api-service/__app__/onefuzzlib/notifications/main.py +++ b/src/api-service/__app__/onefuzzlib/notifications/main.py @@ -21,11 +21,11 @@ from onefuzztypes.primitives import Container, Event from ..azure.containers import ( + StorageType, container_exists, get_container_metadata, get_file_sas_url, ) -from ..azure.creds import get_fuzz_storage from ..azure.queue import send_message from ..dashboard import add_event from ..orm import ORMMixin @@ -72,7 +72,7 @@ def key_fields(cls) -> Tuple[str, str]: def create( cls, container: Container, config: NotificationTemplate ) -> Result["Notification"]: - if not container_exists(container): + if not container_exists(container, StorageType.corpus): return Error(code=ErrorCode.INVALID_REQUEST, errors=["invalid container"]) existing = cls.get_existing(container, config) @@ -106,7 +106,7 @@ def get_queue_tasks() -> Sequence[Tuple[Task, Sequence[str]]]: @cached(ttl=60) def container_metadata(container: Container) -> Optional[Dict[str, str]]: - return get_container_metadata(container) + return get_container_metadata(container, StorageType.corpus) def new_files(container: Container, filename: str) -> None: @@ -149,9 +149,9 @@ def new_files(container: Container, filename: str) -> None: for (task, containers) in get_queue_tasks(): if container in containers: logging.info("queuing input %s %s %s", container, filename, task.task_id) - url = get_file_sas_url(container, filename, read=True, delete=True) - send_message( - task.task_id, bytes(url, "utf-8"), account_id=get_fuzz_storage() + url = get_file_sas_url( + container, filename, StorageType.corpus, read=True, delete=True ) + send_message(task.task_id, bytes(url, "utf-8"), StorageType.corpus) add_event("new_file", results) diff --git a/src/api-service/__app__/onefuzzlib/pools.py b/src/api-service/__app__/onefuzzlib/pools.py index c0b382c3c1..cc36636f4b 100644 --- a/src/api-service/__app__/onefuzzlib/pools.py +++ b/src/api-service/__app__/onefuzzlib/pools.py @@ -35,7 +35,7 @@ from .__version__ import __version__ from .azure.auth import build_auth -from .azure.creds import get_func_storage, get_fuzz_storage +from .azure.containers import StorageType from .azure.image import get_os from .azure.network import Network from .azure.queue import ( @@ -442,7 +442,7 @@ def populate_work_queue(self) -> None: return worksets = peek_queue( - self.get_pool_queue(), account_id=get_fuzz_storage(), object_type=WorkSet + self.get_pool_queue(), StorageType.corpus, object_type=WorkSet ) for workset in worksets: @@ -460,7 +460,7 @@ def get_pool_queue(self) -> str: return "pool-%s" % self.pool_id.hex def init(self) -> None: - create_queue(self.get_pool_queue(), account_id=get_fuzz_storage()) + create_queue(self.get_pool_queue(), StorageType.corpus) self.state = PoolState.running self.save() @@ -470,7 +470,9 @@ def schedule_workset(self, work_set: WorkSet) -> bool: return False return queue_object( - self.get_pool_queue(), work_set, account_id=get_fuzz_storage() + self.get_pool_queue(), + work_set, + StorageType.corpus, ) @classmethod @@ -531,7 +533,7 @@ def halt(self) -> None: scalesets = Scaleset.search_by_pool(self.name) nodes = Node.search(query={"pool_name": [self.name]}) if not scalesets and not nodes: - delete_queue(self.get_pool_queue(), account_id=get_fuzz_storage()) + delete_queue(self.get_pool_queue(), StorageType.corpus) logging.info("pool stopped, deleting: %s", self.name) self.state = PoolState.halt self.delete() @@ -1053,16 +1055,16 @@ def queue_name(self) -> str: return "to-shrink-%s" % self.scaleset_id.hex def clear(self) -> None: - clear_queue(self.queue_name(), account_id=get_func_storage()) + clear_queue(self.queue_name(), StorageType.config) def create(self) -> None: - create_queue(self.queue_name(), account_id=get_func_storage()) + create_queue(self.queue_name(), StorageType.config) def delete(self) -> None: - delete_queue(self.queue_name(), account_id=get_func_storage()) + delete_queue(self.queue_name(), StorageType.config) def add_entry(self) -> None: - queue_object(self.queue_name(), ShrinkEntry(), account_id=get_func_storage()) + queue_object(self.queue_name(), ShrinkEntry(), StorageType.config) def should_shrink(self) -> bool: - return remove_first_message(self.queue_name(), account_id=get_func_storage()) + return remove_first_message(self.queue_name(), StorageType.config) diff --git a/src/api-service/__app__/onefuzzlib/proxy.py b/src/api-service/__app__/onefuzzlib/proxy.py index 1816297cbb..64ddae8f5f 100644 --- a/src/api-service/__app__/onefuzzlib/proxy.py +++ b/src/api-service/__app__/onefuzzlib/proxy.py @@ -21,8 +21,7 @@ from .__version__ import __version__ from .azure.auth import build_auth -from .azure.containers import get_file_sas_url, save_blob -from .azure.creds import get_func_storage +from .azure.containers import StorageType, get_file_sas_url, save_blob from .azure.ip import get_public_ip from .azure.queue import get_queue_sas from .azure.vm import VM @@ -191,12 +190,12 @@ def save_proxy_config(self) -> None: url=get_file_sas_url( "proxy-configs", "%s/config.json" % self.region, - account_id=get_func_storage(), + StorageType.config, read=True, ), notification=get_queue_sas( "proxy", - account_id=get_func_storage(), + StorageType.config, add=True, ), forwards=forwards, @@ -207,7 +206,7 @@ def save_proxy_config(self) -> None: "proxy-configs", "%s/config.json" % self.region, proxy_config.json(), - account_id=get_func_storage(), + StorageType.config, ) @classmethod diff --git a/src/api-service/__app__/onefuzzlib/reports.py b/src/api-service/__app__/onefuzzlib/reports.py index a789275c4b..7255b25914 100644 --- a/src/api-service/__app__/onefuzzlib/reports.py +++ b/src/api-service/__app__/onefuzzlib/reports.py @@ -10,7 +10,7 @@ from onefuzztypes.models import Report from pydantic import ValidationError -from .azure.containers import get_blob +from .azure.containers import StorageType, get_blob def parse_report( @@ -50,7 +50,7 @@ def get_report(container: str, filename: str) -> Optional[Report]: logging.error("get_report invalid extension: %s", metadata) return None - blob = get_blob(container, filename) + blob = get_blob(container, filename, StorageType.corpus) if blob is None: logging.error("get_report invalid blob: %s", metadata) return None diff --git a/src/api-service/__app__/onefuzzlib/repro.py b/src/api-service/__app__/onefuzzlib/repro.py index f8df33db28..435aa0dcb8 100644 --- a/src/api-service/__app__/onefuzzlib/repro.py +++ b/src/api-service/__app__/onefuzzlib/repro.py @@ -14,8 +14,8 @@ from onefuzztypes.models import ReproConfig, TaskVm from .azure.auth import build_auth -from .azure.containers import save_blob -from .azure.creds import get_base_region, get_func_storage +from .azure.containers import StorageType, save_blob +from .azure.creds import get_base_region from .azure.ip import get_public_ip from .azure.vm import VM from .extension import repro_extensions @@ -205,7 +205,7 @@ def build_repro_script(self) -> Optional[Error]: "repro-scripts", "%s/%s" % (self.vm_id, filename), files[filename], - account_id=get_func_storage(), + StorageType.config, ) logging.info("saved repro script") diff --git a/src/api-service/__app__/onefuzzlib/tasks/config.py b/src/api-service/__app__/onefuzzlib/tasks/config.py index 3978096f43..78fb1aa3a8 100644 --- a/src/api-service/__app__/onefuzzlib/tasks/config.py +++ b/src/api-service/__app__/onefuzzlib/tasks/config.py @@ -11,13 +11,13 @@ from onefuzztypes.enums import Compare, ContainerPermission, ContainerType, TaskFeature from onefuzztypes.models import TaskConfig, TaskDefinition, TaskUnitConfig -from ..azure.containers import blob_exists, container_exists, get_container_sas_url -from ..azure.creds import ( - get_func_storage, - get_fuzz_storage, - get_instance_id, - get_instance_url, +from ..azure.containers import ( + StorageType, + blob_exists, + container_exists, + get_container_sas_url, ) +from ..azure.creds import get_instance_id, get_instance_url from ..azure.queue import get_queue_sas from .defs import TASK_DEFINITIONS @@ -68,7 +68,7 @@ def check_containers(definition: TaskDefinition, config: TaskConfig) -> None: containers: Dict[ContainerType, List[str]] = {} for container in config.containers: if container.name not in checked: - if not container_exists(container.name): + if not container_exists(container.name, StorageType.corpus): raise TaskConfigError("missing container: %s" % container.name) checked.add(container.name) @@ -137,7 +137,7 @@ def check_config(config: TaskConfig) -> None: if TaskFeature.target_exe in definition.features: container = [x for x in config.containers if x.type == ContainerType.setup][0] - if not blob_exists(container.name, config.task.target_exe): + if not blob_exists(container.name, config.task.target_exe, StorageType.corpus): err = "target_exe `%s` does not exist in the setup container `%s`" % ( config.task.target_exe, container.name, @@ -153,7 +153,7 @@ def check_config(config: TaskConfig) -> None: for tool_path in tools_paths: if config.task.generator_exe.startswith(tool_path): generator = config.task.generator_exe.replace(tool_path, "") - if not blob_exists(container.name, generator): + if not blob_exists(container.name, generator, StorageType.corpus): err = ( "generator_exe `%s` does not exist in the tools container `%s`" % ( @@ -188,7 +188,7 @@ def build_task_config( telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"), heartbeat_queue=get_queue_sas( "task-heartbeat", - account_id=get_func_storage(), + StorageType.config, add=True, ), back_channel_address="https://%s/api/back_channel" % (get_instance_url()), @@ -198,11 +198,11 @@ def build_task_config( if definition.monitor_queue: config.input_queue = get_queue_sas( task_id, + StorageType.corpus, add=True, read=True, update=True, process=True, - account_id=get_fuzz_storage(), ) for container_def in definition.containers: @@ -219,6 +219,7 @@ def build_task_config( "path": "_".join(["task", container_def.type.name, str(i)]), "url": get_container_sas_url( container.name, + StorageType.corpus, read=ContainerPermission.Read in container_def.permissions, write=ContainerPermission.Write in container_def.permissions, add=ContainerPermission.Add in container_def.permissions, diff --git a/src/api-service/__app__/onefuzzlib/tasks/main.py b/src/api-service/__app__/onefuzzlib/tasks/main.py index 42199f3509..570fe8bf78 100644 --- a/src/api-service/__app__/onefuzzlib/tasks/main.py +++ b/src/api-service/__app__/onefuzzlib/tasks/main.py @@ -18,7 +18,7 @@ WebhookEventTaskStopped, ) -from ..azure.creds import get_fuzz_storage +from ..azure.containers import StorageType from ..azure.image import get_os from ..azure.queue import create_queue, delete_queue from ..orm import MappingIntStrAny, ORMMixin, QueryFilter @@ -123,7 +123,7 @@ def event_include(self) -> Optional[MappingIntStrAny]: } def init(self) -> None: - create_queue(self.task_id, account_id=get_fuzz_storage()) + create_queue(self.task_id, StorageType.corpus) self.state = TaskState.waiting self.save() @@ -132,7 +132,7 @@ def stopping(self) -> None: logging.info("stopping task: %s:%s", self.job_id, self.task_id) ProxyForward.remove_forward(self.task_id) - delete_queue(str(self.task_id), account_id=get_fuzz_storage()) + delete_queue(str(self.task_id), StorageType.corpus) Node.stop_task(self.task_id) self.state = TaskState.stopped self.save() diff --git a/src/api-service/__app__/onefuzzlib/tasks/scheduler.py b/src/api-service/__app__/onefuzzlib/tasks/scheduler.py index b3203d3dd5..0a34f78c93 100644 --- a/src/api-service/__app__/onefuzzlib/tasks/scheduler.py +++ b/src/api-service/__app__/onefuzzlib/tasks/scheduler.py @@ -10,8 +10,12 @@ from onefuzztypes.enums import OS, PoolState, TaskState from onefuzztypes.models import WorkSet, WorkUnit -from ..azure.containers import blob_exists, get_container_sas_url, save_blob -from ..azure.creds import get_func_storage +from ..azure.containers import ( + StorageType, + blob_exists, + get_container_sas_url, + save_blob, +) from ..pools import Pool from .config import build_task_config, get_setup_container from .main import Task @@ -60,20 +64,26 @@ def schedule_tasks() -> None: agent_config = build_task_config(task.job_id, task.task_id, task.config) setup_container = get_setup_container(task.config) - setup_url = get_container_sas_url(setup_container, read=True, list=True) + setup_url = get_container_sas_url( + setup_container, StorageType.corpus, read=True, list=True + ) setup_script = None - if task.os == OS.windows and blob_exists(setup_container, "setup.ps1"): + if task.os == OS.windows and blob_exists( + setup_container, "setup.ps1", StorageType.corpus + ): setup_script = "setup.ps1" - if task.os == OS.linux and blob_exists(setup_container, "setup.sh"): + if task.os == OS.linux and blob_exists( + setup_container, "setup.sh", StorageType.corpus + ): setup_script = "setup.sh" save_blob( "task-configs", "%s/config.json" % task.task_id, agent_config.json(exclude_none=True), - account_id=get_func_storage(), + StorageType.config, ) reboot = False count = 1 diff --git a/src/api-service/__app__/onefuzzlib/updates.py b/src/api-service/__app__/onefuzzlib/updates.py index 2d0233e945..ab74872c34 100644 --- a/src/api-service/__app__/onefuzzlib/updates.py +++ b/src/api-service/__app__/onefuzzlib/updates.py @@ -10,7 +10,7 @@ from onefuzztypes.enums import UpdateType from pydantic import BaseModel -from .azure.creds import get_func_storage +from .azure.containers import StorageType from .azure.queue import queue_object @@ -46,7 +46,7 @@ def queue_update( if not queue_object( "update-queue", update, - account_id=get_func_storage(), + StorageType.config, visibility_timeout=visibility_timeout, ): logging.error("unable to queue update") diff --git a/src/api-service/__app__/onefuzzlib/webhooks.py b/src/api-service/__app__/onefuzzlib/webhooks.py index 12f57cc596..3b7102b44b 100644 --- a/src/api-service/__app__/onefuzzlib/webhooks.py +++ b/src/api-service/__app__/onefuzzlib/webhooks.py @@ -27,7 +27,7 @@ from pydantic import BaseModel from .__version__ import __version__ -from .azure.creds import get_func_storage +from .azure.containers import StorageType from .azure.queue import queue_object from .orm import ORMMixin @@ -135,8 +135,8 @@ def queue_webhook(self) -> None: queue_object( "webhooks", obj, + StorageType.config, visibility_timeout=visibility_timeout, - account_id=get_func_storage(), ) diff --git a/src/api-service/__app__/pool/__init__.py b/src/api-service/__app__/pool/__init__.py index 15b8722c90..64f4bbc2d9 100644 --- a/src/api-service/__app__/pool/__init__.py +++ b/src/api-service/__app__/pool/__init__.py @@ -12,9 +12,9 @@ from onefuzztypes.requests import PoolCreate, PoolSearch, PoolStop from onefuzztypes.responses import BoolResult +from ..onefuzzlib.azure.containers import StorageType from ..onefuzzlib.azure.creds import ( get_base_region, - get_func_storage, get_instance_id, get_instance_url, get_regions, @@ -33,7 +33,7 @@ def set_config(pool: Pool) -> Pool: telemetry_key=os.environ.get("ONEFUZZ_TELEMETRY"), heartbeat_queue=get_queue_sas( "node-heartbeat", - account_id=get_func_storage(), + StorageType.config, add=True, ), instance_id=get_instance_id(),