diff --git a/dataloader/instance_augmentation.py b/dataloader/instance_augmentation.py index c470840..77c9a1c 100644 --- a/dataloader/instance_augmentation.py +++ b/dataloader/instance_augmentation.py @@ -159,7 +159,10 @@ def instance_flip(self, points,axis,center,flip_type = 1): def check_occlusion(self,points,center,min_dist=2): 'check if close to a point' - dist = np.linalg.norm(points-center,axis=0) + if points.ndim == 1: + dist = np.linalg.norm(points[np.newaxis,:]-center,axis=1) + else: + dist = np.linalg.norm(points-center,axis=1) return np.all(dist>min_dist) def rotate_origin(self,xyz,radians):