diff --git a/fafbseg/flywire/annotations.py b/fafbseg/flywire/annotations.py index d12d4f7..a9cdac7 100644 --- a/fafbseg/flywire/annotations.py +++ b/fafbseg/flywire/annotations.py @@ -136,17 +136,35 @@ def inner(*args, **kwargs): @inject_dataset(disallowed=['flat_630', 'flat_571']) -def is_proofread(x, materialization='auto', cache=True, validate=True, *, - dataset=None): +def is_proofread(x, table=("proofreading_status_public_v1", "proofread_neurons"), + materialization='auto', cache=True, validate=True, *, dataset=None): """Test if neuron has been set to `proofread`. Parameters ---------- x : int | list of int Root IDs to check. + table : str | tupple + Which CAVE table(s) to use. There are currently two tables: + - "proofreading_status_public_v1" contains everything that + has been set to proofread by the community and includes + some things that aren't actual neurons; this table is + automatically updated and should be up-to-date for the + production dataset + - "proofread_neurons" is a sanitized version of the above + and should contain only neurons; note though that this + table is not automatically updated and will lag behind + the "proofreading_status_public_v1" table + Unfortunately, the "proofreading_status_public_v1" table + is currently only available in the production but not the + public dataset - which makes sense since the public dataset + is a static snapshot. + To make this function robust, we default to using + "proofreading_status_public_v1" and fall back to + "proofread_neurons" if the former is not available. materialization : "latest" | "live" | "auto" | int Which materialization to check. If "latest" will use the - latest available one in the cave client. + latest available one in the CAVE client. validate : bool Whether to validate IDs. cache : bool @@ -183,6 +201,20 @@ def is_proofread(x, materialization='auto', cache=True, validate=True, *, # Get available materialization versions client = get_cave_client(dataset=dataset) + available_tables = client.materialize.get_tables() + if isinstance(table, str): + if table not in available_tables: + raise ValueError(f'Table "{table}" not available in dataset "{dataset}"') + elif isinstance(table, (tuple, list)): + for t in table: + if t in available_tables: + table = t + break + if not isinstance(t, str): + raise ValueError(f'None of the tables "{table}" are available in dataset "{dataset}"') + else: + raise TypeError('`table` must be str or tuple/list of str, got {type(table)}') + if materialization == 'latest': mat_versions = client.materialize.get_versions() materialization = max(mat_versions) @@ -191,23 +223,23 @@ def is_proofread(x, materialization='auto', cache=True, validate=True, *, if materialization == 'live': # For live materialization only do on-the-run queries - table = client.materialize.live_query(table='proofreading_status_public_v1', - timestamp=dt.datetime.utcnow(), - filter_in_dict=dict(pt_root_id=x)) + pr_table = client.materialize.live_query(table=table, + timestamp=dt.datetime.utcnow(), + filter_in_dict=dict(pt_root_id=x)) elif isinstance(materialization, int): if cache: - if materialization in PR_TABLE: - table = PR_TABLE[materialization] + if (table, materialization) in PR_TABLE: + pr_table = PR_TABLE[(table, materialization)] else: - table = client.materialize.query_table(table='proofreading_status_public_v1', - materialization_version=materialization) - PR_TABLE[materialization] = table + pr_table = client.materialize.query_table(table=table, + materialization_version=materialization) + PR_TABLE[(table, materialization)] = pr_table else: - table = client.materialize.query_table(table='proofreading_status_public_v1', - filter_in_dict=dict(pt_root_id=x), - materialization_version=materialization) + pr_table = client.materialize.query_table(table=table, + filter_in_dict=dict(pt_root_id=x), + materialization_version=materialization) - return np.isin(x, table.pt_root_id.values) + return np.isin(x, pr_table.pt_root_id.values) @inject_dataset(disallowed=['flat_630', 'flat_571'])