Skip to content

Commit

Permalink
Improve filter api (#142)
Browse files Browse the repository at this point in the history
* feat: improves API of FilterChain

* chore: updates docs for updated API of FilterChain

* chore: bumps versions & updates changelog
  • Loading branch information
M0r13n authored Jun 26, 2024
1 parent b973000 commit 65acde3
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 56 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
====================
pyais CHANGELOG
====================
-------------------------------------------------------------------------------
Version 2.6.6 26 Jun 2024
-------------------------------------------------------------------------------
* improves the API of `FilterChain`
* `FilterChain.filter(stream)` now accepts a stream instance
* this stream MUST implement the `Stream` interface defined in pyais.stream
* individual messages can be filtered using `IterMessages(...)`
-------------------------------------------------------------------------------
Version 2.6.5 10 May 2024
-------------------------------------------------------------------------------
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ This is useful for debugging or for getting used to pyais.
It is also possible to encode messages.

| :exclamation: Every message needs at least a single keyword argument: `mmsi`. All other fields have most likely default values. |
|----------------------------------------------------------------------------------------------------------------------------------|
| -------------------------------------------------------------------------------------------------------------------------------- |

### Encode data using a dictionary

Expand Down Expand Up @@ -385,7 +385,7 @@ The filtering system is built around a series of filter classes, each designed t
## Example Usage

```python
from pyais import decode
from pyais import decode, TCPConnection
# ... (importing necessary classes)

# Define and initialize filters
Expand All @@ -405,8 +405,8 @@ chain = FilterChain([
])

# Decode AIS data and filter
data = [decode(b"!AIVDM..."), ...]
filtered_data = list(chain.filter(data))
stream = TCPConnection(...)
filtered_data = list(chain.filter(stream))

for msg in filtered_data:
print(msg.lat, msg.lon)
Expand Down
6 changes: 3 additions & 3 deletions docs/filters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Example Usage

.. code-block:: python
from pyais import decode
from pyais import decode, TCPConnection
# ... (importing necessary classes)
# Define and initialize filters
Expand All @@ -82,8 +82,8 @@ Example Usage
])
# Decode AIS data and filter
data = [decode(b"!AIVDM..."), ...]
filtered_data = list(chain.filter(data))
stream = TCPConnection(...)
filtered_data = chain.filter(stream)
for msg in filtered_data:
print(msg.lat, msg.lon)
Expand Down
25 changes: 6 additions & 19 deletions examples/filters.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from pyais import decode
from pyais.filter import (
AttributeFilter,
DistanceFilter,
Expand All @@ -7,6 +6,7 @@
MessageTypeFilter,
NoneFilter
)
from pyais.stream import TCPConnection

# Define the filter chain with various criteria
chain = FilterChain([
Expand All @@ -26,21 +26,8 @@
GridFilter(lat_min=50, lon_min=0, lat_max=52, lon_max=5),
])

# Example AIS data to filter
data = [
decode(b"!AIVDM,1,1,,B,15NG6V0P01G?cFhE`R2IU?wn28R>,0*05"),
decode(b"!AIVDM,1,1,,A,13HOI:0P0000VOHLCnHQKwvL05Ip,0*23"),
decode(b"!AIVDM,1,1,,B,100h00PP0@PHFV`Mg5gTH?vNPUIp,0*3B"),
decode(b"!AIVDM,1,1,,A,133sVfPP00PD>hRMDH@jNOvN20S8,0*7F"),
decode(b"!AIVDM,1,1,,B,13eaJF0P00Qd388Eew6aagvH85Ip,0*45"),
decode(b"!AIVDM,1,1,,A,14eGrSPP00ncMJTO5C6aBwvP2D0?,0*7A"),
decode(b"!AIVDM,1,1,,A,15MrVH0000KH<:V:NtBLoqFP2H9:,0*2F"),
decode(b"!AIVDM,1,1,,A,702R5`hwCjq8,0*6B"),
]

# Filter the data using the defined chain
filtered_data = list(chain.filter(data))

# Print the latitude and longitude of each message that passed the filters
for msg in filtered_data:
print(msg.lat, msg.lon)
# Create a stream of ais messages
with TCPConnection('153.44.253.27', port=5631) as ais_stream:
for ais_msg in chain.filter(ais_stream):
# Only messages that pass this filter chain are printed
print(ais_msg)
5 changes: 3 additions & 2 deletions pyais/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from pyais.messages import NMEAMessage, ANY_MESSAGE, AISSentence
from pyais.stream import TCPConnection, FileReaderStream, IterMessages
from pyais.stream import TCPConnection, FileReaderStream, IterMessages, Stream
from pyais.encode import encode_dict, encode_msg, ais_to_nmea_0183
from pyais.decode import decode
from pyais.tracker import AISTracker, AISTrack

__license__ = 'MIT'
__version__ = '2.6.5'
__version__ = '2.6.6'
__author__ = 'Leon Morten Richter'

__all__ = (
Expand All @@ -18,6 +18,7 @@
'TCPConnection',
'IterMessages',
'FileReaderStream',
'Stream',
'decode',
'AISTracker',
'AISTrack',
Expand Down
57 changes: 30 additions & 27 deletions pyais/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
"""

import math
import socket
import typing
import pyais

# Type Aliases for readability
AIS_STREAM = typing.Generator[pyais.AISSentence, None, None]
FILTER_FUNCTION = typing.Callable[[pyais.AISSentence], bool]
F = typing.TypeVar("F", typing.BinaryIO, socket.socket, None)
AIS_STREAM = pyais.Stream[F]
MESSAGE_STREAM = typing.Generator[pyais.ANY_MESSAGE, None, None]
FILTER_FUNCTION = typing.Callable[[pyais.ANY_MESSAGE], bool]
LAT_LON = typing.Tuple[float, float] # Tuple type for latitude and longitude


Expand Down Expand Up @@ -66,30 +69,30 @@ def set_next(self, filter: 'Filter') -> None:
"""
self.next_filter = filter

def filter(self, data: AIS_STREAM) -> AIS_STREAM:
def filter(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM:
"""
Apply the filter to the data and then pass it to the next filter.
Parameters:
data (AIS_STREAM): The stream of data to filter.
data (MESSAGE_STREAM): The stream of data to filter.
Returns:
AIS_STREAM: The filtered data stream.
MESSAGE_STREAM: The filtered data stream.
"""
data = self.filter_data(data)
if self.next_filter:
return self.next_filter.filter(data)
return data

def filter_data(self, data: AIS_STREAM) -> AIS_STREAM:
def filter_data(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM:
"""
Abstract method to filter data. Should be implemented by subclasses.
Parameters:
data (AIS_STREAM): The stream of data to filter.
data (MESSAGE_STREAM): The stream of data to filter.
Returns:
AIS_STREAM: The filtered data stream.
MESSAGE_STREAM: The filtered data stream.
"""
raise NotImplementedError("This method should be overridden by subclasses.")

Expand All @@ -109,15 +112,15 @@ def __init__(self, ff: FILTER_FUNCTION) -> None:
super().__init__()
self.ff = ff

def filter_data(self, data: AIS_STREAM) -> AIS_STREAM:
def filter_data(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM:
"""
Filter the data based on the user-defined function.
Parameters:
data (AIS_STREAM): The stream of data to filter.
data (MESSAGE_STREAM): The stream of data to filter.
Yields:
AIS_STREAM: The filtered data stream.
MESSAGE_STREAM: The filtered data stream.
"""
yield from filter(self.ff, data)

Expand All @@ -137,15 +140,15 @@ def __init__(self, *attrs: str) -> None:
super().__init__()
self.attrs = attrs

def filter_data(self, data: AIS_STREAM) -> AIS_STREAM:
def filter_data(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM:
"""
Filter the data, allowing only messages where specified attributes are not None.
Parameters:
data (AIS_STREAM): The stream of data to filter.
data (MESSAGE_STREAM): The stream of data to filter.
Yields:
AIS_STREAM: The filtered data stream.
MESSAGE_STREAM: The filtered data stream.
"""
for msg in data:
if all(getattr(msg, attr, None) is not None for attr in self.attrs):
Expand All @@ -167,18 +170,18 @@ def __init__(self, *types: int) -> None:
super().__init__()
self.types = types

def filter_data(self, data: AIS_STREAM) -> AIS_STREAM:
def filter_data(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM:
"""
Filter the data, allowing only messages of specified types.
Parameters:
data (AIS_STREAM): The stream of data to filter.
data (MESSAGE_STREAM): The stream of data to filter.
Yields:
AIS_STREAM: The filtered data stream.
MESSAGE_STREAM: The filtered data stream.
"""
for msg in data:
if msg.msg_type not in self.types: # type: ignore
if msg.msg_type not in self.types:
continue
yield msg

Expand All @@ -200,15 +203,15 @@ def __init__(self, ref_lat_lon: LAT_LON, distance_km: float) -> None:
self.ref_lat_lon = ref_lat_lon
self.distance_km = distance_km

def filter_data(self, data: AIS_STREAM) -> AIS_STREAM:
def filter_data(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM:
"""
Filter the data based on distance from a reference point.
Parameters:
data (AIS_STREAM): The stream of data to filter.
data (MESSAGE_STREAM): The stream of data to filter.
Yields:
AIS_STREAM: The filtered data stream.
MESSAGE_STREAM: The filtered data stream.
"""
for msg in data:
if hasattr(msg, 'lat'):
Expand All @@ -235,15 +238,15 @@ def __init__(self, lat_min: float, lon_min: float, lat_max: float, lon_max: floa
self.lat_max = lat_max
self.lon_max = lon_max

def filter_data(self, data: AIS_STREAM) -> AIS_STREAM:
def filter_data(self, data: MESSAGE_STREAM) -> MESSAGE_STREAM:
"""
Filter the data based on whether it falls within a specified grid.
Parameters:
data (AIS_STREAM): The stream of data to filter.
data (MESSAGE_STREAM): The stream of data to filter.
Yields:
AIS_STREAM: The filtered data stream.
MESSAGE_STREAM: The filtered data stream.
"""
for msg in data:
if hasattr(msg, 'lat'):
Expand Down Expand Up @@ -274,14 +277,14 @@ def __init__(self, filters: typing.List[Filter]) -> None:
self.filters = filters
self.start = filters[0]

def filter(self, data: AIS_STREAM) -> AIS_STREAM:
def filter(self, stream: AIS_STREAM[F]) -> MESSAGE_STREAM:
"""
Apply the chain of filters to the data.
Parameters:
data (AIS_STREAM): The stream of data to filter.
stream (AIS_STREAM): The stream of data to filter.
Yields:
AIS_STREAM: The filtered data stream.
"""
yield from self.start.filter(data)
yield from self.start.filter(x.decode() for x in stream)
10 changes: 9 additions & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@
import subprocess
import unittest

KEYWORDS_TO_IGNORE = (
'tcp',
'udp',
'live',
'tracking',
'filters',
)


class TestExamples(unittest.TestCase):
"""
Expand All @@ -14,7 +22,7 @@ def test_run_every_file(self):
i = -1
exe = sys.executable
for i, file in enumerate(pathlib.Path(__file__).parent.parent.joinpath('examples').glob('*.py')):
if 'tcp' not in str(file) and 'udp' not in str(file) and 'live' not in str(file) and 'tracking' not in str(file):
if all(kw not in str(file) for kw in KEYWORDS_TO_IGNORE):
env = os.environ
env['PYTHONPATH'] = f':{pathlib.Path(__file__).parent.parent.absolute()}'
assert subprocess.check_call(f'{exe} {file}'.split(), env=env, shell=False) == 0
Expand Down
22 changes: 22 additions & 0 deletions tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pathlib
import unittest
from pyais.filter import AttributeFilter, DistanceFilter, FilterChain, GridFilter, MessageTypeFilter, NoneFilter, haversine
from pyais.stream import FileReaderStream


class MockAISMessage:
Expand All @@ -9,6 +11,9 @@ def __init__(self, msg_type=None, lat=None, lon=None, other_attr=None):
self.lon = lon
self.other_attr = other_attr

def decode(self):
return self


class TestNoneFilter(unittest.TestCase):
def test_filtering_none_attributes(self):
Expand Down Expand Up @@ -165,6 +170,7 @@ def test_filter_chain(self):
filter1 = NoneFilter('lat', 'lon')
filter2 = MessageTypeFilter(1, 2)
chain = FilterChain([filter1, filter2])

mock_data = [MockAISMessage(lat=1, lon=1, msg_type=1), MockAISMessage(lat=None, lon=1, msg_type=2)]

# Execute
Expand Down Expand Up @@ -200,6 +206,22 @@ def test_complex_filter_chain(self):
self.assertEqual(filtered_data[0].lon, -73.965)
self.assertEqual(filtered_data[0].msg_type, 1)

def test_filter_chain_with_file_stream(self):
# Setup: Define the filters and chain
chain = FilterChain([AttributeFilter(lambda x: x.mmsi == 445451000)])

# Setup: Define sample file
file = pathlib.Path(__file__).parent.joinpath('messages.ais')

with FileReaderStream(file) as ais_stream:
total = len(list(ais_stream))

with FileReaderStream(file) as ais_stream:
filtered = list(chain.filter(ais_stream))

self.assertEqual(len(filtered), 2)
self.assertEqual(total, 6)


class TestAttributeFilter(unittest.TestCase):
def test_attribute_filtering(self):
Expand Down

0 comments on commit 65acde3

Please sign in to comment.