Skip to content

Commit

Permalink
Get minimal covering nodes (#274)
Browse files Browse the repository at this point in the history
* fixing minimal covering nodes

* adding version constraint

* adding version constraint

* add version 2 range constraint

* adding test for new feature

* fixing whitespace

* ignoring notebooks in formatting

* fixing formatting

* updating uv.lock

* including tests and docs in linting
  • Loading branch information
fcollman authored Dec 5, 2024
1 parent c335026 commit bcc1c5b
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 3 deletions.
28 changes: 28 additions & 0 deletions caveclient/chunkedgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,34 @@ def get_root_id(self, supervoxel_id, timestamp=None, level2=False) -> np.int64:
response = self.session.get(url, params=query_d)
return np.int64(handle_response(response, as_json=True)["root_id"])

@_check_version_compatibility(method_constraint="<3,>=2.18.0")
def get_minimal_covering_nodes(self, node_ids: Iterable[np.int64 or int]) -> dict:
"""Get the minimal covering nodes for a list of root IDs.
Parameters
----------
nodes_ids : Iterable of int or np.int64
List of root IDs to query.
Returns
-------
np.array of np.int64:
List of PCG node_ids that minimally and exactly cover the input nodes
"""

endpoint_mapping = self.default_url_mapping
url = self._endpoints["minimal_covering_nodes"].format_map(endpoint_mapping)
query_d = {}
query_d["as_array"] = True
data = json.dumps({"node_ids": node_ids}, cls=BaseEncoder)
response = self.session.post(
url,
data=data,
params=query_d,
headers={"Content-Type": "application/json"},
)
return np.frombuffer(response.content, dtype=np.uint64)

def get_merge_log(self, root_id) -> list:
"""Get the merge log (splits and merges) for an object.
Expand Down
1 change: 1 addition & 0 deletions caveclient/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@
"do_merge": pcg_v1 + "/table/{table_id}/merge",
"get_roots": pcg_v1 + "/table/{table_id}/roots_binary",
"leaves_many": pcg_v1 + "/table/{table_id}/node/leaves_many",
"minimal_covering_nodes": pcg_v1 + "/table/{table_id}/minimal_covering_nodes",
"merge_log": pcg_v1 + "/table/{table_id}/root/{root_id}/merge_log",
"change_log": pcg_v1 + "/table/{table_id}/root/{root_id}/change_log",
"tabular_change_log": pcg_v1 + "/table/{table_id}/tabular_change_log_many",
Expand Down
2 changes: 1 addition & 1 deletion caveclient/tools/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
warnings.warn("Must install responses to use CAVEclientMock for testing")
imports_worked = False

DEFAULT_CHUNKEDGRAPH_SERVER_VERSION = "2.15.0"
DEFAULT_CHUNKEDGRAPH_SERVER_VERSION = "2.18.0"
DEFAULT_MATERIALIZATION_SERVER_VERSON = "4.30.1"
DEFAULT_SKELETON_SERVICE_SERVER_VERSION = "0.3.8"
DEFAULT_JSON_SERVICE_SERVER_VERSION = "0.7.0"
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ tag_name = "v{new_version}"
[[tool.bumpversion.files]]
filename = "pyproject.toml"

[tool.ruff.lint]
exclude = ['*.ipynb']

[tool.poe.tasks]
checks = ['doc-build', 'lint', 'test']
doc-build = "mkdocs build"
Expand Down
108 changes: 108 additions & 0 deletions tests/test_chunkedgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,114 @@ def test_get_leaves_many(self, myclient):
for k, v in svids_ret.items():
assert np.all(v == sv_dict[str(k)])

@responses.activate
def test_minimal_covering_nodes(self, myclient):
endpoint_mapping = self._default_endpoint_map
l2ids = [
151733086944494680,
151803455688671629,
151803524408148361,
151803524408148371,
151873824432849265,
151873893152326692,
151873961871802557,
151944261829395442,
151944261896503344,
151944330615980096,
152014630573572920,
152014630573573159,
152014699360157777,
152084999317750839,
152155368061928130,
152225736806105726,
152296105550283453,
152296174269760177,
152366543013938289,
152436911758115473,
152436980477592301,
152507349221769815,
152577717965948027,
152577786685424795,
152577786685424797,
152648155429602343,
152648155429602345,
152648224149078995,
152718524173779595,
152718592893256250,
152788961637433848,
152859330381611287,
152859399101088529,
152929767845266045,
152929767845266047,
152929836564742225,
153000136589443132,
153000205308919861,
153070574053098110,
153140942797275294,
153141011516751998,
153141080236228767,
153211311541453183,
153211380260929846,
153211448980406620,
153211448980406624,
153211517699883483,
153281817724584437,
153281886444061198,
153352186468762089,
153352255188238828,
153422623932416499,
153422692651893340,
153492992676594177,
153493061396070924,
153563430140248766,
153563498859725522,
153633798884426257,
153633867603903006,
153704236348080465,
153774605092258013,
153774673811734916,
153844973836435679,
153845042555912428,
153915342580613364,
153915411300090031,
153985711324791351,
154056080068969264,
154126448813146756,
154126448813146804,
154126517532623397,
154196817557324412,
154196886276800996,
154267186301502007,
154267255020978536,
154337623765156171,
154337692484633043,
]

covering_nodes = [
226343819258366071,
298438259283864547,
441652105523081273,
441652105523071767,
441652105523184668,
441634513337085631,
650770146159222932,
]

url = chunkedgraph_endpoints_v1["minimal_covering_nodes"].format_map(
endpoint_mapping
)
query_d = {"as_array": True}
urlq = url + "?" + urlencode(query_d)
responses.add(
responses.POST,
body=np.array(covering_nodes).tobytes(),
url=urlq,
match=[json_params_matcher({"node_ids": l2ids})],
)

node_ids = myclient.chunkedgraph.get_minimal_covering_nodes(l2ids)
assert np.all(node_ids == np.array(covering_nodes))

@responses.activate
def test_get_root(self, myclient):
endpoint_mapping = self._default_endpoint_map
Expand Down
2 changes: 1 addition & 1 deletion tests/test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_create_versioned_client(self):
cg_version_url = endpoints.chunkedgraph_endpoints_common[
"get_version"
].format_map(endpoint_mapping)
responses.add(responses.GET, cg_version_url, json="2.15.0", status=200)
responses.add(responses.GET, cg_version_url, json="2.18.0", status=200)

mat_version_url = endpoints.materialization_common["get_version"].format_map(
endpoint_mapping
Expand Down
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit bcc1c5b

Please sign in to comment.