diff --git a/retailtree/logics/overlap.py b/retailtree/logics/overlap.py index d81ddd7..e96ec0a 100644 --- a/retailtree/logics/overlap.py +++ b/retailtree/logics/overlap.py @@ -1,4 +1,8 @@ +from retailtree.structs.annotation_struct import Annotation + + def calculate_overlap(anno_1, anno_2, axis): + # type:(Annotation, Annotation, str) -> float """ Calculate overlap of two annotations in given axis. Parameters diff --git a/retailtree/logics/right_left.py b/retailtree/logics/right_left.py index 3802abd..b4ab52e 100644 --- a/retailtree/logics/right_left.py +++ b/retailtree/logics/right_left.py @@ -1,8 +1,10 @@ +from retailtree.structs.annotation_struct import Annotation from retailtree.logics.overlap import calculate_overlap # Establishing right-left connections def right_left_connections(i, j, given_overlap_percentage): + # type:(Annotation, Annotation, float) -> None # Checking the candidate annotations is in the range of base annotations corresponds to x axis if (i.x_min < j.x_min and i.x_max + i.length > j.x_min): # Finding overlap percentage @@ -22,6 +24,7 @@ def right_left_connections(i, j, given_overlap_percentage): # Establishing left-right connection def left_right_connections(i, j, given_overlap_percentage): + # type:(Annotation, Annotation, float) -> None if j.right == None: # Checking the candidate annotations is in the range of base annotations corresponds to x axis if (i.x_min >= j.x_max and i.x_min - i.length < j.x_max): diff --git a/retailtree/logics/top_bottom.py b/retailtree/logics/top_bottom.py index 591c61f..b8a9e3a 100644 --- a/retailtree/logics/top_bottom.py +++ b/retailtree/logics/top_bottom.py @@ -1,8 +1,10 @@ +from retailtree.structs.annotation_struct import Annotation from retailtree.logics.overlap import calculate_overlap # Establishing tob-bottom connections def top_bottom_connections(i, j, given_overlap_percentage): + # type:(Annotation, Annotation, float) -> None if (i.y_max > j.y_max and i.y_min - i.width < j.y_max): # Finding overlap percentage overlap_percentage = calculate_overlap(i, j, axis='x')/j.length @@ -24,6 +26,7 @@ def top_bottom_connections(i, j, given_overlap_percentage): # Establishing tob-bottom connections def bottom_top_connections(i, j, given_overlap_percentage): + # type:(Annotation, Annotation, float) -> None if j.top == None: if (i.y_max <= j.y_min and i.y_max + i.width > j.y_min): # Finding overlap percentage diff --git a/retailtree/logics/vp_tree.py b/retailtree/logics/vp_tree.py index 49f68d1..ef70fdb 100644 --- a/retailtree/logics/vp_tree.py +++ b/retailtree/logics/vp_tree.py @@ -1,10 +1,13 @@ -from typing import Callable import math import statistics as stats +from typing import Callable + +from retailtree.structs.annotation_struct import Annotation class VPTree: - def __init__(self, points, dist_fn: Callable[[tuple[float, float], tuple[float, float]], float]): + def __init__(self, points, dist_fn): + # type:(list[Annotation], Callable[[tuple[float, float], tuple[float, float]], float]) -> None self.left = None self.right = None self.left_min = math.inf @@ -56,7 +59,8 @@ def _is_leaf(self): return (self.left is None) and (self.right is None) def get_all_in_range(self, query, max_distance): - neighbors = list() + # type: (tuple[float, float], float) -> list[tuple[float, Annotation]] + neighbors = list() # type: list[tuple[float, Annotation]] nodes_to_visit = [(self, 0)] while len(nodes_to_visit) > 0: diff --git a/retailtree/retailtree.py b/retailtree/retailtree.py index 9543b02..a49b0dc 100644 --- a/retailtree/retailtree.py +++ b/retailtree/retailtree.py @@ -1,16 +1,18 @@ +import math +from random import sample +import numpy as np +from typing import Callable + from retailtree.structs.annotation_struct import Annotation from retailtree.logics.vp_tree import VPTree from retailtree.logics.right_left import right_left_connections, left_right_connections from retailtree.logics.top_bottom import top_bottom_connections, bottom_top_connections from retailtree.utils.dist_func import euclidean -import math -from random import sample -import numpy as np class RetailTree: def __init__(self) -> None: - self.annotations: dict[int, Annotation] = {} + self.annotations = {} # type:dict[int, Annotation] self.tree = None self.__neighbors_radius = None @@ -25,7 +27,8 @@ def add_annotation(self, annotation): """ self.annotations[annotation.id] = annotation - def get(self, id: int): + def get(self, id): + # type:(int) -> Annotation """ Retrieve an annotation by its ID. @@ -38,6 +41,7 @@ def get(self, id: int): return self.annotations[id] def build_tree(self, dist_func=euclidean): + # type:(Callable[[tuple[float, float], tuple[float, float]], float]) -> None """ Builds a Vantage Point Tree (VPTree) using the given distance function. @@ -55,6 +59,7 @@ def build_tree(self, dist_func=euclidean): self.tree = obj def __get_neighbors_radius(self): + # type:() -> float """ Method to get the radius within which neighbors will be searched for. Radius is the max of the diagonals of the annotations considered. @@ -70,14 +75,15 @@ def __get_neighbors_radius(self): self.__neighbors_radius = radius return radius - def __fetch_neighbors(self, id: int, radius=1): - + def __fetch_neighbors(self, id, radius=1): + # type:(int, int) -> list[tuple[float, Annotation]] radius = radius*self.__get_neighbors_radius() neighbors = self.tree.get_all_in_range( (self.annotations[id].x_mid, self.annotations[id].y_mid), radius) return neighbors def __finding_angle(self, origin, neighbor): + # type:(Annotation, tuple[float, Annotation]) -> int translated_point2 = np.array( [neighbor[1].x_mid, neighbor[1].y_mid]) - np.array([origin.x_mid, origin.y_mid]) # print(translated_point2) @@ -106,7 +112,8 @@ def __fetching_ann_in_range(self, result_dict, min_angle, max_angle, result_lst) # if min_angle is None: return result_lst - def neighbors_wa(self, id: int, radius=1, amin=None, amax=None): + def neighbors_wa(self, id, radius=1, amin=None, amax=None): + # type:(int, int, float, float) -> list[dict] """ Retrieves neighboring elements within a specified angle range around a given element. @@ -139,7 +146,8 @@ def neighbors_wa(self, id: int, radius=1, amin=None, amax=None): return result_lst - def neighbors(self, id: int, radius=1): + def neighbors(self, id, radius=1): + # type:(int, int) -> list[dict] """ Finds neighboring annotations within a specified radius of a given annotation.(Radius specified is taken as a square rather than a circle) @@ -177,7 +185,8 @@ def neighbors(self, id: int, radius=1): return result_lst - def TBLR(self, id: int, radius=1, overlap=0.5): + def TBLR(self, id, radius=1, overlap=0.5): + # type:(int, int, float) -> (dict[str, int | bool] | str) """ Computes top, bottom, left, and right connections for a given annotation within a specified radius. @@ -189,8 +198,9 @@ def TBLR(self, id: int, radius=1, overlap=0.5): - overlap (float, optional): The overlap percentage used to compute connections. Defaults to 0.5. Returns: - - dict: A dictionary containing top, bottom, left, and right connections of the given annotation. + - dict OR str: A dictionary containing top, bottom, left, and right connections of the given annotation. Each connection is represented by the ID of the connected annotation, or False if no connection exists. + If the SKU is not present in the bucket, a string with value 'SKU is absent in annotation bucket' is returned. Examples: Example usages of TBLR: diff --git a/retailtree/structs/annotation_struct.py b/retailtree/structs/annotation_struct.py index 77864e7..f4de548 100644 --- a/retailtree/structs/annotation_struct.py +++ b/retailtree/structs/annotation_struct.py @@ -1,5 +1,5 @@ -from typing import Optional, Dict, Any +from typing import Any class Annotation: @@ -25,7 +25,8 @@ class Annotation: """ - def __init__(self, id: int, x_min: float, y_min: float, x_max: float, y_max: float, label: Optional[Any] = None, metadata: Optional[Dict[Any, Any]] = None): + def __init__(self, id, x_min, y_min, x_max, y_max, label=None, metadata=None): + # type:(int, float, float, float, float, Any , dict[Any, Any]) -> None self.__id = int(id) self.__x_min = float(x_min) self.__x_max = float(x_max) @@ -40,10 +41,10 @@ def __init__(self, id: int, x_min: float, y_min: float, x_max: float, y_max: fl self.__width = None # TODO Check compatibility for older versions of python - self.right: "Annotation" = None - self.left: "Annotation" = None - self.top: "Annotation" = None - self.bottom: "Annotation" = None + self.right = None # type: Annotation + self.left = None # type: Annotation + self.top = None # type: Annotation + self.bottom = None # type: Annotation self.overlap_right = 0 self.overlap_top = 0 @@ -128,9 +129,3 @@ def get_coords(self): def __repr__(self) -> str: return str(self.id) - - -# ann = Annotation(10, 1, 1, 1, 1) - - -# print(ann.y_mid) diff --git a/retailtree/utils/dist_func.py b/retailtree/utils/dist_func.py index c6c9c79..d2a2ff1 100644 --- a/retailtree/utils/dist_func.py +++ b/retailtree/utils/dist_func.py @@ -3,9 +3,11 @@ # Function to calculate euclidean distance def euclidean(p1, p2): + # type:(tuple[float, float], tuple[float, float]) -> float return math.sqrt(pow((p2[0]-p1[0]), 2) + pow((p2[1]-p1[1]), 2)) # Function to calculate manhattan distance def manhattan(p1, p2): + # type:(tuple[float, float], tuple[float, float]) -> float return abs(p2[0] - p1[0]) + abs(p2[1] - p1[1]) diff --git a/setup.py b/setup.py index dd131d2..cd927fa 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ def read(fname): setup( name="retailtree", - version="1.3", + version="1.3.1", long_description=read("README.md"), long_description_content_type='text/markdown', packages=find_packages(),