Skip to content

Commit

Permalink
Merge pull request #12 from ParallelDots/typehints-fix
Browse files Browse the repository at this point in the history
Fix: fix typehints to use comment typehints.
  • Loading branch information
ahwankumar authored Jun 4, 2024
2 parents 6f9f12b + 02a54aa commit 793ea11
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 27 deletions.
4 changes: 4 additions & 0 deletions retailtree/logics/overlap.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from retailtree.structs.annotation_struct import Annotation


def calculate_overlap(anno_1, anno_2, axis):
# type:(Annotation, Annotation, str) -> float
"""
Calculate overlap of two annotations in given axis.
Parameters
Expand Down
3 changes: 3 additions & 0 deletions retailtree/logics/right_left.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from retailtree.structs.annotation_struct import Annotation
from retailtree.logics.overlap import calculate_overlap


# Establishing right-left connections
def right_left_connections(i, j, given_overlap_percentage):
# type:(Annotation, Annotation, float) -> None
# Checking the candidate annotations is in the range of base annotations corresponds to x axis
if (i.x_min < j.x_min and i.x_max + i.length > j.x_min):
# Finding overlap percentage
Expand All @@ -22,6 +24,7 @@ def right_left_connections(i, j, given_overlap_percentage):

# Establishing left-right connection
def left_right_connections(i, j, given_overlap_percentage):
# type:(Annotation, Annotation, float) -> None
if j.right == None:
# Checking the candidate annotations is in the range of base annotations corresponds to x axis
if (i.x_min >= j.x_max and i.x_min - i.length < j.x_max):
Expand Down
3 changes: 3 additions & 0 deletions retailtree/logics/top_bottom.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from retailtree.structs.annotation_struct import Annotation
from retailtree.logics.overlap import calculate_overlap
# Establishing tob-bottom connections


def top_bottom_connections(i, j, given_overlap_percentage):
# type:(Annotation, Annotation, float) -> None
if (i.y_max > j.y_max and i.y_min - i.width < j.y_max):
# Finding overlap percentage
overlap_percentage = calculate_overlap(i, j, axis='x')/j.length
Expand All @@ -24,6 +26,7 @@ def top_bottom_connections(i, j, given_overlap_percentage):

# Establishing tob-bottom connections
def bottom_top_connections(i, j, given_overlap_percentage):
# type:(Annotation, Annotation, float) -> None
if j.top == None:
if (i.y_max <= j.y_min and i.y_max + i.width > j.y_min):
# Finding overlap percentage
Expand Down
10 changes: 7 additions & 3 deletions retailtree/logics/vp_tree.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Callable
import math
import statistics as stats
from typing import Callable

from retailtree.structs.annotation_struct import Annotation


class VPTree:
def __init__(self, points, dist_fn: Callable[[tuple[float, float], tuple[float, float]], float]):
def __init__(self, points, dist_fn):
# type:(list[Annotation], Callable[[tuple[float, float], tuple[float, float]], float]) -> None
self.left = None
self.right = None
self.left_min = math.inf
Expand Down Expand Up @@ -56,7 +59,8 @@ def _is_leaf(self):
return (self.left is None) and (self.right is None)

def get_all_in_range(self, query, max_distance):
neighbors = list()
# type: (tuple[float, float], float) -> list[tuple[float, Annotation]]
neighbors = list() # type: list[tuple[float, Annotation]]
nodes_to_visit = [(self, 0)]

while len(nodes_to_visit) > 0:
Expand Down
32 changes: 21 additions & 11 deletions retailtree/retailtree.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import math
from random import sample
import numpy as np
from typing import Callable

from retailtree.structs.annotation_struct import Annotation
from retailtree.logics.vp_tree import VPTree
from retailtree.logics.right_left import right_left_connections, left_right_connections
from retailtree.logics.top_bottom import top_bottom_connections, bottom_top_connections
from retailtree.utils.dist_func import euclidean
import math
from random import sample
import numpy as np


class RetailTree:
def __init__(self) -> None:
self.annotations: dict[int, Annotation] = {}
self.annotations = {} # type:dict[int, Annotation]
self.tree = None
self.__neighbors_radius = None

Expand All @@ -25,7 +27,8 @@ def add_annotation(self, annotation):
"""
self.annotations[annotation.id] = annotation

def get(self, id: int):
def get(self, id):
# type:(int) -> Annotation
"""
Retrieve an annotation by its ID.
Expand All @@ -38,6 +41,7 @@ def get(self, id: int):
return self.annotations[id]

def build_tree(self, dist_func=euclidean):
# type:(Callable[[tuple[float, float], tuple[float, float]], float]) -> None
"""
Builds a Vantage Point Tree (VPTree) using the given distance function.
Expand All @@ -55,6 +59,7 @@ def build_tree(self, dist_func=euclidean):
self.tree = obj

def __get_neighbors_radius(self):
# type:() -> float
"""
Method to get the radius within which neighbors will be searched for.
Radius is the max of the diagonals of the annotations considered.
Expand All @@ -70,14 +75,15 @@ def __get_neighbors_radius(self):
self.__neighbors_radius = radius
return radius

def __fetch_neighbors(self, id: int, radius=1):

def __fetch_neighbors(self, id, radius=1):
# type:(int, int) -> list[tuple[float, Annotation]]
radius = radius*self.__get_neighbors_radius()
neighbors = self.tree.get_all_in_range(
(self.annotations[id].x_mid, self.annotations[id].y_mid), radius)
return neighbors

def __finding_angle(self, origin, neighbor):
# type:(Annotation, tuple[float, Annotation]) -> int
translated_point2 = np.array(
[neighbor[1].x_mid, neighbor[1].y_mid]) - np.array([origin.x_mid, origin.y_mid])
# print(translated_point2)
Expand Down Expand Up @@ -106,7 +112,8 @@ def __fetching_ann_in_range(self, result_dict, min_angle, max_angle, result_lst)
# if min_angle is None:
return result_lst

def neighbors_wa(self, id: int, radius=1, amin=None, amax=None):
def neighbors_wa(self, id, radius=1, amin=None, amax=None):
# type:(int, int, float, float) -> list[dict]
"""
Retrieves neighboring elements within a specified angle range around a given element.
Expand Down Expand Up @@ -139,7 +146,8 @@ def neighbors_wa(self, id: int, radius=1, amin=None, amax=None):

return result_lst

def neighbors(self, id: int, radius=1):
def neighbors(self, id, radius=1):
# type:(int, int) -> list[dict]
"""
Finds neighboring annotations within a specified radius of a given annotation.(Radius specified is taken as a square rather than a circle)
Expand Down Expand Up @@ -177,7 +185,8 @@ def neighbors(self, id: int, radius=1):

return result_lst

def TBLR(self, id: int, radius=1, overlap=0.5):
def TBLR(self, id, radius=1, overlap=0.5):
# type:(int, int, float) -> (dict[str, int | bool] | str)
"""
Computes top, bottom, left, and right connections for a given annotation within a specified radius.
Expand All @@ -189,8 +198,9 @@ def TBLR(self, id: int, radius=1, overlap=0.5):
- overlap (float, optional): The overlap percentage used to compute connections. Defaults to 0.5.
Returns:
- dict: A dictionary containing top, bottom, left, and right connections of the given annotation.
- dict OR str: A dictionary containing top, bottom, left, and right connections of the given annotation.
Each connection is represented by the ID of the connected annotation, or False if no connection exists.
If the SKU is not present in the bucket, a string with value 'SKU is absent in annotation bucket' is returned.
Examples:
Example usages of TBLR:
Expand Down
19 changes: 7 additions & 12 deletions retailtree/structs/annotation_struct.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

from typing import Optional, Dict, Any
from typing import Any


class Annotation:
Expand All @@ -25,7 +25,8 @@ class Annotation:
"""

def __init__(self, id: int, x_min: float, y_min: float, x_max: float, y_max: float, label: Optional[Any] = None, metadata: Optional[Dict[Any, Any]] = None):
def __init__(self, id, x_min, y_min, x_max, y_max, label=None, metadata=None):
# type:(int, float, float, float, float, Any , dict[Any, Any]) -> None
self.__id = int(id)
self.__x_min = float(x_min)
self.__x_max = float(x_max)
Expand All @@ -40,10 +41,10 @@ def __init__(self, id: int, x_min: float, y_min: float, x_max: float, y_max: fl
self.__width = None

# TODO Check compatibility for older versions of python
self.right: "Annotation" = None
self.left: "Annotation" = None
self.top: "Annotation" = None
self.bottom: "Annotation" = None
self.right = None # type: Annotation
self.left = None # type: Annotation
self.top = None # type: Annotation
self.bottom = None # type: Annotation

self.overlap_right = 0
self.overlap_top = 0
Expand Down Expand Up @@ -128,9 +129,3 @@ def get_coords(self):

def __repr__(self) -> str:
return str(self.id)


# ann = Annotation(10, 1, 1, 1, 1)


# print(ann.y_mid)
2 changes: 2 additions & 0 deletions retailtree/utils/dist_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

# Function to calculate euclidean distance
def euclidean(p1, p2):
# type:(tuple[float, float], tuple[float, float]) -> float
return math.sqrt(pow((p2[0]-p1[0]), 2) + pow((p2[1]-p1[1]), 2))


# Function to calculate manhattan distance
def manhattan(p1, p2):
# type:(tuple[float, float], tuple[float, float]) -> float
return abs(p2[0] - p1[0]) + abs(p2[1] - p1[1])
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def read(fname):

setup(
name="retailtree",
version="1.3",
version="1.3.1",
long_description=read("README.md"),
long_description_content_type='text/markdown',
packages=find_packages(),
Expand Down

0 comments on commit 793ea11

Please sign in to comment.