From c95f2d1424688fad7ac9ecd333fd991befdeca37 Mon Sep 17 00:00:00 2001 From: iris <84595986+iris-garden@users.noreply.github.com> Date: Wed, 13 Sep 2023 13:34:16 -0400 Subject: [PATCH] [batch] restricts operations to GCS buckets using hot storage (#13200) Closes https://github.com/hail-is/hail/issues/13003. --- hail/.gitignore | 2 + hail/python/hail/backend/backend.py | 6 +- hail/python/hail/backend/local_backend.py | 7 +- hail/python/hail/backend/service_backend.py | 22 +++-- hail/python/hail/backend/spark_backend.py | 12 ++- hail/python/hail/context.py | 22 +++-- .../hail/docs/configuration_reference.rst | 53 +++++++++++ hail/python/hail/docs/index.rst | 1 + hail/python/hail/linalg/blockmatrix.py | 10 +- hail/python/hail/matrixtable.py | 4 +- hail/python/hail/methods/impex.py | 8 +- hail/python/hail/table.py | 8 +- .../aiogoogle/client/storage_client.py | 38 +++++++- hail/python/hailtop/aiotools/router_fs.py | 22 +++-- hail/python/hailtop/aiotools/validators.py | 51 ++++++++++ hail/python/hailtop/batch/backend.py | 85 +++++++++++------ hail/python/hailtop/batch/batch.py | 28 +++--- .../docs/_templates/_autosummary/class.rst | 4 +- hail/python/hailtop/batch/docs/conf.py | 6 +- .../batch/docs/configuration_reference.rst | 6 ++ hail/python/hailtop/batch/docs/index.rst | 1 + hail/python/hailtop/config/variables.py | 1 + .../hailctl/config/config_variables.py | 10 ++ .../test/hail/fs/test_worker_driver_fs.py | 8 +- hail/python/test/hailtop/aiotools/__init__.py | 0 hail/python/test/hailtop/batch/test_batch.py | 92 +++++++++++++++++-- 26 files changed, 400 insertions(+), 107 deletions(-) create mode 100644 hail/python/hail/docs/configuration_reference.rst create mode 100644 hail/python/hailtop/aiotools/validators.py create mode 100644 hail/python/hailtop/batch/docs/configuration_reference.rst delete mode 100644 hail/python/test/hailtop/aiotools/__init__.py diff --git a/hail/.gitignore b/hail/.gitignore index 4752314acb7..7c5c19cdbe2 100644 --- a/hail/.gitignore +++ b/hail/.gitignore @@ -52,12 +52,14 @@ python/hail/docs/vds/hail.vds.filter_chromosomes.rst python/hail/docs/vds/hail.vds.filter_intervals.rst python/hail/docs/vds/hail.vds.filter_samples.rst python/hail/docs/vds/hail.vds.filter_variants.rst +python/hail/docs/vds/hail.vds.impute_sex_chr_ploidy_from_interval_coverage.rst python/hail/docs/vds/hail.vds.impute_sex_chromosome_ploidy.rst python/hail/docs/vds/hail.vds.interval_coverage.rst python/hail/docs/vds/hail.vds.lgt_to_gt.rst python/hail/docs/vds/hail.vds.read_vds.rst python/hail/docs/vds/hail.vds.sample_qc.rst python/hail/docs/vds/hail.vds.split_multi.rst +python/hail/docs/vds/hail.vds.store_ref_block_max_length.rst python/hail/docs/vds/hail.vds.to_dense_mt.rst python/hail/docs/vds/hail.vds.to_merged_sparse_mt.rst python/hail/docs/vds/hail.vds.local_to_global.rst diff --git a/hail/python/hail/backend/backend.py b/hail/python/hail/backend/backend.py index d193b2fb877..eef2a3d5bb6 100644 --- a/hail/python/hail/backend/backend.py +++ b/hail/python/hail/backend/backend.py @@ -69,11 +69,11 @@ def __init__(self): self._references = {} @abc.abstractmethod - def stop(self): - pass + def validate_file(self, uri: str): + raise NotImplementedError @abc.abstractmethod - def validate_file_scheme(self, url): + def stop(self): pass @abc.abstractmethod diff --git a/hail/python/hail/backend/local_backend.py b/hail/python/hail/backend/local_backend.py index cd1bfdd655d..ec838b7d520 100644 --- a/hail/python/hail/backend/local_backend.py +++ b/hail/python/hail/backend/local_backend.py @@ -19,6 +19,7 @@ from hailtop.utils import find_spark_home from hailtop.fs.router_fs import RouterFS +from hailtop.aiotools.validators import validate_file _installed = False @@ -181,6 +182,9 @@ def __init__(self, tmpdir, log, quiet, append, branching_factor, self._initialize_flags({}) + def validate_file(self, uri: str) -> None: + validate_file(uri, self._fs.afs) + def jvm(self): return self._jvm @@ -213,9 +217,6 @@ def register_ir_function(self, def _is_registered_ir_function_name(self, name: str) -> bool: return name in self._registered_ir_function_names - def validate_file_scheme(self, url): - pass - def stop(self): self._jhc.stop() self._jhc = None diff --git a/hail/python/hail/backend/service_backend.py b/hail/python/hail/backend/service_backend.py index 4478e2efe85..975194add34 100644 --- a/hail/python/hail/backend/service_backend.py +++ b/hail/python/hail/backend/service_backend.py @@ -36,6 +36,7 @@ from ..builtin_references import BUILTIN_REFERENCES from ..ir import BaseIR from ..utils import ANY_REGION +from hailtop.aiotools.validators import validate_file ReferenceGenomeConfig = Dict[str, Any] @@ -205,7 +206,8 @@ async def create(*, name_prefix: Optional[str] = None, token: Optional[str] = None, regions: Optional[List[str]] = None, - gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None): + gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None, + gcs_bucket_allow_list: Optional[List[str]] = None): billing_project = configuration_of(ConfigVariable.BATCH_BILLING_PROJECT, billing_project, None) if billing_project is None: raise ValueError( @@ -216,7 +218,10 @@ async def create(*, gcs_requester_pays_configuration = get_gcs_requester_pays_configuration( gcs_requester_pays_configuration=gcs_requester_pays_configuration, ) - async_fs = RouterAsyncFS(gcs_kwargs={'gcs_requester_pays_configuration': gcs_requester_pays_configuration}) + async_fs = RouterAsyncFS( + gcs_kwargs={'gcs_requester_pays_configuration': gcs_requester_pays_configuration}, + gcs_bucket_allow_list=gcs_bucket_allow_list + ) sync_fs = RouterFS(async_fs) if batch_client is None: batch_client = await aiohb.BatchClient.create(billing_project, _token=token) @@ -279,7 +284,7 @@ async def create(*, worker_cores=worker_cores, worker_memory=worker_memory, name_prefix=name_prefix or '', - regions=regions, + regions=regions ) sb._initialize_flags(flags) return sb @@ -321,6 +326,9 @@ def __init__(self, self.name_prefix = name_prefix self.regions = regions + def validate_file(self, uri: str) -> None: + validate_file(uri, self._async_fs, validate_scheme=True) + def debug_info(self) -> Dict[str, Any]: return { 'jar_spec': str(self.jar_spec), @@ -343,12 +351,6 @@ def fs(self) -> FS: def logger(self): return log - def validate_file_scheme(self, url): - assert isinstance(self._async_fs, RouterAsyncFS) - if self._async_fs.get_scheme(url) == 'file': - raise ValueError( - f'Found local filepath {url} when using Query on Batch. Specify a remote filepath instead.') - def stop(self): async_to_blocking(self._async_fs.close()) async_to_blocking(self.async_bc.close()) @@ -663,7 +665,7 @@ async def inputs(infile, _): def add_sequence(self, name, fasta_file, index_file): # pylint: disable=unused-argument # FIXME Not only should this be in the cloud, it should be in the *right* cloud for blob in (fasta_file, index_file): - self.validate_file_scheme(blob) + self.validate_file(blob) def remove_sequence(self, name): # pylint: disable=unused-argument pass diff --git a/hail/python/hail/backend/spark_backend.py b/hail/python/hail/backend/spark_backend.py index 7a1e063d13a..7d93637ce3a 100644 --- a/hail/python/hail/backend/spark_backend.py +++ b/hail/python/hail/backend/spark_backend.py @@ -18,6 +18,8 @@ from hail.ir.renderer import CSERenderer from hail.table import Table from hail.matrixtable import MatrixTable +from hailtop.aiotools.router_fs import RouterAsyncFS +from hailtop.aiotools.validators import validate_file from .py4j_backend import Py4JBackend, handle_java_exception from ..hail_logging import Logger @@ -242,6 +244,13 @@ def __init__(self, idempotent, sc, spark_conf, app_name, master, self._initialize_flags({}) + self._router_async_fs = RouterAsyncFS( + gcs_kwargs={"gcs_requester_pays_configuration": gcs_requester_pays_project} + ) + + def validate_file(self, uri: str) -> None: + validate_file(uri, self._router_async_fs) + def jvm(self): return self._jvm @@ -251,9 +260,6 @@ def hail_package(self): def utils_package_object(self): return self._utils_package_object - def validate_file_scheme(self, url): - pass - def stop(self): self._jbackend.close() self._jhc.stop() diff --git a/hail/python/hail/context.py b/hail/python/hail/context.py index 52a9f150310..c1a72c0fedc 100644 --- a/hail/python/hail/context.py +++ b/hail/python/hail/context.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, Tuple, List +from typing import Optional, Union, Tuple, List, Dict import warnings import sys import os @@ -170,7 +170,8 @@ def stop(self): worker_cores=nullable(oneof(str, int)), worker_memory=nullable(str), gcs_requester_pays_configuration=nullable(oneof(str, sized_tupleof(str, sequenceof(str)))), - regions=nullable(sequenceof(str))) + regions=nullable(sequenceof(str)), + gcs_bucket_allow_list=nullable(dictof(str, sequenceof(str)))) def init(sc=None, app_name=None, master=None, @@ -195,7 +196,8 @@ def init(sc=None, worker_cores=None, worker_memory=None, gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None, - regions: Optional[List[str]] = None): + regions: Optional[List[str]] = None, + gcs_bucket_allow_list: Optional[Dict[str, List[str]]] = None): """Initialize and configure Hail. This function will be called with default arguments if any Hail functionality is used. If you @@ -311,7 +313,9 @@ def init(sc=None, List of regions to run jobs in when using the Batch backend. Use :data:`.ANY_REGION` to specify any region is allowed or use `None` to use the underlying default regions from the hailctl environment configuration. For example, use `hailctl config set batch/regions region1,region2` to set the default regions to use. - + gcs_bucket_allow_list: + A list of buckets that Hail should be permitted to read from or write to, even if their default policy is to + use "cold" storage. Should look like ``["bucket1", "bucket2"]``. """ if Env._hc: if idempotent: @@ -347,7 +351,8 @@ def init(sc=None, worker_memory=worker_memory, name_prefix=app_name, gcs_requester_pays_configuration=gcs_requester_pays_configuration, - regions=regions + regions=regions, + gcs_bucket_allow_list=gcs_bucket_allow_list )) if backend == 'spark': return init_spark( @@ -469,7 +474,8 @@ def init_spark(sc=None, name_prefix=nullable(str), token=nullable(str), gcs_requester_pays_configuration=nullable(oneof(str, sized_tupleof(str, sequenceof(str)))), - regions=nullable(sequenceof(str)) + regions=nullable(sequenceof(str)), + gcs_bucket_allow_list=nullable(sequenceof(str)) ) async def init_batch( *, @@ -492,6 +498,7 @@ async def init_batch( token: Optional[str] = None, gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None, regions: Optional[List[str]] = None, + gcs_bucket_allow_list: Optional[List[str]] = None ): from hail.backend.service_backend import ServiceBackend # FIXME: pass local_tmpdir and use on worker and driver @@ -506,7 +513,8 @@ async def init_batch( name_prefix=name_prefix, token=token, regions=regions, - gcs_requester_pays_configuration=gcs_requester_pays_configuration) + gcs_requester_pays_configuration=gcs_requester_pays_configuration, + gcs_bucket_allow_list=gcs_bucket_allow_list) log = _get_log(log) if tmpdir is None: diff --git a/hail/python/hail/docs/configuration_reference.rst b/hail/python/hail/docs/configuration_reference.rst new file mode 100644 index 00000000000..2beba42d278 --- /dev/null +++ b/hail/python/hail/docs/configuration_reference.rst @@ -0,0 +1,53 @@ +.. role:: python(code) + :language: python + :class: highlight + +.. role:: bash(code) + :language: bash + :class: highlight + +.. _sec-configuration-reference: + +Configuration Reference +======================= + +Configuration variables can be set for Hail Query by: + +#. passing them as keyword arguments to :func:`.init`, +#. running a command of the form :bash:`hailctl config set ` from the command line, or +#. setting them as shell environment variables by running a command of the form + :bash:`export =` in a terminal, which will set the variable for the current terminal + session. + +Each method for setting configuration variables listed above overrides variables set by any and all methods below it. +For example, setting a configuration variable by passing it to :func:`.init` will override any values set for the +variable using either :bash:`hailctl` or shell environment variables. + +.. warning:: + Some environment variables are shared between Hail Query and Hail Batch. Setting one of these variables via + :func:`.init`, :bash:`hailctl`, or environment variables will affect both Query and Batch. However, when + instantiating a class specific to one of the two, passing configuration to that class will not affect the other. + For example, if one value for :python:`gcs_bucket_allow_list` is passed to :func:`.init`, a different value + may be passed to the constructor for Batch's :python:`ServiceBackend`, which will only affect that instance of the + class (which can only be used within Batch), and won't affect Query. + +Supported Configuration Variables +--------------------------------- + +.. list-table:: GCS Bucket Allowlist + :widths: 50 50 + + * - Keyword Argument Name + - :python:`gcs_bucket_allow_list` + * - Keyword Argument Format + - :python:`["bucket1", "bucket2"]` + * - :bash:`hailctl` Variable Name + - :bash:`gcs/bucket_allow_list` + * - Environment Variable Name + - :bash:`HAIL_GCS_BUCKET_ALLOW_LIST` + * - :bash:`hailctl` and Environment Variable Format + - :bash:`bucket1,bucket2` + * - Effect + - Prevents Hail Query from erroring if the default storage policy for any of the given locations is to use cold storage. + * - Shared between Query and Batch + - Yes diff --git a/hail/python/hail/docs/index.rst b/hail/python/hail/docs/index.rst index 6275a362de7..ee85d26b2ac 100644 --- a/hail/python/hail/docs/index.rst +++ b/hail/python/hail/docs/index.rst @@ -20,6 +20,7 @@ Contents Hail on the Cloud Tutorials Reference (Python API) + Configuration Reference Overview How-To Guides Cheatsheets diff --git a/hail/python/hail/linalg/blockmatrix.py b/hail/python/hail/linalg/blockmatrix.py index 70f959ffe90..034e54cb07e 100644 --- a/hail/python/hail/linalg/blockmatrix.py +++ b/hail/python/hail/linalg/blockmatrix.py @@ -619,7 +619,7 @@ def write(self, path, overwrite=False, force_row_major=False, stage_locally=Fals If ``True``, major output will be written to temporary local storage before being copied to ``output``. """ - hl.current_backend().validate_file_scheme(path) + hl.current_backend().validate_file(path) writer = BlockMatrixNativeWriter(path, overwrite, force_row_major, stage_locally) Env.backend().execute(BlockMatrixWrite(self._bmir, writer)) @@ -647,7 +647,7 @@ def checkpoint(self, path, overwrite=False, force_row_major=False, stage_locally If ``True``, major output will be written to temporary local storage before being copied to ``output``. """ - hl.current_backend().validate_file_scheme(path) + hl.current_backend().validate_file(path) self.write(path, overwrite, force_row_major, stage_locally) return BlockMatrix.read(path, _assert_type=self._bmir._type) @@ -729,7 +729,7 @@ def write_from_entry_expr(entry_expr, path, overwrite=False, mean_impute=False, block_size: :obj:`int`, optional Block size. Default given by :meth:`.BlockMatrix.default_block_size`. """ - hl.current_backend().validate_file_scheme(path) + hl.current_backend().validate_file(path) if not block_size: block_size = BlockMatrix.default_block_size() @@ -1193,7 +1193,7 @@ def tofile(self, uri): -------- :meth:`.to_numpy` """ - hl.current_backend().validate_file_scheme(uri) + hl.current_backend().validate_file(uri) _check_entries_size(self.n_rows, self.n_cols) @@ -1975,7 +1975,7 @@ def export(path_in, path_out, delimiter='\t', header=None, add_index=False, para Describes which entries to export. One of: ``'full'``, ``'lower'``, ``'strict_lower'``, ``'upper'``, ``'strict_upper'``. """ - hl.current_backend().validate_file_scheme(path_out) + hl.current_backend().validate_file(path_out) export_type = ExportType.default(parallel) diff --git a/hail/python/hail/matrixtable.py b/hail/python/hail/matrixtable.py index e2f027f7938..a3929b22ee7 100644 --- a/hail/python/hail/matrixtable.py +++ b/hail/python/hail/matrixtable.py @@ -2677,7 +2677,7 @@ def checkpoint(self, output: str, overwrite: bool = False, stage_locally: bool = -------- >>> dataset = dataset.checkpoint('output/dataset_checkpoint.mt') """ - hl.current_backend().validate_file_scheme(output) + hl.current_backend().validate_file(output) if not _read_if_exists or not hl.hadoop_exists(f'{output}/_SUCCESS'): self.write(output=output, overwrite=overwrite, stage_locally=stage_locally, _codec_spec=_codec_spec) @@ -2727,7 +2727,7 @@ def write(self, output: str, overwrite: bool = False, stage_locally: bool = Fals If ``True``, overwrite an existing file at the destination. """ - hl.current_backend().validate_file_scheme(output) + hl.current_backend().validate_file(output) if _partitions is not None: _partitions, _partitions_type = hl.utils._dumps_partitions(_partitions, self.row_key.dtype) diff --git a/hail/python/hail/methods/impex.py b/hail/python/hail/methods/impex.py index 532b8e4ff23..94a59585464 100644 --- a/hail/python/hail/methods/impex.py +++ b/hail/python/hail/methods/impex.py @@ -121,7 +121,7 @@ def export_gen(dataset, output, precision=4, gp=None, id1=None, id2=None, require_biallelic(dataset, 'export_gen') - hl.current_backend().validate_file_scheme(output) + hl.current_backend().validate_file(output) if gp is None: if 'GP' in dataset.entry and dataset.GP.dtype == tarray(tfloat64): @@ -238,7 +238,7 @@ def export_bgen(mt, output, gp=None, varid=None, rsid=None, parallel=None, compr require_row_key_variant(mt, 'export_bgen') require_col_key_str(mt, 'export_bgen') - hl.current_backend().validate_file_scheme(output) + hl.current_backend().validate_file(output) if gp is None: if 'GP' in mt.entry and mt.GP.dtype == tarray(tfloat64): @@ -364,7 +364,7 @@ def export_plink(dataset, output, call=None, fam_id=None, ind_id=None, pat_id=No require_biallelic(dataset, 'export_plink', tolerate_generic_locus=True) - hl.current_backend().validate_file_scheme(output) + hl.current_backend().validate_file(output) if ind_id is None: require_col_key_str(dataset, "export_plink") @@ -539,7 +539,7 @@ def export_vcf(dataset, output, append_to_header=None, parallel=None, metadata=N **Note**: This feature is experimental, and the interface and defaults may change in future versions. """ - hl.current_backend().validate_file_scheme(output) + hl.current_backend().validate_file(output) _, ext = os.path.splitext(output) if ext == '.gz': diff --git a/hail/python/hail/table.py b/hail/python/hail/table.py index 2865491aeed..8bd9b689af5 100644 --- a/hail/python/hail/table.py +++ b/hail/python/hail/table.py @@ -1146,7 +1146,7 @@ def export(self, output, types_file=None, header=True, parallel=None, delimiter= delimiter : :class:`str` Field delimiter. """ - hl.current_backend().validate_file_scheme(output) + hl.current_backend().validate_file(output) parallel = ir.ExportType.default(parallel) Env.backend().execute( @@ -1325,7 +1325,7 @@ def checkpoint(self, output: str, overwrite: bool = False, stage_locally: bool = >>> table1 = table1.checkpoint('output/table_checkpoint.ht') """ - hl.current_backend().validate_file_scheme(output) + hl.current_backend().validate_file(output) if not _read_if_exists or not hl.hadoop_exists(f'{output}/_SUCCESS'): self.write(output=output, overwrite=overwrite, stage_locally=stage_locally, _codec_spec=_codec_spec) @@ -1372,7 +1372,7 @@ def write(self, output: str, overwrite=False, stage_locally: bool = False, If ``True``, overwrite an existing file at the destination. """ - hl.current_backend().validate_file_scheme(output) + hl.current_backend().validate_file(output) Env.backend().execute(ir.TableWrite(self._tir, ir.TableNativeWriter(output, overwrite, stage_locally, _codec_spec))) @@ -1500,7 +1500,7 @@ def write_many(self, If ``True``, overwrite an existing file at the destination. """ - hl.current_backend().validate_file_scheme(output) + hl.current_backend().validate_file(output) Env.backend().execute( ir.TableWrite( diff --git a/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py b/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py index 358ddff71b2..1c2e4b0f0d9 100644 --- a/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py +++ b/hail/python/hailtop/aiocloud/aiogoogle/client/storage_client.py @@ -307,10 +307,18 @@ def __init__(self, gcs_requester_pays_configuration: Optional[GCSRequesterPaysCo # Around May 2022, GCS started timing out a lot with our default 5s timeout kwargs['timeout'] = aiohttp.ClientTimeout(total=20) super().__init__('https://storage.googleapis.com/storage/v1', **kwargs) - gcs_requester_pays_configuration = get_gcs_requester_pays_configuration( - gcs_requester_pays_configuration=gcs_requester_pays_configuration, + self._gcs_requester_pays_configuration = get_gcs_requester_pays_configuration( + gcs_requester_pays_configuration=gcs_requester_pays_configuration ) - self._gcs_requester_pays_configuration = gcs_requester_pays_configuration + + async def bucket_info(self, bucket: str) -> Dict[str, Any]: + """ + See `the GCS API docs https://cloud.google.com/storage/docs/json_api/v1/buckets`_ for the list of bucket + properties in the response. + """ + kwargs: Dict[str, Any] = {} + self._update_params_with_user_project(kwargs, bucket) + return await self.get(f'/b/{bucket}', **kwargs) # docs: # https://cloud.google.com/storage/docs/json_api/v1 @@ -588,10 +596,34 @@ class GoogleStorageAsyncFS(AsyncFS): def __init__(self, *, storage_client: Optional[GoogleStorageClient] = None, + bucket_allow_list: Optional[List[str]] = None, **kwargs): if not storage_client: storage_client = GoogleStorageClient(**kwargs) self._storage_client = storage_client + if bucket_allow_list is None: + bucket_allow_list = [] + self.allowed_storage_locations = bucket_allow_list + + def storage_location(self, uri: str) -> str: + return self.get_bucket_and_name(uri)[0] + + async def is_hot_storage(self, location: str) -> bool: + """ + See `the GCS API docs https://cloud.google.com/storage/docs/storage-classes`_ for a list of possible storage + classes. + + Raises + ------ + :class:`aiohttp.ClientResponseError` + If the specified bucket does not exist, or if the account being used to access GCS does not have permission + to read the bucket's default storage policy. + """ + return (await self._storage_client.bucket_info(location))["storageClass"].lower() in ( + "standard", + "regional", + "multi_regional", + ) @staticmethod def valid_url(url: str) -> bool: diff --git a/hail/python/hailtop/aiotools/router_fs.py b/hail/python/hailtop/aiotools/router_fs.py index f2cecffe978..4743e5bd202 100644 --- a/hail/python/hailtop/aiotools/router_fs.py +++ b/hail/python/hailtop/aiotools/router_fs.py @@ -1,12 +1,13 @@ from typing import Any, Optional, List, Set, AsyncIterator, Dict, AsyncContextManager, Callable import asyncio -import urllib.parse from ..aiocloud import aioaws, aioazure, aiogoogle from .fs import (AsyncFS, MultiPartCreate, FileStatus, FileListEntry, ReadableStream, WritableStream, AsyncFSURL) from .local_fs import LocalAsyncFS +from hailtop.config import ConfigVariable, configuration_of + class RouterAsyncFS(AsyncFS): def __init__(self, @@ -15,18 +16,18 @@ def __init__(self, local_kwargs: Optional[Dict[str, Any]] = None, gcs_kwargs: Optional[Dict[str, Any]] = None, azure_kwargs: Optional[Dict[str, Any]] = None, - s3_kwargs: Optional[Dict[str, Any]] = None): + s3_kwargs: Optional[Dict[str, Any]] = None, + gcs_bucket_allow_list: Optional[List[str]] = None): self._filesystems = [] if filesystems is None else filesystems self._local_kwargs = local_kwargs or {} self._gcs_kwargs = gcs_kwargs or {} self._azure_kwargs = azure_kwargs or {} self._s3_kwargs = s3_kwargs or {} - - def get_scheme(self, uri: str) -> str: - scheme = urllib.parse.urlparse(uri).scheme or 'file' - if not scheme: - raise ValueError(f"no default scheme and URL has no scheme: {uri}") - return scheme + self._gcs_bucket_allow_list = ( + gcs_bucket_allow_list + if gcs_bucket_allow_list is not None + else configuration_of(ConfigVariable.GCS_BUCKET_ALLOW_LIST, None, fallback="").split(",") + ) def parse_url(self, url: str) -> AsyncFSURL: return self._get_fs(url).parse_url(url) @@ -50,7 +51,10 @@ def _load_fs(self, uri: str): if LocalAsyncFS.valid_url(uri): fs = LocalAsyncFS(**self._local_kwargs) elif aiogoogle.GoogleStorageAsyncFS.valid_url(uri): - fs = aiogoogle.GoogleStorageAsyncFS(**self._gcs_kwargs) + fs = aiogoogle.GoogleStorageAsyncFS( + **self._gcs_kwargs, + bucket_allow_list = self._gcs_bucket_allow_list.copy() + ) elif aioazure.AzureAsyncFS.valid_url(uri): fs = aioazure.AzureAsyncFS(**self._azure_kwargs) elif aioaws.S3AsyncFS.valid_url(uri): diff --git a/hail/python/hailtop/aiotools/validators.py b/hail/python/hailtop/aiotools/validators.py new file mode 100644 index 00000000000..7f52e6751dc --- /dev/null +++ b/hail/python/hailtop/aiotools/validators.py @@ -0,0 +1,51 @@ +from hailtop.aiocloud.aiogoogle.client.storage_client import GoogleStorageAsyncFS +from hailtop.aiotools.router_fs import RouterAsyncFS +from hailtop.utils import async_to_blocking +from textwrap import dedent +from typing import Optional +from urllib.parse import urlparse + + +def validate_file( + uri: str, + router_async_fs: RouterAsyncFS, + *, + validate_scheme: Optional[bool] = False +) -> None: + """ + Validates a URI's scheme if a file scheme cache was provided, and its cloud location's default storage policy if + the URI points to a cloud with an ``AsyncFS`` implementation that supports checking that policy. + + Raises + ------ + :class:`ValueError` + If one of the validation steps fails. + """ + if validate_scheme: + scheme = urlparse(uri).scheme + if not scheme or scheme == "file": + raise ValueError( + f"Local filepath detected: '{uri}'. The Hail Batch Service does not support the use of local " + "filepaths. Please specify a remote URI instead (e.g. 'gs://bucket/folder')." + ) + fs = router_async_fs._get_fs(uri) + if isinstance(fs, GoogleStorageAsyncFS): + location = fs.storage_location(uri) + if location not in fs.allowed_storage_locations: + if not async_to_blocking(fs.is_hot_storage(location)): + raise ValueError( + dedent( + f"""\ + GCS Bucket '{location}' is configured to use cold storage by default. Accessing the blob + '{uri}' would incur egress charges. Either + + * avoid the increased cost by changing the default storage policy for the bucket + (https://cloud.google.com/storage/docs/changing-default-storage-class) and the individual + blobs in it (https://cloud.google.com/storage/docs/changing-storage-classes) to 'Standard', or + + * accept the increased cost by adding '{location}' to the 'gcs_bucket_allow_list' configuration + variable (https://hail.is/docs/0.2/configuration_reference.html). + """ + ) + ) + fs.allowed_storage_locations.append(location) diff --git a/hail/python/hailtop/batch/backend.py b/hail/python/hailtop/batch/backend.py index 48886d7a115..766e808259c 100644 --- a/hail/python/hailtop/batch/backend.py +++ b/hail/python/hailtop/batch/backend.py @@ -23,13 +23,13 @@ from hailtop.batch_client.parse import parse_cpu_in_mcpu import hailtop.batch_client.client as bc from hailtop.batch_client.client import BatchClient -from hailtop.aiotools import AsyncFS from hailtop.aiotools.router_fs import RouterAsyncFS from hailtop.aiocloud.aiogoogle import GCSRequesterPaysConfiguration from . import resource, batch, job as _job # pylint: disable=unused-import from .exceptions import BatchException from .globals import DEFAULT_SHELL +from hailtop.aiotools.validators import validate_file HAIL_GENETICS_HAILTOP_IMAGE = os.environ.get('HAIL_GENETICS_HAILTOP_IMAGE', f'hailgenetics/hailtop:{pip_version()}') @@ -51,6 +51,29 @@ class Backend(abc.ABC, Generic[RunningBatchType]): _closed = False + def __init__(self): + self._requester_pays_fses: Dict[GCSRequesterPaysConfiguration, RouterAsyncFS] = {} + + def requester_pays_fs(self, requester_pays_config: GCSRequesterPaysConfiguration) -> RouterAsyncFS: + try: + return self._requester_pays_fses[requester_pays_config] + except KeyError: + if requester_pays_config is not None: + self._requester_pays_fses[requester_pays_config] = RouterAsyncFS( + gcs_kwargs={"gcs_requester_pays_configuration": requester_pays_config} + ) + return self._requester_pays_fses[requester_pays_config] + return self._fs + + def validate_file(self, uri: str, requester_pays_config: Optional[GCSRequesterPaysConfiguration] = None) -> None: + self._validate_file( + uri, self.requester_pays_fs(requester_pays_config) if requester_pays_config is not None else self._fs + ) + + @abc.abstractmethod + def _validate_file(self, uri: str, fs: RouterAsyncFS) -> None: + raise NotImplementedError + @abc.abstractmethod def _run(self, batch, dry_run, verbose, delete_scratch_on_exit, **backend_kwargs) -> RunningBatchType: """ @@ -64,7 +87,7 @@ def _run(self, batch, dry_run, verbose, delete_scratch_on_exit, **backend_kwargs @property @abc.abstractmethod - def _fs(self) -> AsyncFS: + def _fs(self) -> RouterAsyncFS: raise NotImplementedError() def _close(self): @@ -83,9 +106,6 @@ def close(self): self._close() self._closed = True - def validate_file_scheme(self, uri: str) -> None: - pass - def __del__(self): self.close() @@ -124,6 +144,7 @@ def __init__(self, tmp_dir: str = '/tmp/', gsa_key_file: Optional[str] = None, extra_docker_run_flags: Optional[str] = None): + super().__init__() self._tmp_dir = tmp_dir.rstrip('/') flags = '' @@ -139,12 +160,15 @@ def __init__(self, flags += f' -v {gsa_key_file}:/gsa-key/key.json' self._extra_docker_run_flags = flags - self.__fs: AsyncFS = RouterAsyncFS() + self.__fs = RouterAsyncFS() @property - def _fs(self): + def _fs(self) -> RouterAsyncFS: return self.__fs + def _validate_file(self, uri: str, fs: RouterAsyncFS) -> None: + validate_file(uri, fs) + def _run(self, batch: 'batch.Batch', dry_run: bool, @@ -429,7 +453,9 @@ class ServiceBackend(Backend[bc.Batch]): available regions to choose from. Use py:attribute:`.ServiceBackend.ANY_REGION` to signify the default is jobs can run in any available region. The default is jobs can run in any region unless a default value has been set with hailctl. An example invocation is `hailctl config set batch/regions "us-central1,us-east1"`. - + gcs_bucket_allow_list: + A list of buckets that the :class:`.ServiceBackend` should be permitted to read from or write to, even if their + default policy is to use "cold" storage. Should look like ``["bucket1", "bucket2"]``. """ @staticmethod @@ -448,17 +474,22 @@ def supported_regions(): with BatchClient('dummy') as dummy_client: return dummy_client.supported_regions() - def __init__(self, - *args, - billing_project: Optional[str] = None, - bucket: Optional[str] = None, - remote_tmpdir: Optional[str] = None, - google_project: Optional[str] = None, - token: Optional[str] = None, - regions: Optional[List[str]] = None, - gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None, - ): + def __init__( + self, + *args, + billing_project: Optional[str] = None, + bucket: Optional[str] = None, + remote_tmpdir: Optional[str] = None, + google_project: Optional[str] = None, + token: Optional[str] = None, + regions: Optional[List[str]] = None, + gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None, + gcs_bucket_allow_list: Optional[List[str]] = None, + ): + import nest_asyncio # pylint: disable=import-outside-toplevel + + super().__init__() nest_asyncio.apply() if len(args) > 2: @@ -498,7 +529,9 @@ def __init__(self, gcs_kwargs = {'gcs_requester_pays_configuration': google_project} else: gcs_kwargs = {'gcs_requester_pays_configuration': gcs_requester_pays_configuration} - self.__fs: RouterAsyncFS = RouterAsyncFS(gcs_kwargs=gcs_kwargs) + self.__fs = RouterAsyncFS(gcs_kwargs=gcs_kwargs, gcs_bucket_allow_list=gcs_bucket_allow_list) + + self.validate_file(self.remote_tmpdir) if regions is None: regions_from_conf = configuration_of(ConfigVariable.BATCH_REGIONS, None, None) @@ -510,9 +543,12 @@ def __init__(self, self.regions = regions @property - def _fs(self): + def _fs(self) -> RouterAsyncFS: return self.__fs + def _validate_file(self, uri: str, fs: RouterAsyncFS) -> None: + validate_file(uri, fs, validate_scheme=True) + def _close(self): if hasattr(self, '_batch_client'): self._batch_client.close() @@ -807,12 +843,3 @@ async def compile_job(job): batch._python_function_defs.clear() batch._python_function_files.clear() return batch_handle - - def validate_file_scheme(self, uri: str) -> None: - scheme = self.__fs.get_scheme(uri) - if scheme == "file": - raise ValueError( - f"Local filepath detected: '{uri}'. " - "ServiceBackend does not support the use of local filepaths. " - "Please specify a remote URI instead (e.g. gs://bucket/folder)." - ) diff --git a/hail/python/hailtop/batch/batch.py b/hail/python/hailtop/batch/batch.py index 971a42d258f..2ca14bd3bad 100644 --- a/hail/python/hailtop/batch/batch.py +++ b/hail/python/hailtop/batch/batch.py @@ -407,7 +407,7 @@ def _new_job_resource_file(self, source, value=None): return jrf def _new_input_resource_file(self, input_path, root=None): - self._backend.validate_file_scheme(input_path) + self._backend.validate_file(input_path, self.requester_pays_project) # Take care not to include an Azure SAS token query string in the local name. if AzureAsyncFS.valid_url(input_path): @@ -559,19 +559,23 @@ def write_output(self, resource: _resource.Resource, dest: str): Write a single job intermediate to a permanent location in GCS: - >>> b = Batch() - >>> j = b.new_job() - >>> j.command(f'echo "hello" > {j.ofile}') - >>> b.write_output(j.ofile, 'gs://mybucket/output/hello.txt') - >>> b.run() # doctest: +SKIP + .. code-block:: python + + b = Batch() + j = b.new_job() + j.command(f'echo "hello" > {j.ofile}') + b.write_output(j.ofile, 'gs://mybucket/output/hello.txt') + b.run() Write a single job intermediate to a permanent location in Azure: - >>> b = Batch() - >>> j = b.new_job() - >>> j.command(f'echo "hello" > {j.ofile}') - >>> b.write_output(j.ofile, 'https://my-account.blob.core.windows.net/my-container/output/hello.txt') - >>> b.run() # doctest: +SKIP + .. code-block:: python + + b = Batch() + j = b.new_job() + j.command(f'echo "hello" > {j.ofile}') + b.write_output(j.ofile, 'https://my-account.blob.core.windows.net/my-container/output/hello.txt') + b.run() # doctest: +SKIP .. warning:: @@ -619,6 +623,8 @@ def write_output(self, resource: _resource.Resource, dest: str): if dest_scheme == '': dest = os.path.abspath(os.path.expanduser(dest)) + self._backend.validate_file(dest, self.requester_pays_project) + resource._add_output_path(dest) def select_jobs(self, pattern: str) -> List[job.Job]: diff --git a/hail/python/hailtop/batch/docs/_templates/_autosummary/class.rst b/hail/python/hailtop/batch/docs/_templates/_autosummary/class.rst index c72f9cf3767..2eb1c1c9ca2 100644 --- a/hail/python/hailtop/batch/docs/_templates/_autosummary/class.rst +++ b/hail/python/hailtop/batch/docs/_templates/_autosummary/class.rst @@ -8,14 +8,16 @@ :no-inherited-members: {% block attributes %} - {% if attributes %} + {% if (attributes | reject('in', inherited_members) | list | count) != 0 %} .. rubric:: Attributes .. autosummary:: :nosignatures: {% for item in attributes %} + {% if item not in inherited_members %} ~{{ name }}.{{ item }} + {% endif %} {%- endfor %} {% endif %} {% endblock %} diff --git a/hail/python/hailtop/batch/docs/conf.py b/hail/python/hailtop/batch/docs/conf.py index e323ff78709..72fb60ca3de 100644 --- a/hail/python/hailtop/batch/docs/conf.py +++ b/hail/python/hailtop/batch/docs/conf.py @@ -29,7 +29,11 @@ # The full version, including alpha/beta/rc tags release = '' nitpicky = True -nitpick_ignore = [('py:class', 'hailtop.batch_client.client.Batch'), ('py:class', 'typing.Self')] +nitpick_ignore = [ + ('py:class', 'hailtop.batch_client.client.Batch'), + ('py:class', 'hailtop.aiotools.router_fs.RouterAsyncFS'), + ('py:class', 'typing.Self'), +] # -- General configuration --------------------------------------------------- diff --git a/hail/python/hailtop/batch/docs/configuration_reference.rst b/hail/python/hailtop/batch/docs/configuration_reference.rst new file mode 100644 index 00000000000..9a88c48d2e7 --- /dev/null +++ b/hail/python/hailtop/batch/docs/configuration_reference.rst @@ -0,0 +1,6 @@ +.. _sec-configuration-reference: + +Configuration Reference +======================= + +See `the query documentation `__. diff --git a/hail/python/hailtop/batch/docs/index.rst b/hail/python/hailtop/batch/docs/index.rst index b5f4e21ff8a..46b252f82b6 100644 --- a/hail/python/hailtop/batch/docs/index.rst +++ b/hail/python/hailtop/batch/docs/index.rst @@ -22,6 +22,7 @@ Contents Batch Service Cookbooks Reference (Python API) + Configuration Reference Advanced UI Search Help Change Log And Version Policy diff --git a/hail/python/hailtop/config/variables.py b/hail/python/hailtop/config/variables.py index 3797679b13c..cfd82a3d774 100644 --- a/hail/python/hailtop/config/variables.py +++ b/hail/python/hailtop/config/variables.py @@ -5,6 +5,7 @@ class ConfigVariable(str, Enum): DOMAIN = 'domain' GCS_REQUESTER_PAYS_PROJECT = 'gcs_requester_pays/project' GCS_REQUESTER_PAYS_BUCKETS = 'gcs_requester_pays/buckets' + GCS_BUCKET_ALLOW_LIST = 'gcs/bucket_allow_list' BATCH_BUCKET = 'batch/bucket' BATCH_REMOTE_TMPDIR = 'batch/remote_tmpdir' BATCH_REGIONS = 'batch/regions' diff --git a/hail/python/hailtop/hailctl/config/config_variables.py b/hail/python/hailtop/hailctl/config/config_variables.py index e8941295749..966a008a854 100644 --- a/hail/python/hailtop/hailctl/config/config_variables.py +++ b/hail/python/hailtop/hailctl/config/config_variables.py @@ -31,6 +31,16 @@ def config_variables(): lambda x: re.fullmatch(r'[^:/\s]+(,[^:/\s]+)*', x) is not None, 'should be comma separated list of bucket names'), ), + ConfigVariable.GCS_BUCKET_ALLOW_LIST: ConfigVariableInfo( + help_msg=( + 'Allows Hail to access the given buckets, even if their default policy is to use cold storage.' + ), + validation=( + # See https://cloud.google.com/storage/docs/buckets#naming for bucket naming requirements. + lambda x: re.fullmatch(r'^[-\.\w]+(,[-\.\w]+)*$', x) is not None, + "should match the pattern 'bucket1,bucket2,bucket3'." + ), + ), ConfigVariable.BATCH_BUCKET: ConfigVariableInfo( help_msg='Deprecated - Name of GCS bucket to use as a temporary scratch directory', validation=(lambda x: re.fullmatch(r'[^:/\s]+', x) is not None, diff --git a/hail/python/test/hail/fs/test_worker_driver_fs.py b/hail/python/test/hail/fs/test_worker_driver_fs.py index 6863d72def7..7a2b695546d 100644 --- a/hail/python/test/hail/fs/test_worker_driver_fs.py +++ b/hail/python/test/hail/fs/test_worker_driver_fs.py @@ -14,7 +14,7 @@ def test_requester_pays_no_settings(): try: hl.import_table('gs://hail-test-requester-pays-fds32/hello') except Exception as exc: - assert "Bucket is a requester pays bucket but no user project provided" in exc.args[0] + assert "Bucket is a requester pays bucket but no user project provided" in str(exc) else: assert False @@ -25,7 +25,7 @@ def test_requester_pays_write_no_settings(): try: hl.utils.range_table(4, n_partitions=4).write(random_filename, overwrite=True) except Exception as exc: - assert "Bucket is a requester pays bucket but no user project provided" in exc.args[0] + assert "Bucket is a requester pays bucket but no user project provided" in str(exc) else: hl.current_backend().fs.rmtree(random_filename) assert False @@ -63,7 +63,7 @@ def test_requester_pays_with_project(): try: hl.import_table('gs://hail-test-requester-pays-fds32/hello') except Exception as exc: - assert "Bucket is a requester pays bucket but no user project provided" in exc.args[0] + assert "Bucket is a requester pays bucket but no user project provided" in str(exc) else: assert False @@ -114,7 +114,7 @@ def test_requester_pays_with_project_more_than_one_partition(): try: hl.import_table('gs://hail-test-requester-pays-fds32/zero-to-nine', min_partitions=8) except Exception as exc: - assert "Bucket is a requester pays bucket but no user project provided" in exc.args[0] + assert "Bucket is a requester pays bucket but no user project provided" in str(exc) else: assert False diff --git a/hail/python/test/hailtop/aiotools/__init__.py b/hail/python/test/hailtop/aiotools/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/hail/python/test/hailtop/batch/test_batch.py b/hail/python/test/hailtop/batch/test_batch.py index 25916667e28..1f8f2efc68e 100644 --- a/hail/python/test/hailtop/batch/test_batch.py +++ b/hail/python/test/hailtop/batch/test_batch.py @@ -26,7 +26,11 @@ from hailtop.test_utils import skip_in_azure from hailtop.httpx import ClientResponseError +from configparser import ConfigParser +from hailtop.config import get_user_config, user_config from hailtop.config.variables import ConfigVariable +from hailtop.aiocloud.aiogoogle.client.storage_client import GoogleStorageAsyncFS +from _pytest.monkeypatch import MonkeyPatch DOCKER_ROOT_IMAGE = os.environ.get('DOCKER_ROOT_IMAGE', 'ubuntu:22.04') @@ -491,6 +495,9 @@ def test_failed_jobs_dont_stop_always_run_jobs(self): class ServiceTests(unittest.TestCase): def setUp(self): + # https://stackoverflow.com/questions/42332030/pytest-monkeypatch-setattr-inside-of-test-class-method + self.monkeypatch = MonkeyPatch() + self.backend = ServiceBackend() remote_tmpdir = get_remote_tmpdir('hailtop_test_batch_service_tests') @@ -681,14 +688,6 @@ def test_file_name_space(self): res_status = res.status() assert res_status['state'] == 'success', str((res_status, res.debug_info())) - def test_local_paths_error(self): - b = self.batch() - b.new_job() - for input in ["hi.txt", "~/hello.csv", "./hey.tsv", "/sup.json", "file://yo.yaml"]: - with pytest.raises(ValueError) as e: - b.read_input(input) - assert str(e.value).startswith("Local filepath detected") - def test_dry_run(self): b = self.batch() j = b.new_job() @@ -1424,3 +1423,80 @@ def test_non_spot_batch(self): assert res.get_job(1).status()['spec']['resources']['preemptible'] == False assert res.get_job(2).status()['spec']['resources']['preemptible'] == False assert res.get_job(3).status()['spec']['resources']['preemptible'] == True + + def test_local_file_paths_error(self): + b = self.batch() + j = b.new_job() + for input in ["hi.txt", "~/hello.csv", "./hey.tsv", "/sup.json", "file://yo.yaml"]: + with pytest.raises(ValueError) as e: + b.read_input(input) + assert str(e.value).startswith("Local filepath detected") + + @skip_in_azure + def test_validate_cloud_storage_policy(self): + # buckets do not exist (bucket names can't contain the string "google" per + # https://cloud.google.com/storage/docs/buckets) + fake_bucket1 = "google" + fake_bucket2 = "google1" + no_bucket_error = "bucket does not exist" + # bucket exists, but account does not have permissions on it + no_perms_bucket = "test" + no_perms_error = "does not have storage.buckets.get access" + # bucket exists and account has permissions, but is set to use cold storage by default + cold_bucket = "hail-test-cold-storage" + cold_error = "configured to use cold storage by default" + fake_uri1, fake_uri2, no_perms_uri, cold_uri = [ + f"gs://{bucket}/test" for bucket in [fake_bucket1, fake_bucket2, no_perms_bucket, cold_bucket] + ] + + def _test_raises(exception_type, exception_msg, func): + with pytest.raises(exception_type) as e: + func() + assert exception_msg in str(e.value) + + def _test_raises_no_bucket_error(remote_tmpdir, arg = None): + _test_raises(ClientResponseError, no_bucket_error, lambda: ServiceBackend(remote_tmpdir=remote_tmpdir, gcs_bucket_allow_list=arg)) + + def _test_raises_cold_error(func): + _test_raises(ValueError, cold_error, func) + + # no configuration, nonexistent buckets error + _test_raises_no_bucket_error(fake_uri1) + _test_raises_no_bucket_error(fake_uri2) + + # no configuration, no perms bucket errors + _test_raises(ClientResponseError, no_perms_error, lambda: ServiceBackend(remote_tmpdir=no_perms_uri)) + + # no configuration, cold bucket errors + _test_raises_cold_error(lambda: ServiceBackend(remote_tmpdir=cold_uri)) + b = self.batch() + _test_raises_cold_error(lambda: b.read_input(cold_uri)) + j = b.new_job() + j.command(f"echo hello > {j.ofile}") + _test_raises_cold_error(lambda: b.write_output(j.ofile, cold_uri)) + + # hailctl config, allowlisted nonexistent buckets don't error + base_config = get_user_config() + local_config = ConfigParser() + local_config.read_dict({ + **{ + section: {key: val for key, val in base_config[section].items()} + for section in base_config.sections() + }, + **{"gcs": {"bucket_allow_list": f"{fake_bucket1},{fake_bucket2}"}} + }) + def _get_user_config(): + return local_config + self.monkeypatch.setattr(user_config, "get_user_config", _get_user_config) + ServiceBackend(remote_tmpdir=fake_uri1) + ServiceBackend(remote_tmpdir=fake_uri2) + + # environment variable config, only allowlisted nonexistent buckets don't error + self.monkeypatch.setenv("HAIL_GCS_BUCKET_ALLOW_LIST", fake_bucket2) + _test_raises_no_bucket_error(fake_uri1) + ServiceBackend(remote_tmpdir=fake_uri2) + + # arg to constructor config, only allowlisted nonexistent buckets don't error + arg = [fake_bucket1] + ServiceBackend(remote_tmpdir=fake_uri1, gcs_bucket_allow_list=arg) + _test_raises_no_bucket_error(fake_uri2, arg)