Skip to content

Commit

Permalink
test: ✅ Add tests for StrikeLineDescriptor
Browse files Browse the repository at this point in the history
  • Loading branch information
egrelier committed Dec 8, 2023
1 parent 0568b9b commit 9b8f408
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 30 deletions.
124 changes: 123 additions & 1 deletion tests/test_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ParentChildRelationship,
Severity,
ThermalEvent,
StrikeLineDescriptor,
ThermalEventCategoryLineOfSight,
User,
crud,
Expand All @@ -34,6 +35,7 @@
)

from .test_thermal_event import random_event as _random_event
from .test_strike_line_descriptor import random_strike_line_descriptor

# Number of different datasets
NB_DATASETS = 3
Expand Down Expand Up @@ -154,6 +156,7 @@ def reset_temporary_database():

session.query(ParentChildRelationship).delete()
session.query(ThermalEvent).delete()
session.query(StrikeLineDescriptor).delete()
session.commit()


Expand Down Expand Up @@ -204,7 +207,7 @@ def test_thermal_event_update():
thermal_event.method = expected
crud.thermal_event.update(thermal_event)

# Retriveve the event and check if its method has been updated
# Retrieve the event and check if its method has been updated
thermal_event_db = crud.thermal_event.get(thermal_event.id)

assert thermal_event_db.method == expected
Expand Down Expand Up @@ -456,6 +459,125 @@ def test_thermal_event_change_analysis_status():
assert actual == expected


def test_strike_line_descriptor_create_read():
# Generate a random strike line descriptor
expected = random_strike_line_descriptor()

# Send the descriptor to the database
crud.strike_line_descriptor.create(expected)

# Read the descriptor from the database
actual = crud.strike_line_descriptor.get(expected.id)

assert actual.instance.bbox_x == expected.instance.bbox_x
assert actual.instance.bbox_y == expected.instance.bbox_y
assert actual.instance.bbox_width == expected.instance.bbox_width
assert actual.instance.bbox_height == expected.instance.bbox_height
assert actual.instance.timestamp_ns == expected.instance.timestamp_ns

assert actual.segmented_points_as_list == expected.segmented_points_as_list
assert actual.angle == expected.angle
assert actual.curve == expected.curve
assert actual.flag_RT == expected.flag_RT


def test_strike_line_descriptor_get_multi():
# Generate random descriptors
nb = 5
descriptors = [random_strike_line_descriptor() for _ in range(nb)]

# Send the descriprots to the database
crud.strike_line_descriptor.create(descriptors)

# Store their ids
expected = [x.id for x in descriptors]

# Retrieve the descriptors and their ids
descriptors_read = crud.strike_line_descriptor.get_multi(limit=nb)
actual = [x.id for x in descriptors_read]

assert actual == expected


def test_strike_line_descriptor_update():
# Generate a random descriptor
descriptor = random_strike_line_descriptor()

# Send the descriptor to the database
crud.strike_line_descriptor.create(descriptor)

# Update its angle with a different, randomly chosen one
expected = random.randrange(90)
descriptor.angle = expected
crud.strike_line_descriptor.update(descriptor)

# Retrieve the descriptor and check if its angle has been updated
descriptor_db = crud.strike_line_descriptor.get(descriptor.id)

assert descriptor_db.angle == expected


def test_strike_line_descriptor_delete():
# Generate a random descriptor
descriptor = random_strike_line_descriptor()

# Send the descriptor to the database
crud.strike_line_descriptor.create(descriptor)

# Delete the descriptor and check it is deleted in the database
crud.strike_line_descriptor.delete(descriptor)

assert crud.strike_line_descriptor.get(descriptor.id) is None


def test_strike_line_descriptor_get_by_columns():
# Generate a random descriptor
(
descriptor,
instance,
segmented_points,
angle,
curve,
flag_RT,
) = random_strike_line_descriptor(return_parameters=True)

# Send the descriptor to the database
crud.strike_line_descriptor.create(descriptor)

# Query with the angle
actual = crud.strike_line_descriptor.get_by_columns(angle=angle)
assert len(actual) == 1

# Query with the curve
actual = crud.strike_line_descriptor.get_by_columns(curve=curve + 1)
assert len(actual) == 0
actual = crud.strike_line_descriptor.get_by_columns(
curve=curve, return_columns=["id"]
)
assert len(actual) == 1

# Query with both angle and curve
actual = crud.strike_line_descriptor.get_by_columns(angle=angle, curve=curve)
assert len(actual) == 1


def test_strike_line_descriptor_get_by_flag_RT():
# Generate a random descriptor
descriptor = random_strike_line_descriptor()
descriptor.flag_RT = True

# Send the descriptor to the database
crud.strike_line_descriptor.create(descriptor)

# Query the descriptors with flag_RT == True
actual = crud.strike_line_descriptor.get_by_flag_RT(True)
assert len(actual) == 1

# Query the descriptors with flag_RT == False
actual = crud.strike_line_descriptor.get_by_flag_RT(False)
assert len(actual) == 0


def test_user():
# Check the users list
assert crud.user.list() == sorted([x[0] for x in users], key=lambda s: s.lower())
Expand Down
68 changes: 39 additions & 29 deletions tests/test_strike_line_descriptor.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,53 @@
import pytest
from thermal_events import StrikeLineDescriptor, ThermalEventInstance
import random

from tests.test_thermal_event import random_instance
from thermal_events import StrikeLineDescriptor

@pytest.fixture
def rectangle():
return [100, 200, 25, 50]

def random_strike_line_descriptor(return_parameters=False):
instance = random_instance()

@pytest.fixture
def segmented_points():
return [[0, 0], [2, 0], [4, 2], [1, 3], [0, 1]]
segmented_points = [
[random.randint(1, 100), random.randint(1, 100)]
for _ in range(random.randint(5, 10))
]
angle = random.randrange(90)
curve = random.randrange(5)
flag_RT = random.choice([True, False])


def test_strike_line_descriptor(rectangle, segmented_points):
thermal_event_instance = ThermalEventInstance.from_rectangle(
rectangle, timestamp_ns=100
)
strike_line_descriptor = StrikeLineDescriptor(
thermal_event_instance, segmented_points, 45, 2, flag_RT=True
instance, segmented_points, angle, curve, flag_RT=flag_RT
)

assert strike_line_descriptor.instance.bbox_x == rectangle[0]
assert strike_line_descriptor.instance.bbox_y == rectangle[1]
assert strike_line_descriptor.instance.bbox_width == rectangle[2]
assert strike_line_descriptor.instance.bbox_height == rectangle[3]
assert strike_line_descriptor.instance.timestamp_ns == 100
if return_parameters:
return strike_line_descriptor, instance, segmented_points, angle, curve, flag_RT
return strike_line_descriptor


def test_strike_line_descriptor():
(
strike_line_descriptor,
instance,
segmented_points,
angle,
curve,
flag_RT,
) = random_strike_line_descriptor(True)

assert strike_line_descriptor.instance.bbox_x == instance.bbox_x
assert strike_line_descriptor.instance.bbox_y == instance.bbox_y
assert strike_line_descriptor.instance.bbox_width == instance.bbox_width
assert strike_line_descriptor.instance.bbox_height == instance.bbox_height
assert strike_line_descriptor.instance.timestamp_ns == instance.timestamp_ns

assert strike_line_descriptor.segmented_points_as_list == segmented_points
assert strike_line_descriptor.angle == 45
assert strike_line_descriptor.curve == 2
assert strike_line_descriptor.flag_RT
assert strike_line_descriptor.angle == angle
assert strike_line_descriptor.curve == curve
assert strike_line_descriptor.flag_RT == flag_RT


def test_segmented_points_as_list(rectangle, segmented_points):
thermal_event_instance = ThermalEventInstance.from_rectangle(
rectangle, timestamp_ns=100
)
strike_line_descriptor = StrikeLineDescriptor(
thermal_event_instance, segmented_points, 45, 2, flag_RT=True
def test_segmented_points_as_list():
(strike_line_descriptor, _, segmented_points, *_) = random_strike_line_descriptor(
True
)

assert strike_line_descriptor.return_segmented_points() == segmented_points

0 comments on commit 9b8f408

Please sign in to comment.