Skip to content

Commit

Permalink
Introduce retrieve_nodes() and retrieve_edges()
Browse files Browse the repository at this point in the history
Generic parameters and defaults don't mix well (see python/mypy#3737) so we need to use overloads instead, or InstancesResult is always InstancesResult[Any, Any]. Overloads don't work on retrieve() due to defaults and ordering of params, so we create two new methods that can be safely typed.
  • Loading branch information
erlendvollset committed Jul 2, 2024
1 parent daf4403 commit b4b09c5
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 23 deletions.
79 changes: 73 additions & 6 deletions cognite/client/_api/data_modeling/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,27 +271,81 @@ def __iter__(self) -> Iterator[Node]:
"""
return self(None, "node")

@overload
def retrieve_edges(
self,
edges: EdgeId | Sequence[EdgeId] | tuple[str, str] | Sequence[tuple[str, str]],
*,
edge_cls: type[T_Edge],
) -> EdgeList[T_Edge]: ...

@overload
def retrieve_edges(
self,
edges: EdgeId | Sequence[EdgeId] | tuple[str, str] | Sequence[tuple[str, str]],
*,
sources: Source | Sequence[Source] | None = None,
include_typing: bool = False,
) -> EdgeList[Edge]: ...

def retrieve_edges(
self,
edges: EdgeId | Sequence[EdgeId] | tuple[str, str] | Sequence[tuple[str, str]],
edge_cls: type[T_Edge] = Edge, # type: ignore
sources: Source | Sequence[Source] | None = None,
include_typing: bool = False,
) -> EdgeList[T_Edge]:
res = self._retrieve_typed(
nodes=None, edges=edges, node_cls=Node, edge_cls=edge_cls, sources=sources, include_typing=include_typing
)
return res.edges

@overload
def retrieve_nodes(
self,
nodes: NodeId | Sequence[NodeId] | tuple[str, str] | Sequence[tuple[str, str]],
*,
node_cls: type[T_Node],
) -> NodeList[T_Node]: ...

@overload
def retrieve_nodes(
self,
nodes: NodeId | Sequence[NodeId] | tuple[str, str] | Sequence[tuple[str, str]],
*,
sources: Source | Sequence[Source] | None = None,
include_typing: bool = False,
) -> NodeList[Node]: ...

def retrieve_nodes(
self,
nodes: NodeId | Sequence[NodeId] | tuple[str, str] | Sequence[tuple[str, str]],
node_cls: type[T_Node] = Node, # type: ignore
sources: Source | Sequence[Source] | None = None,
include_typing: bool = False,
) -> NodeList[T_Node]:
res = self._retrieve_typed(
nodes=nodes, edges=None, node_cls=node_cls, edge_cls=Edge, sources=sources, include_typing=include_typing
)
return res.nodes

def retrieve(
self,
nodes: NodeId | Sequence[NodeId] | tuple[str, str] | Sequence[tuple[str, str]] | None = None,
edges: EdgeId | Sequence[EdgeId] | tuple[str, str] | Sequence[tuple[str, str]] | None = None,
sources: Source | Sequence[Source] | None = None,
include_typing: bool = False,
node_cls: type[T_Node] = Node, # type: ignore[assignment]
edge_cls: type[T_Edge] = Edge, # type: ignore[assignment]
) -> InstancesResult[T_Node, T_Edge]:
) -> InstancesResult[Node, Edge]:
"""`Retrieve one or more instance by id(s). <https://developer.cognite.com/api#tag/Instances/operation/byExternalIdsInstances>`_
Args:
nodes (NodeId | Sequence[NodeId] | tuple[str, str] | Sequence[tuple[str, str]] | None): Node ids
edges (EdgeId | Sequence[EdgeId] | tuple[str, str] | Sequence[tuple[str, str]] | None): Edge ids
sources (Source | Sequence[Source] | None): Retrieve properties from the listed - by reference - views.
include_typing (bool): Whether to return property type information as part of the result.
node_cls (type[T_Node]): Node class to use when returning nodes.
edge_cls (type[T_Edge]): Edge class to use when returning edges.
Returns:
InstancesResult[T_Node, T_Edge]: Requested instances.
InstancesResult[Node, Edge]: Requested instances.
Examples:
Expand Down Expand Up @@ -324,6 +378,19 @@ def retrieve(
... EdgeId("mySpace", "myEdge"),
... sources=("myspace", "myView"))
"""
return self._retrieve_typed(
nodes=nodes, edges=edges, sources=sources, include_typing=include_typing, node_cls=Node, edge_cls=Edge
)

def _retrieve_typed(
self,
nodes: NodeId | Sequence[NodeId] | tuple[str, str] | Sequence[tuple[str, str]] | None,
edges: EdgeId | Sequence[EdgeId] | tuple[str, str] | Sequence[tuple[str, str]] | None,
sources: Source | Sequence[Source] | None,
include_typing: bool,
node_cls: type[T_Node],
edge_cls: type[T_Edge],
) -> InstancesResult[T_Node, T_Edge]:
identifiers = self._load_node_and_edge_ids(nodes, edges)

sources = self._to_sources(sources, node_cls, edge_cls)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def test_apply_retrieve_and_delete(self, cognite_client: CogniteClient, person_v

try:
created = cognite_client.data_modeling.instances.apply(new_node, replace=True)
retrieved: InstancesResult[Node, Edge] = cognite_client.data_modeling.instances.retrieve(new_node.as_id())
retrieved = cognite_client.data_modeling.instances.retrieve(new_node.as_id())

assert len(created.nodes) == 1
assert created.nodes[0].created_time
Expand All @@ -503,9 +503,7 @@ def test_apply_retrieve_and_delete(self, cognite_client: CogniteClient, person_v
assert retrieved.nodes[0].as_id() == new_node.as_id()

deleted_result = cognite_client.data_modeling.instances.delete(new_node.as_id())
retrieved_deleted: InstancesResult[Node, Edge] = cognite_client.data_modeling.instances.retrieve(
new_node.as_id()
)
retrieved_deleted = cognite_client.data_modeling.instances.retrieve(new_node.as_id())

assert len(deleted_result.nodes) == 1
assert deleted_result.nodes[0] == new_node.as_id()
Expand Down Expand Up @@ -581,7 +579,7 @@ def test_apply_auto_create_nodes(self, cognite_client: CogniteClient, person_vie
created_edges = cognite_client.data_modeling.instances.apply(
edges=person_to_actor, auto_create_start_nodes=True, auto_create_end_nodes=True, replace=True
)
created_nodes: InstancesResult[Node, Edge] = cognite_client.data_modeling.instances.retrieve(node_pair)
created_nodes = cognite_client.data_modeling.instances.retrieve(node_pair)

assert len(created_edges.edges) == 1
assert created_edges.edges[0].created_time
Expand All @@ -600,13 +598,13 @@ def test_delete_non_existent(self, cognite_client: CogniteClient, integration_te
assert res.edges == []

def test_retrieve_multiple(self, cognite_client: CogniteClient, movie_nodes: NodeList) -> None:
retrieved: InstancesResult[Node, Edge] = cognite_client.data_modeling.instances.retrieve(movie_nodes.as_ids())
retrieved = cognite_client.data_modeling.instances.retrieve(movie_nodes.as_ids())
assert len(retrieved.nodes) == len(movie_nodes)

def test_retrieve_nodes_and_edges_using_id_tuples(
self, cognite_client: CogniteClient, movie_nodes: NodeList, movie_edges: EdgeList
) -> None:
retrieved: InstancesResult[Node, Edge] = cognite_client.data_modeling.instances.retrieve(
retrieved = cognite_client.data_modeling.instances.retrieve(
nodes=[(id.space, id.external_id) for id in movie_nodes.as_ids()],
edges=[(id.space, id.external_id) for id in movie_edges.as_ids()],
)
Expand All @@ -616,17 +614,25 @@ def test_retrieve_nodes_and_edges_using_id_tuples(
def test_retrieve_nodes_and_edges(
self, cognite_client: CogniteClient, movie_nodes: NodeList, movie_edges: EdgeList
) -> None:
retrieved: InstancesResult[Node, Edge] = cognite_client.data_modeling.instances.retrieve(
retrieved = cognite_client.data_modeling.instances.retrieve(
nodes=movie_nodes.as_ids(), edges=movie_edges.as_ids()
)
assert set(retrieved.nodes.as_ids()) == set(movie_nodes.as_ids())
assert set(retrieved.edges.as_ids()) == set(movie_edges.as_ids())

def test_retrieve_nodes(self, cognite_client: CogniteClient, movie_nodes: NodeList) -> None:
retrieved = cognite_client.data_modeling.instances.retrieve_nodes(movie_nodes.as_ids())
assert set(retrieved.as_ids()) == set(movie_nodes.as_ids())

def test_retrieve_edges(self, cognite_client: CogniteClient, movie_edges: EdgeList) -> None:
retrieved = cognite_client.data_modeling.instances.retrieve_edges(movie_edges.as_ids())
assert set(retrieved.as_ids()) == set(movie_edges.as_ids())

def test_retrieve_multiple_with_missing(self, cognite_client: CogniteClient, movie_nodes: NodeList) -> None:
ids_without_missing = movie_nodes.as_ids()
ids_with_missing = [*ids_without_missing, NodeId("myNonExistingSpace", "myImaginaryContainer")]

retrieved: InstancesResult[Node, Edge] = cognite_client.data_modeling.instances.retrieve(ids_with_missing)
retrieved = cognite_client.data_modeling.instances.retrieve(ids_with_missing)
assert retrieved.nodes.as_ids() == ids_without_missing

def test_retrieve_non_existent(self, cognite_client: CogniteClient) -> None:
Expand Down Expand Up @@ -877,9 +883,7 @@ def test_retrieve_in_units(
node = node_with_1_1_pressure_in_bar
source = SourceSelector(unit_view.as_id(), target_units=[TargetUnit("pressure", UnitReference("pressure:pa"))])

retrieved: InstancesResult[Node, Edge] = cognite_client.data_modeling.instances.retrieve(
node.as_id(), sources=[source]
)
retrieved = cognite_client.data_modeling.instances.retrieve(node.as_id(), sources=[source])
assert retrieved.nodes
assert math.isclose(cast(float, retrieved.nodes[0]["pressure"]), 1.1 * 1e5)

Expand Down Expand Up @@ -965,9 +969,9 @@ def test_write_typed_node(self, cognite_client: CogniteClient, integration_test_
assert len(created.nodes) == 1
assert created.nodes[0].external_id == external_id

retrieved = cognite_client.data_modeling.instances.retrieve(
retrieved = cognite_client.data_modeling.instances.retrieve_nodes(
primitive.as_id(), node_cls=PrimitiveNullableRead
).nodes
)
assert len(retrieved) == 1
assert isinstance(retrieved[0], PrimitiveNullableRead)
assert retrieved[0].text == "text"
Expand Down Expand Up @@ -1000,9 +1004,9 @@ def test_write_typed_node_listed_properties(
assert len(created.nodes) == 1
assert created.nodes[0].external_id == external_id

retrieved = cognite_client.data_modeling.instances.retrieve(
retrieved = cognite_client.data_modeling.instances.retrieve_nodes(
primitive_listed.as_id(), node_cls=PrimitiveListedRead
).nodes
)
assert len(retrieved) == 1
assert isinstance(retrieved[0], PrimitiveListedRead)
assert retrieved[0].text == ["text"]
Expand All @@ -1022,7 +1026,7 @@ def test_write_type_node_instance_property_descriptor(
assert len(created.nodes) == 1
assert created.nodes[0].external_id == external_id

retrieved = cognite_client.data_modeling.instances.retrieve(person.as_id(), node_cls=PersonRead).nodes
retrieved = cognite_client.data_modeling.instances.retrieve_nodes(person.as_id(), node_cls=PersonRead)
assert len(retrieved) == 1
assert isinstance(retrieved[0], PersonRead)
assert retrieved[0].name == "John Doe"
Expand Down

0 comments on commit b4b09c5

Please sign in to comment.