11import unittest
22import torch
3- from torch_points_kernels import ball_query
43import numpy .testing as npt
54import numpy as np
65from sklearn .neighbors import KDTree
6+ import os
7+ import sys
8+
9+ ROOT = os .path .join (os .path .dirname (os .path .realpath (__file__ )), ".." )
10+ sys .path .insert (0 , ROOT )
711
8- from . import run_if_cuda
12+ from test import run_if_cuda
13+ from torch_points_kernels import ball_query
914
1015
1116class TestBall (unittest .TestCase ):
@@ -76,10 +81,10 @@ def test_simple_gpu(self):
7681 npt .assert_array_almost_equal (dist2 , dist2_answer )
7782
7883 def test_simple_cpu (self ):
79- x = torch .tensor ([[10 , 0 , 0 ], [0.1 , 0 , 0 ], [10 , 0 , 0 ], [0 .1 , 0 , 0 ]]).to (torch .float )
84+ x = torch .tensor ([[10 , 0 , 0 ], [0.1 , 0 , 0 ], [10 , 0 , 0 ], [10 .1 , 0 , 0 ]]).to (torch .float )
8085 y = torch .tensor ([[0 , 0 , 0 ]]).to (torch .float )
8186
82- batch_x = torch .from_numpy (np .asarray ([0 , 0 , 1 , 1 ])).long ()
87+ batch_x = torch .from_numpy (np .asarray ([0 , 0 , 0 , 0 ])).long ()
8388 batch_y = torch .from_numpy (np .asarray ([0 ])).long ()
8489
8590 idx , dist2 = ball_query (1.0 , 2 , x , y , mode = "PARTIAL_DENSE" , batch_x = batch_x , batch_y = batch_y )
@@ -93,6 +98,17 @@ def test_simple_cpu(self):
9398 npt .assert_array_almost_equal (idx , idx_answer )
9499 npt .assert_array_almost_equal (dist2 , dist2_answer )
95100
101+
102+ def test_breaks (self ):
103+ x = torch .tensor ([[10 , 0 , 0 ], [0.1 , 0 , 0 ], [10 , 0 , 0 ], [10.1 , 0 , 0 ]]).to (torch .float )
104+ y = torch .tensor ([[0 , 0 , 0 ]]).to (torch .float )
105+
106+ batch_x = torch .from_numpy (np .asarray ([0 , 0 , 1 , 1 ])).long ()
107+ batch_y = torch .from_numpy (np .asarray ([0 ])).long ()
108+
109+ with self .assertRaises (RuntimeError ):
110+ idx , dist2 = ball_query (1.0 , 2 , x , y , mode = "PARTIAL_DENSE" , batch_x = batch_x , batch_y = batch_y )
111+
96112 def test_random_cpu (self ):
97113 a = torch .randn (100 , 3 ).to (torch .float )
98114 b = torch .randn (50 , 3 ).to (torch .float )
0 commit comments