diff --git a/caveclient/chunkedgraph.py b/caveclient/chunkedgraph.py index 74dd1150..a7f3680e 100644 --- a/caveclient/chunkedgraph.py +++ b/caveclient/chunkedgraph.py @@ -774,7 +774,7 @@ def get_subgraph( rd = handle_response(response) return np.int64(rd["nodes"]), np.double(rd["affinities"]), np.int32(rd["areas"]) - def level2_chunk_graph(self, root_id) -> list: + def level2_chunk_graph(self, root_id, bounds=None) -> list: """ Get graph of level 2 chunks, the smallest agglomeration level above supervoxels. @@ -783,6 +783,14 @@ def level2_chunk_graph(self, root_id) -> list: ---------- root_id : int Root id of object + bounds : np.array + 3x2 bounding box (x,y,z) x (min,max) in chunked graph coordinates (use + `client.chunkedgraph.base_resolution` to view this default resolution for + your chunkedgraph client). Note that the result will include any level 2 + nodes which have chunk boundaries within some part of this bounding box, + meaning that the representative point for a given level 2 node could still + be slightly outside of these bounds. If None, returns all level 2 chunks + for the root ID. Returns ------- @@ -792,8 +800,27 @@ def level2_chunk_graph(self, root_id) -> list: """ endpoint_mapping = self.default_url_mapping endpoint_mapping["root_id"] = root_id + + query_d = {} + if bounds is not None: + query_d["bounds"] = package_bounds(bounds) + url = self._endpoints["lvl2_graph"].format_map(endpoint_mapping) - r = handle_response(self.session.get(url)) + response = self.session.get(url, params=query_d) + + used_bounds = response.headers.get("Used-Bounds") + used_bounds = used_bounds == "true" or used_bounds == "True" + if bounds is not None and not used_bounds: + warning = ( + "Bounds were not used for this query, even though it was requested. " + "This is likely because your system is running a version of the " + "chunkedgraph that does not support this feature. Please contact " + "your system administrator to update the chunkedgraph." + ) + raise ValueError(warning) + + r = handle_response(response) + return r["edge_graph"] def remesh_level2_chunks(self, chunk_ids) -> None: diff --git a/tests/test_chunkedgraph.py b/tests/test_chunkedgraph.py index d607f998..7bee5cf8 100644 --- a/tests/test_chunkedgraph.py +++ b/tests/test_chunkedgraph.py @@ -1,18 +1,19 @@ -from re import A, match -from .conftest import test_info, TEST_LOCAL_SERVER, TEST_DATASTACK +import datetime +import json +from urllib.parse import urlencode + +import numpy as np import pytest +import pytz import responses from responses.matchers import json_params_matcher -import pytz -import numpy as np + from caveclient.endpoints import ( - chunkedgraph_endpoints_v1, chunkedgraph_endpoints_common, + chunkedgraph_endpoints_v1, ) -import datetime -import time -import json -from urllib.parse import urlencode + +from .conftest import TEST_LOCAL_SERVER, test_info def binary_body_match(body): @@ -365,6 +366,60 @@ def test_get_lvl2subgraph(self, myclient): qlvl2_graph = myclient.chunkedgraph.level2_chunk_graph(root_id) assert np.all(qlvl2_graph == lvl2_graph) + @responses.activate + def test_get_lvl2subgraph_bounds(self, myclient): + endpoint_mapping = self._default_endpoint_map + root_id = 864691136812623475 + endpoint_mapping["root_id"] = root_id + url = chunkedgraph_endpoints_v1["lvl2_graph"].format_map(endpoint_mapping) + + lvl2_graph_list = [ + [160032475051983415, 160032543771460210], + [160032475051983415, 160102843796161019], + [160032543771460210, 160032612490936816], + [160032543771460210, 160102912515637813], + [160032612490936816, 160032681210413593], + [160032612490936816, 160102981235115106], + [160032681210413593, 160032749929890185], + [160032681210413593, 160032749929890340], + [160032681210413593, 160103049954591386], + [160032749929890185, 160103118674068005], + [160032818649367090, 160103187393544707], + [160102843796161019, 160102912515637813], + [160102912515637813, 160102981235115106], + [160102981235115106, 160103049954591364], + [160102981235115106, 160103049954591386], + [160103049954591386, 160103118674068005], + [160103118674068005, 160103187393544707], + [160103187393544707, 160173556137722487], + ] + + lvl2_graph = np.array(lvl2_graph_list, dtype=np.int64) + + responses.add( + responses.GET, + json={"edge_graph": lvl2_graph_list}, + url=url, + headers={"Used-Bounds": "True"}, + ) + + bounds = np.array([[83875, 85125], [82429, 83679], [20634, 20884]]) + qlvl2_graph = myclient.chunkedgraph.level2_chunk_graph(root_id, bounds=bounds) + assert np.all(qlvl2_graph == lvl2_graph) + + # should fail when bounds are not used, but bounds were passed in + responses.add( + responses.GET, + json={"edge_graph": lvl2_graph_list}, + url=url, + headers={"Used-Bounds": "False"}, + ) + + with pytest.raises(ValueError): + qlvl2_graph = myclient.chunkedgraph.level2_chunk_graph( + root_id, bounds=bounds + ) + @responses.activate def test_get_remeshing(self, myclient): endpoint_mapping = self._default_endpoint_map @@ -801,7 +856,6 @@ def test_get_info(self, myclient): @responses.activate def test_is_valid_nodes(self, myclient): - endpoint_mapping = self._default_endpoint_map url = chunkedgraph_endpoints_v1["valid_nodes"].format_map(endpoint_mapping) query_nodes = [91070075234304972, 91070075234296549]