diff --git a/tests/test_crud.py b/tests/test_crud.py index af643c0..74bab44 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -17,6 +17,7 @@ ParentChildRelationship, Severity, ThermalEvent, + StrikeLineDescriptor, ThermalEventCategoryLineOfSight, User, crud, @@ -34,6 +35,7 @@ ) from .test_thermal_event import random_event as _random_event +from .test_strike_line_descriptor import random_strike_line_descriptor # Number of different datasets NB_DATASETS = 3 @@ -154,6 +156,7 @@ def reset_temporary_database(): session.query(ParentChildRelationship).delete() session.query(ThermalEvent).delete() + session.query(StrikeLineDescriptor).delete() session.commit() @@ -204,7 +207,7 @@ def test_thermal_event_update(): thermal_event.method = expected crud.thermal_event.update(thermal_event) - # Retriveve the event and check if its method has been updated + # Retrieve the event and check if its method has been updated thermal_event_db = crud.thermal_event.get(thermal_event.id) assert thermal_event_db.method == expected @@ -456,6 +459,125 @@ def test_thermal_event_change_analysis_status(): assert actual == expected +def test_strike_line_descriptor_create_read(): + # Generate a random strike line descriptor + expected = random_strike_line_descriptor() + + # Send the descriptor to the database + crud.strike_line_descriptor.create(expected) + + # Read the descriptor from the database + actual = crud.strike_line_descriptor.get(expected.id) + + assert actual.instance.bbox_x == expected.instance.bbox_x + assert actual.instance.bbox_y == expected.instance.bbox_y + assert actual.instance.bbox_width == expected.instance.bbox_width + assert actual.instance.bbox_height == expected.instance.bbox_height + assert actual.instance.timestamp_ns == expected.instance.timestamp_ns + + assert actual.segmented_points_as_list == expected.segmented_points_as_list + assert actual.angle == expected.angle + assert actual.curve == expected.curve + assert actual.flag_RT == expected.flag_RT + + +def test_strike_line_descriptor_get_multi(): + # Generate random descriptors + nb = 5 + descriptors = [random_strike_line_descriptor() for _ in range(nb)] + + # Send the descriprots to the database + crud.strike_line_descriptor.create(descriptors) + + # Store their ids + expected = [x.id for x in descriptors] + + # Retrieve the descriptors and their ids + descriptors_read = crud.strike_line_descriptor.get_multi(limit=nb) + actual = [x.id for x in descriptors_read] + + assert actual == expected + + +def test_strike_line_descriptor_update(): + # Generate a random descriptor + descriptor = random_strike_line_descriptor() + + # Send the descriptor to the database + crud.strike_line_descriptor.create(descriptor) + + # Update its angle with a different, randomly chosen one + expected = random.randrange(90) + descriptor.angle = expected + crud.strike_line_descriptor.update(descriptor) + + # Retrieve the descriptor and check if its angle has been updated + descriptor_db = crud.strike_line_descriptor.get(descriptor.id) + + assert descriptor_db.angle == expected + + +def test_strike_line_descriptor_delete(): + # Generate a random descriptor + descriptor = random_strike_line_descriptor() + + # Send the descriptor to the database + crud.strike_line_descriptor.create(descriptor) + + # Delete the descriptor and check it is deleted in the database + crud.strike_line_descriptor.delete(descriptor) + + assert crud.strike_line_descriptor.get(descriptor.id) is None + + +def test_strike_line_descriptor_get_by_columns(): + # Generate a random descriptor + ( + descriptor, + instance, + segmented_points, + angle, + curve, + flag_RT, + ) = random_strike_line_descriptor(return_parameters=True) + + # Send the descriptor to the database + crud.strike_line_descriptor.create(descriptor) + + # Query with the angle + actual = crud.strike_line_descriptor.get_by_columns(angle=angle) + assert len(actual) == 1 + + # Query with the curve + actual = crud.strike_line_descriptor.get_by_columns(curve=curve + 1) + assert len(actual) == 0 + actual = crud.strike_line_descriptor.get_by_columns( + curve=curve, return_columns=["id"] + ) + assert len(actual) == 1 + + # Query with both angle and curve + actual = crud.strike_line_descriptor.get_by_columns(angle=angle, curve=curve) + assert len(actual) == 1 + + +def test_strike_line_descriptor_get_by_flag_RT(): + # Generate a random descriptor + descriptor = random_strike_line_descriptor() + descriptor.flag_RT = True + + # Send the descriptor to the database + crud.strike_line_descriptor.create(descriptor) + + # Query the descriptors with flag_RT == True + actual = crud.strike_line_descriptor.get_by_flag_RT(True) + assert len(actual) == 1 + + # Query the descriptors with flag_RT == False + actual = crud.strike_line_descriptor.get_by_flag_RT(False) + assert len(actual) == 0 + + def test_user(): # Check the users list assert crud.user.list() == sorted([x[0] for x in users], key=lambda s: s.lower()) diff --git a/tests/test_strike_line_descriptor.py b/tests/test_strike_line_descriptor.py index 33d88fc..e219381 100644 --- a/tests/test_strike_line_descriptor.py +++ b/tests/test_strike_line_descriptor.py @@ -1,43 +1,53 @@ -import pytest -from thermal_events import StrikeLineDescriptor, ThermalEventInstance +import random +from tests.test_thermal_event import random_instance +from thermal_events import StrikeLineDescriptor -@pytest.fixture -def rectangle(): - return [100, 200, 25, 50] +def random_strike_line_descriptor(return_parameters=False): + instance = random_instance() -@pytest.fixture -def segmented_points(): - return [[0, 0], [2, 0], [4, 2], [1, 3], [0, 1]] + segmented_points = [ + [random.randint(1, 100), random.randint(1, 100)] + for _ in range(random.randint(5, 10)) + ] + angle = random.randrange(90) + curve = random.randrange(5) + flag_RT = random.choice([True, False]) - -def test_strike_line_descriptor(rectangle, segmented_points): - thermal_event_instance = ThermalEventInstance.from_rectangle( - rectangle, timestamp_ns=100 - ) strike_line_descriptor = StrikeLineDescriptor( - thermal_event_instance, segmented_points, 45, 2, flag_RT=True + instance, segmented_points, angle, curve, flag_RT=flag_RT ) - - assert strike_line_descriptor.instance.bbox_x == rectangle[0] - assert strike_line_descriptor.instance.bbox_y == rectangle[1] - assert strike_line_descriptor.instance.bbox_width == rectangle[2] - assert strike_line_descriptor.instance.bbox_height == rectangle[3] - assert strike_line_descriptor.instance.timestamp_ns == 100 + if return_parameters: + return strike_line_descriptor, instance, segmented_points, angle, curve, flag_RT + return strike_line_descriptor + + +def test_strike_line_descriptor(): + ( + strike_line_descriptor, + instance, + segmented_points, + angle, + curve, + flag_RT, + ) = random_strike_line_descriptor(True) + + assert strike_line_descriptor.instance.bbox_x == instance.bbox_x + assert strike_line_descriptor.instance.bbox_y == instance.bbox_y + assert strike_line_descriptor.instance.bbox_width == instance.bbox_width + assert strike_line_descriptor.instance.bbox_height == instance.bbox_height + assert strike_line_descriptor.instance.timestamp_ns == instance.timestamp_ns assert strike_line_descriptor.segmented_points_as_list == segmented_points - assert strike_line_descriptor.angle == 45 - assert strike_line_descriptor.curve == 2 - assert strike_line_descriptor.flag_RT + assert strike_line_descriptor.angle == angle + assert strike_line_descriptor.curve == curve + assert strike_line_descriptor.flag_RT == flag_RT -def test_segmented_points_as_list(rectangle, segmented_points): - thermal_event_instance = ThermalEventInstance.from_rectangle( - rectangle, timestamp_ns=100 - ) - strike_line_descriptor = StrikeLineDescriptor( - thermal_event_instance, segmented_points, 45, 2, flag_RT=True +def test_segmented_points_as_list(): + (strike_line_descriptor, _, segmented_points, *_) = random_strike_line_descriptor( + True ) assert strike_line_descriptor.return_segmented_points() == segmented_points