Skip to content

Commit

Permalink
Add bounding box argument to level2_chunk_graph (#154)
Browse files Browse the repository at this point in the history
* add bounds

* add test

* amend docs

* add a warning

* add error

* clarify description of resolution

* does this have headers?

* fix response order

* typo

* fix test and header encoding
  • Loading branch information
bdpedigo authored Mar 11, 2024
1 parent 2aa4efb commit 2489515
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 12 deletions.
31 changes: 29 additions & 2 deletions caveclient/chunkedgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def get_subgraph(
rd = handle_response(response)
return np.int64(rd["nodes"]), np.double(rd["affinities"]), np.int32(rd["areas"])

def level2_chunk_graph(self, root_id) -> list:
def level2_chunk_graph(self, root_id, bounds=None) -> list:
"""
Get graph of level 2 chunks, the smallest agglomeration level above supervoxels.
Expand All @@ -783,6 +783,14 @@ def level2_chunk_graph(self, root_id) -> list:
----------
root_id : int
Root id of object
bounds : np.array
3x2 bounding box (x,y,z) x (min,max) in chunked graph coordinates (use
`client.chunkedgraph.base_resolution` to view this default resolution for
your chunkedgraph client). Note that the result will include any level 2
nodes which have chunk boundaries within some part of this bounding box,
meaning that the representative point for a given level 2 node could still
be slightly outside of these bounds. If None, returns all level 2 chunks
for the root ID.
Returns
-------
Expand All @@ -792,8 +800,27 @@ def level2_chunk_graph(self, root_id) -> list:
"""
endpoint_mapping = self.default_url_mapping
endpoint_mapping["root_id"] = root_id

query_d = {}
if bounds is not None:
query_d["bounds"] = package_bounds(bounds)

url = self._endpoints["lvl2_graph"].format_map(endpoint_mapping)
r = handle_response(self.session.get(url))
response = self.session.get(url, params=query_d)

used_bounds = response.headers.get("Used-Bounds")
used_bounds = used_bounds == "true" or used_bounds == "True"
if bounds is not None and not used_bounds:
warning = (
"Bounds were not used for this query, even though it was requested. "
"This is likely because your system is running a version of the "
"chunkedgraph that does not support this feature. Please contact "
"your system administrator to update the chunkedgraph."
)
raise ValueError(warning)

r = handle_response(response)

return r["edge_graph"]

def remesh_level2_chunks(self, chunk_ids) -> None:
Expand Down
74 changes: 64 additions & 10 deletions tests/test_chunkedgraph.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from re import A, match
from .conftest import test_info, TEST_LOCAL_SERVER, TEST_DATASTACK
import datetime
import json
from urllib.parse import urlencode

import numpy as np
import pytest
import pytz
import responses
from responses.matchers import json_params_matcher
import pytz
import numpy as np

from caveclient.endpoints import (
chunkedgraph_endpoints_v1,
chunkedgraph_endpoints_common,
chunkedgraph_endpoints_v1,
)
import datetime
import time
import json
from urllib.parse import urlencode

from .conftest import TEST_LOCAL_SERVER, test_info


def binary_body_match(body):
Expand Down Expand Up @@ -365,6 +366,60 @@ def test_get_lvl2subgraph(self, myclient):
qlvl2_graph = myclient.chunkedgraph.level2_chunk_graph(root_id)
assert np.all(qlvl2_graph == lvl2_graph)

@responses.activate
def test_get_lvl2subgraph_bounds(self, myclient):
endpoint_mapping = self._default_endpoint_map
root_id = 864691136812623475
endpoint_mapping["root_id"] = root_id
url = chunkedgraph_endpoints_v1["lvl2_graph"].format_map(endpoint_mapping)

lvl2_graph_list = [
[160032475051983415, 160032543771460210],
[160032475051983415, 160102843796161019],
[160032543771460210, 160032612490936816],
[160032543771460210, 160102912515637813],
[160032612490936816, 160032681210413593],
[160032612490936816, 160102981235115106],
[160032681210413593, 160032749929890185],
[160032681210413593, 160032749929890340],
[160032681210413593, 160103049954591386],
[160032749929890185, 160103118674068005],
[160032818649367090, 160103187393544707],
[160102843796161019, 160102912515637813],
[160102912515637813, 160102981235115106],
[160102981235115106, 160103049954591364],
[160102981235115106, 160103049954591386],
[160103049954591386, 160103118674068005],
[160103118674068005, 160103187393544707],
[160103187393544707, 160173556137722487],
]

lvl2_graph = np.array(lvl2_graph_list, dtype=np.int64)

responses.add(
responses.GET,
json={"edge_graph": lvl2_graph_list},
url=url,
headers={"Used-Bounds": "True"},
)

bounds = np.array([[83875, 85125], [82429, 83679], [20634, 20884]])
qlvl2_graph = myclient.chunkedgraph.level2_chunk_graph(root_id, bounds=bounds)
assert np.all(qlvl2_graph == lvl2_graph)

# should fail when bounds are not used, but bounds were passed in
responses.add(
responses.GET,
json={"edge_graph": lvl2_graph_list},
url=url,
headers={"Used-Bounds": "False"},
)

with pytest.raises(ValueError):
qlvl2_graph = myclient.chunkedgraph.level2_chunk_graph(
root_id, bounds=bounds
)

@responses.activate
def test_get_remeshing(self, myclient):
endpoint_mapping = self._default_endpoint_map
Expand Down Expand Up @@ -801,7 +856,6 @@ def test_get_info(self, myclient):

@responses.activate
def test_is_valid_nodes(self, myclient):

endpoint_mapping = self._default_endpoint_map
url = chunkedgraph_endpoints_v1["valid_nodes"].format_map(endpoint_mapping)
query_nodes = [91070075234304972, 91070075234296549]
Expand Down

0 comments on commit 2489515

Please sign in to comment.