Skip to content

Commit

Permalink
Multischema endpoint (#158)
Browse files Browse the repository at this point in the history
* add multi-schema endpoints

* fix data formatting for multischema endpoint

* tweak docstring

* Use multischema endpoints in table interface

* multithreaded materialization initialization

* tweek type hinting for older pythons

* add table manager tests, __contains__ and __len__

* move table testing to responses-based
  • Loading branch information
ceesem authored Mar 13, 2024
1 parent 748659e commit d80dde9
Show file tree
Hide file tree
Showing 6 changed files with 436 additions and 39 deletions.
39 changes: 36 additions & 3 deletions caveclient/emannotationschemas.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from .base import ClientBase, _api_endpoints, handle_response
from .endpoints import schema_api_versions, schema_endpoints_common
from .auth import AuthClient
Expand Down Expand Up @@ -65,7 +66,7 @@ def __init__(
over_client=over_client,
)

def get_schemas(self):
def get_schemas(self) -> list[str]:
"""Get the available schema types
Returns
Expand All @@ -77,8 +78,8 @@ def get_schemas(self):
url = self._endpoints["schema"].format_map(endpoint_mapping)
response = self.session.get(url)
return handle_response(response)

def schema_definition(self, schema_type):
def schema_definition(self, schema_type: str) -> dict[str]:
"""Get the definition of a specified schema_type
Parameters
Expand All @@ -97,6 +98,38 @@ def schema_definition(self, schema_type):
response = self.session.get(url)
return handle_response(response)

def schema_definition_multi(self, schema_types: list[str]) -> dict:
"""Get the definition of multiple schema_types
Parameters
----------
schema_types : list
List of schema names
Returns
-------
dict
Dictionary of schema definitions. Keys are schema names, values are definitions.
"""
endpoint_mapping = self.default_url_mapping
url = self._endpoints["schema_definition_multi"].format_map(endpoint_mapping)
data={'schema_names': ','.join(schema_types)}
response = self.session.post(url, params=data)
return handle_response(response)

def schema_definition_all(self) -> dict[str]:
"""Get the definition of all schema_types
Returns
-------
dict
Dictionary of schema definitions. Keys are schema names, values are definitions.
"""
endpoint_mapping = self.default_url_mapping
url = self._endpoints["schema_definition_all"].format_map(endpoint_mapping)
response = self.session.get(url)
return handle_response(response)


client_mapping = {
1: SchemaClientLegacy,
Expand Down
2 changes: 2 additions & 0 deletions caveclient/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@
schema_endpoints_v2 = {
"schema": schema_v2 + "/type",
"schema_definition": schema_v2 + "/type/{schema_type}",
"schema_definition_multi": schema_v2 + "/types",
"schema_definition_all": schema_v2 + "/types_all",
}

schema_api_versions = {1: schema_endpoints_v1, 2: schema_endpoints_v2}
Expand Down
67 changes: 46 additions & 21 deletions caveclient/materializationengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandas as pd
import pyarrow as pa
import pytz
from concurrent.futures import ThreadPoolExecutor
from cachetools import TTLCache, cached
from IPython.display import HTML

Expand Down Expand Up @@ -247,25 +248,28 @@ def __init__(
over_client=over_client,
)
self._datastack_name = datastack_name
if version is None:
version = self.most_recent_version()
self._version = version
if cg_client is None:
if self.fc is not None:
self.cg_client = self.fc.chunkedgraph
else:
self.cg_client = cg_client
self._cg_client = cg_client
self.synapse_table = synapse_table
self.desired_resolution = desired_resolution
self._tables = None
self._views = None

@property
def datastack_name(self):
return self._datastack_name

@property
def cg_client(self):
if self._cg_client is None:
if self.fc is not None:
self._cg_client = self.fc.chunkedgraph
else:
raise ValueError("No chunkedgraph client specified")
return self._cg_client

@property
def version(self):
if self._version is None:
self._version = self.most_recent_version()
return self._version

@property
Expand Down Expand Up @@ -328,18 +332,6 @@ def get_versions(self, datastack_name=None, expired=False):
self.raise_for_status(response)
return response.json()

@property
def tables(self):
if self._tables is None:
self._tables = TableManager(self.fc)
return self._tables

@property
def views(self):
if self._views is None:
self._views = ViewManager(self.fc)
return self._views

def get_tables(self, datastack_name=None, version=None):
"""Gets a list of table names for a datastack
Expand Down Expand Up @@ -1876,6 +1868,39 @@ def _assemble_attributes(
class MaterializationClientV3(MaterializationClientV2):
def __init__(self, *args, **kwargs):
super(MaterializationClientV3, self).__init__(*args, **kwargs)
metadata = []
with ThreadPoolExecutor(max_workers=4) as executor:
metadata.append(
executor.submit(
self.get_tables_metadata,
)
)
metadata.append(
executor.submit(
self.fc.schema.schema_definition_all
)
)
metadata.append(
executor.submit(
self.get_views
)
)
metadata.append(
executor.submit(
self.get_view_schemas
)
)
if self.fc is not None:
tables = TableManager(self.fc, metadata[0].result(), metadata[1].result())
else:
tables = None
self.tables = tables

if self.fc is not None:
views = ViewManager(self.fc, metadata[2].result(), metadata[3].result())
else:
views = None
self.views = views

@cached(cache=TTLCache(maxsize=100, ttl=60 * 60 * 12))
def get_tables_metadata(
Expand Down
90 changes: 79 additions & 11 deletions caveclient/tools/table_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,9 @@ def combine_names(tableA, namesA, tableB, namesB, suffixes):
return final_namesA + final_namesB, table_map, rename_map


def get_all_table_metadata(client):
meta = client.materialize.get_tables_metadata()
def get_all_table_metadata(client, meta=None):
if meta is None:
meta = client.materialize.get_tables_metadata()
tables = []
for m in meta:
if m.get("annotation_table"):
Expand Down Expand Up @@ -116,6 +117,20 @@ def _schema_key(schema_name, client, **kwargs):
key = keys.hashkey(schema_name, str(allow_types))
return key

def populate_schema_cache(client, schema_definitions=None):
if schema_definitions is None:
try:
schema_definitions = client.schema.schema_definition_all()
except:
schema_definitions = {sn:None for sn in client.schema.get_schemas()}
for schema_name, schema_definition in schema_definitions.items():
get_col_info(schema_name, client, schema_definition=schema_definition)

def populate_table_cache(client, metadata=None):
if metadata is None:
metadata = get_all_table_metadata(client)
for tn, meta in metadata.items():
table_metadata(tn, client, meta=meta)

@cached(cache=_schema_cache, key=_schema_key)
def get_col_info(
Expand All @@ -126,8 +141,12 @@ def get_col_info(
allow_types=ALLOW_COLUMN_TYPES,
add_fields=["id"],
omit_fields=[],
schema_definition=None,
):
schema = client.schema.schema_definition(schema_name)
if schema_definition is None:
schema = client.schema.schema_definition(schema_name)
else:
schema = schema_definition.copy()
sp_name = f"#/definitions/{spatial_point}"
unbd_sp_name = f"#/definitions/{unbound_spatial_point}"
n_sp = 0
Expand Down Expand Up @@ -275,18 +294,18 @@ def get_table_info(

_metadata_cache = TTLCache(maxsize=128, ttl=86_400)


def _metadata_key(tn, client):
def _metadata_key(tn, client, **kwargs):
key = keys.hashkey(tn)
return key


@cached(cache=_metadata_cache, key=_metadata_key)
def table_metadata(table_name, client):
def table_metadata(table_name, client, meta=None):
"Caches getting table metadata"
with warnings.catch_warnings():
warnings.simplefilter(action="ignore")
meta = client.materialize.get_table_metadata(table_name)
if meta is None:
meta = client.materialize.get_table_metadata(table_name)
if "schema" not in meta:
meta["schema"] = meta.get("schema_type")
return meta
Expand Down Expand Up @@ -534,6 +553,29 @@ def query(
desired_resolution=None,
get_counts=False,
):
"""Query views through the table interface
Parameters
----------
select_columns : list[str], optional
Specification of columns to return, by default None
offset : int, optional
Integer offset from the beginning of the table to return, by default None.
Used when tables are too large to return in one query.
limit : int, optional
Maximum number of rows to return, by default None
split_positions : bool, optional
If true, returns each point coordinate as a separate column, by default False
materialization_version : int, optional
Query a specified materialization version, by default None
metadata : bool, optional
If true includes query and table metadata in the .attrs property of the returned dataframe, by default True
desired_resolution : list[int], optional
Sets the 3d point resolution in nm, by default None.
If default, uses the values in the table directly.
get_counts : bool, optional
Only return number of rows in the query, by default False
"""
logger.warning(
"The `client.materialize.views` interface is experimental and might experience breaking changes before the feature is stabilized."
)
Expand Down Expand Up @@ -601,24 +643,39 @@ def make_query_filter_view(view_name, meta, schema, client):
class TableManager(object):
"""Use schema definitions to generate query filters for each table."""

def __init__(self, client):
def __init__(self, client, metadata=None, schema=None):
self._client = client
self._table_metadata = get_all_table_metadata(self._client)
self._table_metadata = get_all_table_metadata(self._client, meta=metadata)
self._tables = sorted(list(self._table_metadata.keys()))
populate_schema_cache(client, schema_definitions=schema)
populate_table_cache(client, metadata=self._table_metadata)
for tn in self._tables:
setattr(self, tn, make_query_filter(tn, self._table_metadata[tn], client))

def __getitem__(self, key):
return getattr(self, key)

def __contains__(self, key):
return key in self._tables

def __repr__(self):
return str(self._tables)

@property
def table_names(self):
return self._tables

def __len__(self):
return len(self._tables)


class ViewManager(object):
def __init__(self, client):
def __init__(self, client, view_metadata=None, view_schema=None):
self._client = client
self._view_metadata, view_schema = get_all_view_metadata(self._client)
if view_metadata is None or view_schema is None:
view_metadata, view_schema = get_all_view_metadata(self._client)
else:
self._view_metadata = view_metadata
self._views = sorted(list(self._view_metadata.keys()))
for vn in self._views:
setattr(
Expand All @@ -632,5 +689,16 @@ def __init__(self, client):
def __getitem__(self, key):
return getattr(self, key)

def __contains__(self, key):
return key in self._views

def __repr__(self):
return str(self._views)

@property
def table_names(self):
return self._views

def __len__(self):
return len(self._views)

3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,4 @@ cachetools>=4.2.1
ipython
networkx
jsonschema
attrs>=21.3.0
cachetools>=4
attrs>=21.3.0
Loading

0 comments on commit d80dde9

Please sign in to comment.