Skip to content

Commit

Permalink
🐛fix:utm座標系に変換して計算
Browse files Browse the repository at this point in the history
  • Loading branch information
hcmos committed Oct 29, 2024
1 parent 5a82c2c commit 9cba5c2
Showing 1 changed file with 71 additions and 90 deletions.
161 changes: 71 additions & 90 deletions scripts/path_error_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,34 @@
import math
import csv
import matplotlib.pyplot as plt
from pyproj import Transformer

class GnssDataProcessor:
def __init__(self, bag_file_path, target_csv_path):
self.target_path = self.load_target_path(target_csv_path) # CSVファイル名を指定

self.target_path = self.load_target_path(target_csv_path)
self.tracked_path, self.total_duration, self.average_frequency = self.read_rosbag_data(bag_file_path)

self.fig, self.ax = plt.subplots()
self.plot_data() # プロット
self.transformer = None
if self.tracked_path:
lat, lon = self.tracked_path[0]
self.set_utm_transformer(lat, lon)

self.errors = []

# 経路グラフを描画
self.plot_path()
self.plot_error()

def set_utm_transformer(self, lat, lon):
utm_zone = int((lon + 180) / 6) + 1
self.transformer = Transformer.from_crs("EPSG:4326", f"EPSG:326{utm_zone}", always_xy=True)

def load_target_path(self, file_path):
path = []

try:
with open(file_path, 'r') as file:
reader = csv.reader(file)
next(reader) # ヘッダー行をスキップ
next(reader)
for row in reader:
lat = float(row[0])
lon = float(row[1])
Expand All @@ -38,7 +49,6 @@ def read_rosbag_data(self, bag_file_path):
last_timestamp = None
intervals = []

# SequentialReaderを作成
reader = rosbag2_py.SequentialReader()
storage_options = rosbag2_py.StorageOptions(uri=bag_file_path, storage_id='sqlite3')
converter_options = rosbag2_py.ConverterOptions(
Expand All @@ -47,12 +57,10 @@ def read_rosbag_data(self, bag_file_path):
)
reader.open(storage_options, converter_options)

# トピック情報を取得
topic_types = reader.get_all_topics_and_types()
type_map = {topic.name: topic.type for topic in topic_types}

# /gps/fixトピックをフィルタリング
topic_name = '/vectornav/gnss' # GNSSデータのトピック名に合わせて変更
topic_name = '/vectornav/gnss' # gnssトピックの指定
if topic_name not in type_map:
print(f"Topic '{topic_name}' not found in the bag file.")
return tracked_path, 0.0, 0.0
Expand All @@ -62,90 +70,53 @@ def read_rosbag_data(self, bag_file_path):
while reader.has_next():
(topic, data, t) = reader.read_next()
if topic == topic_name:
# 最初と最後のタイムスタンプを記録
if first_timestamp is None:
first_timestamp = t
last_timestamp = t

# NavSatFixメッセージのデシリアライズ
msg = deserialize_message(data, NavSatFix)
tracked_path.append((msg.latitude, msg.longitude))

# メッセージ間隔を計算
if previous_timestamp is not None:
interval = (t - previous_timestamp) * 1e-9 # 秒単位
interval = (t - previous_timestamp) * 1e-9
intervals.append(interval)

previous_timestamp = t

# 累積時間を計算 (秒単位)
total_duration = (last_timestamp - first_timestamp) * 1e-9 if first_timestamp and last_timestamp else 0.0

# 平均周波数を計算 (Hz)
if intervals:
average_interval = sum(intervals) / len(intervals)
average_frequency = 1.0 / average_interval if average_interval > 0 else 0.0
else:
average_frequency = 0.0
average_frequency = 1.0 / (sum(intervals) / len(intervals)) if intervals else 0.0

return tracked_path, total_duration, average_frequency

def calculate_min_distance_to_path(self, lat, lon):
min_distance = float('inf')

# 経路の各セグメントを調べ、最短距離を見つける
for i in range(len(self.target_path) - 1):
start = self.target_path[i]
end = self.target_path[i + 1]

# 現在の位置からセグメントへの最短距離を計算
distance = self.calculate_distance_to_segment(lat, lon, start, end)
min_distance = min(min_distance, distance)

return min_distance

def calculate_distance_to_segment(self, lat, lon, start, end):
# セグメントの端点の座標
lat1, lon1 = start
lat2, lon2 = end

# 現在の位置とセグメントの端点の距離を計算
dist_start_to_point = self.calculate_distance(lat1, lon1, lat, lon)
dist_end_to_point = self.calculate_distance(lat2, lon2, lat, lon)
dist_start_to_end = self.calculate_distance(lat1, lon1, lat2, lon2)

# 内積を使用して点がセグメント上にあるかを判定
if dist_start_to_end == 0:
# セグメントの始点と終点が同じ場合
return dist_start_to_point

# 点がセグメントの外側にある場合
t = ((lat - lat1) * (lat2 - lat1) + (lon - lon1) * (lon2 - lon1)) / (dist_start_to_end ** 2)
if t < 0.0:
return dist_start_to_point
elif t > 1.0:
return dist_end_to_point

# 点がセグメントの内側にある場合、垂線の距離を計算
proj_lat = lat1 + t * (lat2 - lat1)
proj_lon = lon1 + t * (lon2 - lon1)
return self.calculate_distance(lat, lon, proj_lat, proj_lon)

def calculate_distance(self, lat1, lon1, lat2, lon2):
R = 6371000 # 地球の半径 (メートル)

lat1_rad = math.radians(lat1)
lat2_rad = math.radians(lat2)
delta_lat = math.radians(lat2 - lat1)
delta_lon = math.radians(lon2 - lon1)

a = math.sin(delta_lat / 2) ** 2 + math.cos(lat1_rad) * math.cos(lat2_rad) * math.sin(delta_lon / 2) ** 2
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))

distance = R * c
return distance

def plot_data(self):
x1, y1 = self.transformer.transform(start[1], start[0])
x2, y2 = self.transformer.transform(end[1], end[0])
x0, y0 = self.transformer.transform(lon, lat)

dx, dy = x2 - x1, y2 - y1
t = ((x0 - x1) * dx + (y0 - y1) * dy) / (dx ** 2 + dy ** 2)

if t < 0:
closest_x, closest_y = x1, y1
elif t > 1:
closest_x, closest_y = x2, y2
else:
closest_x, closest_y = x1 + t * dx, y1 + t * dy

return math.sqrt((x0 - closest_x) ** 2 + (y0 - closest_y) ** 2)

def plot_path(self):
target_lats = [lat for lat, lon in self.target_path]
target_lons = [lon for lat, lon in self.target_path]
tracked_lats = [lat for lat, lon in self.tracked_path]
Expand All @@ -156,35 +127,45 @@ def plot_data(self):
max_error_coord = ()
datasize = 0

# GNSSデータからの誤差を計算
for i, (lat, lon) in enumerate(self.tracked_path):
min_distance = self.calculate_min_distance_to_path(lat, lon)
accumulated_error += min_distance
if min_distance > max_error:
max_error_coord = (lon, lat)
max_error = min_distance

datasize = datasize + 1

# グラフ表示
self.ax.plot(target_lons, target_lats, 'bo-', label='Target Path')
self.ax.plot(tracked_lons, tracked_lats, 'ro-', label='Tracked Path')
self.ax.plot(max_error_coord[0], max_error_coord[1], 'gx', ms=10 , label='Max Error')
self.ax.plot(tracked_lons[0], tracked_lats[0], 'ko', ms=5 , label='Start')
self.ax.plot(tracked_lons[-1], tracked_lats[-1], 'kx', ms=5 , label='Goal')

self.ax.set_title('GNSS Tracking & Target Path')
self.ax.set_xlabel('Longitude')
self.ax.set_ylabel('Latitude')
self.ax.legend(loc='upper right')

# 累積誤差、累積時間、平均周波数を表示
self.ax.text(0.02, 0.95, f'Accumulated error: {accumulated_error:.2f} m', transform=self.ax.transAxes)
self.ax.text(0.02, 0.90, f'Total duration: {self.total_duration:.2f} sec', transform=self.ax.transAxes)
self.ax.text(0.02, 0.85, f'Average frequency: {self.average_frequency:.2f} Hz', transform=self.ax.transAxes)
self.ax.text(0.02, 0.80, f'Max error: {max_error:.2f}m ', transform=self.ax.transAxes)

print(f'1データごとの平均誤差: {accumulated_error/datasize:.2f} m')
datasize += 1
self.errors.append(min_distance)

plt.figure()
plt.plot(target_lons, target_lats, 'bo-', label='Target Path')
plt.plot(tracked_lons, tracked_lats, 'ro-', label='Tracked Path')
plt.plot(max_error_coord[0], max_error_coord[1], 'gx', ms=10, label='Max Error')
plt.plot(tracked_lons[0], tracked_lats[0], 'ko', ms=5, label='Start')
plt.plot(tracked_lons[-1], tracked_lats[-1], 'kx', ms=5, label='Goal')

plt.title('GNSS Tracking & Target Path')
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.legend(loc='upper right')

# axで相対位置に表示させる
ax = plt.gca()
ax.text(0.02, 0.95, f'Accumulated error: {accumulated_error:.2f} m', transform=ax.transAxes, fontsize=10)
ax.text(0.02, 0.90, f'Total duration: {self.total_duration:.2f} sec', transform=ax.transAxes, fontsize=10)
ax.text(0.02, 0.85, f'Average frequency: {self.average_frequency:.2f} Hz', transform=ax.transAxes, fontsize=10)
ax.text(0.02, 0.80, f'Max error: {max_error:.2f} m', transform=ax.transAxes, fontsize=10)

def plot_error(self):
plt.figure()
errors = self.errors

plt.plot(errors, 'r-', label='Error Distance (m)')
plt.title('Errors')
plt.xlabel('Data Point Index')
plt.ylabel('Error Distance (m)')
plt.legend(loc='upper right')

print(f'1データごとの平均誤差: {sum(errors) / len(errors):.2f} m')

plt.show()

Expand All @@ -199,7 +180,7 @@ def main():
'config',
'course_data',
'gazebo_shihou_course.csv'
) # 目標経路
) # 目標csv

GnssDataProcessor(bag_file_path, target_csv_path)

Expand Down

0 comments on commit 9cba5c2

Please sign in to comment.