Skip to content

Commit caa08b5

Browse files
committed
Apply refactoring to unify trace implementations more
Signed-off-by: Pierre R. Mai <pmai@pmsf.de>
1 parent 16abb2c commit caa08b5

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

osi3trace/osi_trace.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,9 @@ def _init_reader(self, path, type_name, cache_messages, topic):
106106
raise FileNotFoundError("File not found")
107107

108108
if path.suffix.lower() == ".mcap":
109-
reader = OSITraceMulti(path, topic)
110-
if reader.get_message_type() != type_name:
111-
raise ValueError(f"Channel message type '{reader.get_message_type()}' does not match expected type '{type_name}'")
112-
return reader
109+
return OSITraceMulti(path, type_name, topic)
113110
elif path.suffix.lower() in [".osi", ".lzma", ".xz"]:
114-
return OSITraceSingle(str(path), type_name, cache_messages)
111+
return OSITraceSingle(path, type_name, cache_messages)
115112
else:
116113
raise ValueError(f"Unsupported file format: '{path.suffix}'")
117114

@@ -212,7 +209,7 @@ def from_file(self, path, type_name="SensorView", cache_messages=False):
212209
"""Import a trace from a file"""
213210
self.type = OSITrace.map_message_type(type_name)
214211

215-
if path.lower().endswith((".lzma", ".xz")):
212+
if path.suffix.lower() in [".lzma", ".xz"]:
216213
self.file = lzma.open(path, "rb")
217214
else:
218215
self.file = open(path, "rb")
@@ -344,16 +341,16 @@ def close(self):
344341
class OSITraceMulti(ReaderBase):
345342
"""OSI multi-channel trace reader"""
346343

347-
def __init__(self, path, topic):
344+
def __init__(self, path, type_name, topic):
348345
self._file = open(path, "rb")
349346
self._mcap_reader = make_reader(self._file, decoder_factories=[DecoderFactory()])
350347
self._iter = None
351348
self._summary = self._mcap_reader.get_summary()
352-
available_topics = self.get_available_topics()
349+
available_topics = self.get_available_topics(type_name)
353350
if topic == None:
354-
topic = available_topics[0]
351+
topic = next(iter(available_topics), None)
355352
if topic not in available_topics:
356-
raise ValueError(f"The requested topic '{topic}' is not present in the trace file.")
353+
raise ValueError(f"The requested topic '{topic}' is not present in the trace file or is not of type '{type_name}'.")
357354
self.topic = topic
358355

359356
def restart(self, index=None):
@@ -376,8 +373,8 @@ def close(self):
376373
self._summary = None
377374
self._iter = None
378375

379-
def get_available_topics(self):
380-
return [channel.topic for id, channel in self._summary.channels.items()]
376+
def get_available_topics(self, type_name = None):
377+
return [channel.topic for channel in self._summary.channels.values() if _channel_is_of_type(channel, type_name)]
381378

382379
def get_file_metadata(self):
383380
metadata = []
@@ -392,11 +389,15 @@ def get_channel_metadata(self):
392389
return None
393390

394391
def get_message_type(self):
395-
for channel_id, channel in self._summary.channels.items():
392+
for channel in self._summary.channels.values():
396393
if channel.topic == self.topic:
397394
schema = self._summary.schemas[channel.schema_id]
398395
if schema.name.startswith("osi3."):
399396
return schema.name[len("osi3.") :]
400397
else:
401398
raise ValueError(f"Schema '{schema.name}' is not an 'osi3.' schema.")
402399
return None
400+
401+
def _channel_is_of_type(self, channel, type_name):
402+
schema = self._summary.schemas[channel.schema_id]
403+
return type_name is None or schema.name == f"osi3.{type_name}"

0 commit comments

Comments
 (0)