-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcalc_distance_cvact.py
64 lines (44 loc) · 1.75 KB
/
calc_distance_cvact.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import numpy as np
from sklearn.metrics import DistanceMetric
from auxgeo.dataset.cvact import CVACTDatasetTrain
import scipy.io as sio
import pickle
import torch
TOP_K = 128
dataset = CVACTDatasetTrain(
data_folder=r"/media/xiapanwang/主数据盘/xiapanwang/Codes/python/New_Geolocalization/0_Datasets/CVACT")
anuData = sio.loadmat(
r'/media/xiapanwang/主数据盘/xiapanwang/Codes/python/New_Geolocalization/0_Datasets/CVACT/ACT_data.mat')
utm = anuData["utm"]
ids = anuData['panoIds']
idx2numidx = dataset.idx2numidx
train_ids_set = set(dataset.train_ids)
train_idsnum_list = []
utm_coords = dict()
utm_coords_list = []
for i, idx in enumerate(ids):
idx = str(idx)
if idx in train_ids_set:
coordinates = (float(utm[i][0]), float(utm[i][1]))
utm_coords[idx] = coordinates
utm_coords_list.append(coordinates)
train_idsnum_list.append(idx2numidx[idx])
print("Length Train Ids:", len(utm_coords_list))
train_idsnum_lookup = np.array(train_idsnum_list)
print("Length of gps coords : " + str(len(utm_coords_list)))
print("Calculation...")
dist = DistanceMetric.get_metric("euclidean")
dm = dist.pairwise(utm_coords_list, utm_coords_list)
print("Distance Matrix:", dm.shape)
dm_torch = torch.from_numpy(dm)
dm_torch = dm_torch.fill_diagonal_(dm.max())
values, ids = torch.topk(dm_torch, k=TOP_K, dim=1, largest=False)
values_near_numpy = values.numpy()
ids_near_numpy = ids.numpy()
near_neighbors = dict()
for i, idnum in enumerate(train_idsnum_list):
near_neighbors[idnum] = train_idsnum_lookup[ids_near_numpy[i]].tolist()
print("Saving...")
with open(r"/media/xiapanwang/主数据盘/xiapanwang/Codes/python/New_Geolocalization/0_Datasets/CVACT/gps_dict.pkl",
"wb") as f:
pickle.dump(near_neighbors, f)