From 5ad7634021ccf444e6b321007dad24db64a3dd92 Mon Sep 17 00:00:00 2001 From: Ben Pedigo Date: Fri, 16 Feb 2024 14:07:56 -0800 Subject: [PATCH] format --- caveclient/__init__.py | 2 +- caveclient/base.py | 12 +- caveclient/datastack_lookup.py | 34 +++-- caveclient/endpoints.py | 6 +- caveclient/format_utils.py | 12 +- caveclient/materializationengine.py | 4 +- caveclient/session_config.py | 2 +- caveclient/tools/table_manager.py | 188 +++++++++++++++++++--------- 8 files changed, 176 insertions(+), 84 deletions(-) diff --git a/caveclient/__init__.py b/caveclient/__init__.py index a186d057..633bba82 100644 --- a/caveclient/__init__.py +++ b/caveclient/__init__.py @@ -2,4 +2,4 @@ from .frameworkclient import CAVEclient -__all__ = ["CAVEclient"] \ No newline at end of file +__all__ = ["CAVEclient"] diff --git a/caveclient/base.py b/caveclient/base.py index 4f47a936..b588bc8b 100644 --- a/caveclient/base.py +++ b/caveclient/base.py @@ -113,7 +113,9 @@ def _check_authorization_redirect(response): ) -def _api_versions(server_name, server_address, endpoints_common, auth_header, verify=True): +def _api_versions( + server_name, server_address, endpoints_common, auth_header, verify=True +): """Asks a server what API versions are available, if possible""" url_mapping = {server_name: server_address} url_base = endpoints_common.get("get_api_versions", None) @@ -140,7 +142,11 @@ def _api_endpoints( if api_version == "latest": try: avail_vs_server = _api_versions( - server_name, server_address, endpoints_common, auth_header, verify=verify + server_name, + server_address, + endpoints_common, + auth_header, + verify=verify, ) avail_vs_server = set(avail_vs_server) except: @@ -241,7 +247,6 @@ def __init__( pool_block=None, over_client=None, ): - super(ClientBaseWithDataset, self).__init__( server_address, auth_header, @@ -276,7 +281,6 @@ def __init__( pool_block=None, over_client=None, ): - super(ClientBaseWithDatastack, self).__init__( server_address, auth_header, diff --git a/caveclient/datastack_lookup.py b/caveclient/datastack_lookup.py index f9bd57c8..e8f64520 100644 --- a/caveclient/datastack_lookup.py +++ b/caveclient/datastack_lookup.py @@ -2,50 +2,57 @@ import json from . import auth import logging + logger = logging.getLogger(__name__) DEFAULT_LOCATION = auth.default_token_location -DEFAULT_DATASTACK_FILE = 'cave_datastack_to_server_map.json' +DEFAULT_DATASTACK_FILE = "cave_datastack_to_server_map.json" + -def read_map(filename = None): +def read_map(filename=None): if filename is None: filename = os.path.join(DEFAULT_LOCATION, DEFAULT_DATASTACK_FILE) try: - with open(os.path.expanduser(filename), 'r') as f: + with open(os.path.expanduser(filename), "r") as f: data = json.load(f) return data except: return {} + def is_writable(filename): # File exists but is not writeable if os.path.exists(os.path.expanduser(filename)): if not os.access(os.path.expanduser(filename), os.W_OK): return False - else: + else: try: # File does not exist so make the directories if possible if not os.path.exists(os.path.expanduser(DEFAULT_LOCATION)): - os.makedirs(os.path.expanduser(DEFAULT_LOCATION)) - with open(os.path.expanduser(filename), 'w') as f: + os.makedirs(os.path.expanduser(DEFAULT_LOCATION)) + with open(os.path.expanduser(filename), "w") as f: if not f.writable(): return False except IOError: return False return True -def write_map(data, filename = None): + +def write_map(data, filename=None): if filename is None: filename = os.path.join(DEFAULT_LOCATION, DEFAULT_DATASTACK_FILE) - if is_writable(filename): - with open(os.path.expanduser(filename), 'w') as f: + if is_writable(filename): + with open(os.path.expanduser(filename), "w") as f: json.dump(data, f) return True else: - logging.warn(f'Did not write cache — file {os.path.expanduser(filename)} is not writeable') + logging.warn( + f"Did not write cache — file {os.path.expanduser(filename)} is not writeable" + ) return False + def handle_server_address(datastack, server_address, filename=None, write=False): data = read_map(filename) if server_address is not None: @@ -53,14 +60,18 @@ def handle_server_address(datastack, server_address, filename=None, write=False) data[datastack] = server_address wrote = write_map(data, filename) if wrote: - logger.warning(f"Updated datastack-to-server cache — '{server_address}' will now be used by default for datastack '{datastack}'") + logger.warning( + f"Updated datastack-to-server cache — '{server_address}' will now be used by default for datastack '{datastack}'" + ) return server_address else: return data.get(datastack) + def get_datastack_cache(filename=None): return read_map(filename) + def reset_server_address_cache(datastack, filename=None): """Remove one or more datastacks from the datastack-to-server cache. @@ -78,4 +89,3 @@ def reset_server_address_cache(datastack, filename=None): data.pop(ds, None) logger.warning(f"Wiping '{ds}' from datastack-to-server cache") write_map(data, filename) - \ No newline at end of file diff --git a/caveclient/endpoints.py b/caveclient/endpoints.py index d31c85ba..2bd158bb 100644 --- a/caveclient/endpoints.py +++ b/caveclient/endpoints.py @@ -273,6 +273,6 @@ fallback_ngl_endpoint = "https://neuroglancer.neuvue.io/" ngl_endpoints_common = { - 'get_info': "{ngl_url}/version.json", - 'fallback_ngl_url': fallback_ngl_endpoint, -} \ No newline at end of file + "get_info": "{ngl_url}/version.json", + "fallback_ngl_url": fallback_ngl_endpoint, +} diff --git a/caveclient/format_utils.py b/caveclient/format_utils.py index 41683be5..26a2007a 100644 --- a/caveclient/format_utils.py +++ b/caveclient/format_utils.py @@ -11,15 +11,17 @@ def format_precomputed_neuroglancer(objurl): objurl_out = None return objurl_out + def format_neuroglancer(objurl): qry = urlparse(objurl) - if qry.scheme == 'graphene' or 'https': + if qry.scheme == "graphene" or "https": return format_graphene(objurl) - elif qry.scheme == 'precomputed': + elif qry.scheme == "precomputed": return format_precomputed_neuroglancer(objurl) else: return format_raw(objurl) + def format_precomputed_https(objurl): qry = urlparse(objurl) if qry.scheme == "gs": @@ -41,6 +43,7 @@ def format_graphene(objurl): objurl_out = None return objurl_out + def format_verbose_graphene(objurl): qry = urlparse(objurl) if qry.scheme == "http" or qry.scheme == "https": @@ -49,6 +52,7 @@ def format_verbose_graphene(objurl): objurl_out = f"graphene://middleauth+{qry.netloc}{qry.path}" return objurl_out + def format_cloudvolume(objurl): qry = urlparse(objurl) if qry.scheme == "graphene": @@ -58,14 +62,16 @@ def format_cloudvolume(objurl): else: return None + def format_raw(objurl): return objurl + def format_cave_explorer(objurl): qry = urlparse(objurl) if qry.scheme == "graphene" or qry.scheme == "https": return format_verbose_graphene(objurl) - elif qry.scheme == 'precomputed': + elif qry.scheme == "precomputed": return format_precomputed_neuroglancer(objurl) else: return None diff --git a/caveclient/materializationengine.py b/caveclient/materializationengine.py index 5f63399b..276053cc 100644 --- a/caveclient/materializationengine.py +++ b/caveclient/materializationengine.py @@ -2369,7 +2369,9 @@ def query_view( else: return response.json() - def get_unique_string_values(self, table: str, datastack_name: Optional[str] = None): + def get_unique_string_values( + self, table: str, datastack_name: Optional[str] = None + ): """Get unique string values for a table Parameters diff --git a/caveclient/session_config.py b/caveclient/session_config.py index 7e9d6b25..31e5d8de 100644 --- a/caveclient/session_config.py +++ b/caveclient/session_config.py @@ -46,4 +46,4 @@ def patch_session( session.mount("http://", http) session.mount("https://", http) - pass \ No newline at end of file + pass diff --git a/caveclient/tools/table_manager.py b/caveclient/tools/table_manager.py index abca4976..68d17f4e 100644 --- a/caveclient/tools/table_manager.py +++ b/caveclient/tools/table_manager.py @@ -3,20 +3,24 @@ import re from cachetools import cached, TTLCache, keys import logging + logger = logging.getLogger(__name__) # json schema column types that can act as potential columns for looking at tables ALLOW_COLUMN_TYPES = ["integer", "boolean", "string", "float"] -SPATIAL_POINT_TYPES = ['SpatialPoint'] +SPATIAL_POINT_TYPES = ["SpatialPoint"] # Helper functions for turning schema field names ot column names + def bound_pt_position(pt): return f"{pt}_position" + def bound_pt_root_id(pt): return f"{pt}_root_id" + def add_with_suffix(namesA, namesB, suffix): all_names = [] rename_map = {} @@ -29,6 +33,7 @@ def add_with_suffix(namesA, namesB, suffix): all_names.append(name) return all_names, rename_map + def pop_empty(filter_dict): keys_to_pop = [] for k in filter_dict.keys(): @@ -38,6 +43,7 @@ def pop_empty(filter_dict): filter_dict.pop(k) return filter_dict + def combine_names(tableA, namesA, tableB, namesB, suffixes): table_map = {} final_namesA, rename_mapA = add_with_suffix(namesA, namesB, suffixes[0]) @@ -50,24 +56,24 @@ def combine_names(tableA, namesA, tableB, namesB, suffixes): return final_namesA + final_namesB, table_map, rename_map + def get_all_table_metadata(client): meta = client.materialize.get_tables_metadata() tables = [] for m in meta: - if m.get('annotation_table'): - tables.append(m['annotation_table']) + if m.get("annotation_table"): + tables.append(m["annotation_table"]) else: - tables.append(m['table_name']) - return { - tn: md - for tn, md in zip(tables, meta) - } + tables.append(m["table_name"]) + return {tn: md for tn, md in zip(tables, meta)} + def get_all_view_metadata(client): views = client.materialize.get_views() view_schema = client.materialize.get_view_schemas() return views, view_schema + def is_list_like(x): if isinstance(x, str): return False @@ -77,13 +83,15 @@ def is_list_like(x): except: return False + def update_spatial_dict(spatial_dict): new_dict = {} for k in spatial_dict: - nm = re.match('(.*)_bbox$', k).groups()[0] + nm = re.match("(.*)_bbox$", k).groups()[0] new_dict[nm] = spatial_dict[k] return new_dict + def filter_empty(filter_dict): new_dict = {} for k, v in filter_dict.items(): @@ -92,12 +100,14 @@ def filter_empty(filter_dict): new_dict[k] = v return new_dict + def replace_empty_with_none(filter_dict): if len(filter_dict) == 0: return None else: return filter_dict + _schema_cache = TTLCache(maxsize=128, ttl=86_400) @@ -142,6 +152,7 @@ def get_col_info( _table_cache = TTLCache(maxsize=128, ttl=86_400) + def _table_key(table_name, meta, client, **kwargs): merge_schema = kwargs.get("merge_schema", True) allow_types = kwargs.get("allow_types", ALLOW_COLUMN_TYPES) @@ -149,7 +160,13 @@ def _table_key(table_name, meta, client, **kwargs): return key -def get_view_info(view_name, meta, schema, allow_types=ALLOW_COLUMN_TYPES, spatial_types=SPATIAL_POINT_TYPES): +def get_view_info( + view_name, + meta, + schema, + allow_types=ALLOW_COLUMN_TYPES, + spatial_types=SPATIAL_POINT_TYPES, +): """Assemble Parameters @@ -169,8 +186,8 @@ def get_view_info(view_name, meta, schema, allow_types=ALLOW_COLUMN_TYPES, spati desc = meta.get("description", "") is_live = meta.get("live_compatible", False) pts = [] - vals = [k for k,v in schema.items() if v['type'] in allow_types] - unbd_pts = [k for k,v in schema.items() if v['type'] in spatial_types] + vals = [k for k, v in schema.items() if v["type"] in allow_types] + unbd_pts = [k for k, v in schema.items() if v["type"] in spatial_types] column_map = {k: view_name for k in vals + unbd_pts} rename_map = {} return ( @@ -181,13 +198,18 @@ def get_view_info(view_name, meta, schema, allow_types=ALLOW_COLUMN_TYPES, spati rename_map, [view_name, None], desc, - is_live + is_live, ) - + @cached(cache=_table_cache, key=_table_key) def get_table_info( - tn, meta, client, allow_types=ALLOW_COLUMN_TYPES, merge_schema=True, suffixes=["", "_ref"] + tn, + meta, + client, + allow_types=ALLOW_COLUMN_TYPES, + merge_schema=True, + suffixes=["", "_ref"], ): """Get the point column and additional columns from a table @@ -247,7 +269,7 @@ def get_table_info( column_map, rename_map, [name_base, name_ref], - meta.get('description'), + meta.get("description"), ) @@ -270,7 +292,9 @@ def table_metadata(table_name, client): return meta -def make_class_vals(pts, val_cols, unbd_pts, table_map, rename_map, table_list, raw_points=False): +def make_class_vals( + pts, val_cols, unbd_pts, table_map, rename_map, table_list, raw_points=False +): class_vals = { "_reference_table": attrs.field( init=False, default=table_list[1], metadata={"is_meta": True} @@ -326,53 +350,68 @@ def __attrs_post_init__(self): ] ) filter_equal_dict = { - tn: filter_empty(attrs.asdict( - self, - filter=lambda a, v: is_list_like(v) == False - and v is not None - and a.metadata.get("is_bbox", False) == False - and a.metadata.get("is_meta", False) == False - and a.metadata.get("table") == tn, - )) + tn: filter_empty( + attrs.asdict( + self, + filter=lambda a, v: is_list_like(v) == False + and v is not None + and a.metadata.get("is_bbox", False) == False + and a.metadata.get("is_meta", False) == False + and a.metadata.get("table") == tn, + ) + ) for tn in tables } filter_in_dict = { - tn: filter_empty(attrs.asdict( - self, - filter=lambda a, v: is_list_like(v) == True - and v is not None - and a.metadata.get("is_bbox", False) == False - and a.metadata.get("is_meta", False) == False - and a.metadata.get("table") == tn, - )) + tn: filter_empty( + attrs.asdict( + self, + filter=lambda a, v: is_list_like(v) == True + and v is not None + and a.metadata.get("is_bbox", False) == False + and a.metadata.get("is_meta", False) == False + and a.metadata.get("table") == tn, + ) + ) for tn in tables } spatial_dict = { tn: update_spatial_dict( - attrs.asdict( - self, - filter=lambda a, v: a.metadata.get("is_bbox", False) - and v is not None - and a.metadata.get("is_meta", False) == False - and a.metadata.get("table") == tn, + attrs.asdict( + self, + filter=lambda a, v: a.metadata.get("is_bbox", False) + and v is not None + and a.metadata.get("is_meta", False) == False + and a.metadata.get("table") == tn, ) ) for tn in tables } self.filter_kwargs_live = { - "filter_equal_dict": replace_empty_with_none(filter_empty(filter_equal_dict)), + "filter_equal_dict": replace_empty_with_none( + filter_empty(filter_equal_dict) + ), "filter_in_dict": replace_empty_with_none(filter_empty(filter_in_dict)), - "filter_spatial_dict": replace_empty_with_none(filter_empty(spatial_dict)), + "filter_spatial_dict": replace_empty_with_none( + filter_empty(spatial_dict) + ), } - if len(tables)==2: + if len(tables) == 2: self.filter_kwargs_mat = self.filter_kwargs_live else: self.filter_kwargs_mat = { - k: replace_empty_with_none(self.filter_kwargs_live[k].get(list(tables)[0],[])) - for k in ["filter_equal_dict", "filter_in_dict", "filter_spatial_dict"] if self.filter_kwargs_live[k] is not None + k: replace_empty_with_none( + self.filter_kwargs_live[k].get(list(tables)[0], []) + ) + for k in [ + "filter_equal_dict", + "filter_in_dict", + "filter_spatial_dict", + ] + if self.filter_kwargs_live[k] is not None } - + pop_empty(self.filter_kwargs_live) pop_empty(self.filter_kwargs_mat) @@ -391,6 +430,7 @@ def __attrs_post_init__(self): ] if not is_view: + class TableQueryKwargs(BaseQueryKwargs): def query( self, @@ -478,8 +518,10 @@ def live_query( **self.filter_kwargs_live, **self.joins_kwargs, ) + return TableQueryKwargs else: + class ViewQueryKwargs(BaseQueryKwargs): def query( self, @@ -503,16 +545,24 @@ def query( split_positions=split_positions, limit=limit, offset=offset, - select_columns = select_columns, + select_columns=select_columns, get_counts=get_counts, **self.filter_kwargs_mat, ) + return ViewQueryKwargs + def make_query_filter(table_name, meta, client): - pts, val_cols, all_unbd_pts, table_map, rename_map, table_list, desc = get_table_info( - table_name, meta, client - ) + ( + pts, + val_cols, + all_unbd_pts, + table_map, + rename_map, + table_list, + desc, + ) = get_table_info(table_name, meta, client) class_vals = make_class_vals( pts, val_cols, all_unbd_pts, table_map, rename_map, table_list ) @@ -522,45 +572,65 @@ def make_query_filter(table_name, meta, client): QueryFilter.__doc__ = desc return QueryFilter + def make_query_filter_view(view_name, meta, schema, client): - pts, val_cols, all_unbd_pts, table_map, rename_map, table_list, desc, live_compatible= get_view_info( - view_name, meta, schema - ) + ( + pts, + val_cols, + all_unbd_pts, + table_map, + rename_map, + table_list, + desc, + live_compatible, + ) = get_view_info(view_name, meta, schema) class_vals = make_class_vals( pts, val_cols, all_unbd_pts, table_map, rename_map, table_list ) ViewQueryFilter = attrs.make_class( - view_name, class_vals, bases=(make_kwargs_mixin(client, is_view=True, live_compatible=live_compatible),) + view_name, + class_vals, + bases=( + make_kwargs_mixin(client, is_view=True, live_compatible=live_compatible), + ), ) ViewQueryFilter.__doc__ = desc return ViewQueryFilter + class TableManager(object): - """Use schema definitions to generate query filters for each table. - """ + """Use schema definitions to generate query filters for each table.""" + def __init__(self, client): self._client = client self._table_metadata = get_all_table_metadata(self._client) self._tables = sorted(list(self._table_metadata.keys())) for tn in self._tables: - setattr(self, tn, make_query_filter(tn, self._table_metadata[tn], client)) + setattr(self, tn, make_query_filter(tn, self._table_metadata[tn], client)) def __getitem__(self, key): return getattr(self, key) - + def __repr__(self): return str(self._tables) + class ViewManager(object): def __init__(self, client): self._client = client self._view_metadata, view_schema = get_all_view_metadata(self._client) self._views = sorted(list(self._view_metadata.keys())) for vn in self._views: - setattr(self, vn, make_query_filter_view(vn, self._view_metadata[vn], view_schema[vn], client)) + setattr( + self, + vn, + make_query_filter_view( + vn, self._view_metadata[vn], view_schema[vn], client + ), + ) def __getitem__(self, key): return getattr(self, key) - + def __repr__(self): return str(self._views)