From 1ff4c4897bb8aa0a4e95848d03d6ff4eb290df29 Mon Sep 17 00:00:00 2001 From: Samuel Jones Date: Mon, 27 Jan 2025 15:42:32 +0000 Subject: [PATCH] Sans2d script transform (#439) * Add SANS2D transforms * Ruff linting fix * Formatting and linting commit * Update fia_api/scripts/transforms/sans_transform.py * Update fia_api/scripts/transforms/sans_transform.py Co-authored-by: keiranjprice101 <44777678+keiranjprice101@users.noreply.github.com> --------- Co-authored-by: github-actions Co-authored-by: keiranjprice101 <44777678+keiranjprice101@users.noreply.github.com> --- fia_api/scripts/transforms/factory.py | 6 +- .../{loq_transform.py => sans_transform.py} | 12 +- test/scripts/transforms/test_factory.py | 3 + .../scripts/transforms/test_loq_transforms.py | 12 +- .../transforms/test_sans2d_transforms.py | 170 ++++++++++++++++++ 5 files changed, 190 insertions(+), 13 deletions(-) rename fia_api/scripts/transforms/{loq_transform.py => sans_transform.py} (87%) create mode 100644 test/scripts/transforms/test_sans2d_transforms.py diff --git a/fia_api/scripts/transforms/factory.py b/fia_api/scripts/transforms/factory.py index dae94354..d9943676 100644 --- a/fia_api/scripts/transforms/factory.py +++ b/fia_api/scripts/transforms/factory.py @@ -5,9 +5,9 @@ import logging from fia_api.scripts.transforms.iris_transform import IrisTransform -from fia_api.scripts.transforms.loq_transform import LoqTransform from fia_api.scripts.transforms.mari_transforms import MariTransform from fia_api.scripts.transforms.osiris_transform import OsirisTransform +from fia_api.scripts.transforms.sans_transform import SansTransform from fia_api.scripts.transforms.test_transforms import TestTransform from fia_api.scripts.transforms.tosca_transform import ToscaTransform from fia_api.scripts.transforms.transform import MissingTransformError, Transform @@ -29,8 +29,8 @@ def get_transform_for_instrument(instrument: str) -> Transform: return ToscaTransform() case "osiris": return OsirisTransform() - case "loq": - return LoqTransform() + case "loq" | "sans2d": + return SansTransform() case "iris": return IrisTransform() case "test": diff --git a/fia_api/scripts/transforms/loq_transform.py b/fia_api/scripts/transforms/sans_transform.py similarity index 87% rename from fia_api/scripts/transforms/loq_transform.py rename to fia_api/scripts/transforms/sans_transform.py index ffeed6d5..de3d17e7 100644 --- a/fia_api/scripts/transforms/loq_transform.py +++ b/fia_api/scripts/transforms/sans_transform.py @@ -16,20 +16,22 @@ # mypy: disable-error-code="operator, index" -class LoqTransform(Transform): +class SansTransform(Transform): """ - LoqTransform applies modifications to LOQ instrument scripts based on reduction input parameters in a Job + SansTransform applies modifications to SANS instrument scripts based on reduction input parameters in a Job entity. """ def apply(self, script: PreScript, job: Job) -> None: # noqa: C901, PLR0912 - logger.info("Beginning LOQ transform for job %s...", job.id) + logger.info("Beginning %s transform for job %s...", job.instrument, job.id) lines = script.value.splitlines() # MyPY does not believe ColumnElement[JSONB] is indexable, despite JSONB implementing the Indexable mixin # If you get here in the future, try removing the type ignore and see if it passes with newer mypy for index, line in enumerate(lines): - if "/extras/loq/MaskFile.toml" in line and "user_file" in job.inputs: - lines[index] = line.replace("/extras/loq/MaskFile.toml", job.inputs["user_file"]) + if f"/extras/{job.instrument.instrument_name.lower()}/MaskFile.toml" in line and "user_file" in job.inputs: + lines[index] = line.replace( + f"/extras/{job.instrument.instrument_name.lower()}/MaskFile.toml", job.inputs["user_file"] + ) continue if "run_number" in job.inputs and self._replace_input( line, lines, index, "sample_scatter", job.inputs["run_number"] diff --git a/test/scripts/transforms/test_factory.py b/test/scripts/transforms/test_factory.py index 161c359c..1f500b38 100644 --- a/test/scripts/transforms/test_factory.py +++ b/test/scripts/transforms/test_factory.py @@ -8,6 +8,7 @@ from fia_api.scripts.transforms.iris_transform import IrisTransform from fia_api.scripts.transforms.mari_transforms import MariTransform from fia_api.scripts.transforms.osiris_transform import OsirisTransform +from fia_api.scripts.transforms.sans_transform import SansTransform from fia_api.scripts.transforms.test_transforms import TestTransform from fia_api.scripts.transforms.tosca_transform import ToscaTransform from fia_api.scripts.transforms.transform import MissingTransformError @@ -21,6 +22,8 @@ ("test", TestTransform), ("osiris", OsirisTransform), ("iris", IrisTransform), + ("loq", SansTransform), + ("sans2d", SansTransform), ], ) def test_transform_factory(name, expected_transform): diff --git a/test/scripts/transforms/test_loq_transforms.py b/test/scripts/transforms/test_loq_transforms.py index 16705234..3690efb6 100644 --- a/test/scripts/transforms/test_loq_transforms.py +++ b/test/scripts/transforms/test_loq_transforms.py @@ -7,7 +7,7 @@ import pytest from fia_api.scripts.pre_script import PreScript -from fia_api.scripts.transforms.loq_transform import LoqTransform +from fia_api.scripts.transforms.sans_transform import SansTransform @pytest.fixture() @@ -64,8 +64,9 @@ def reduction_1(): "sample_height": 8.0, "sample_width": 8.0, "slice_wavs": "[1.0, 2.0, 3.0]", - "phi_limits_list": "[(-20, 20), (30, 160)]", + "phi_limits": "[(-20, 20), (30, 160)]", } + mock.instrument.instrument_name = "LOQ" return mock @@ -84,8 +85,9 @@ def reduction_2(): "sample_height": 8.0, "sample_width": 8.0, "slice_wavs": "[1.0, 2.0, 3.0]", - "phi_limits_list": "[(-20, 20), (30, 160)]", + "phi_limits": "[(-20, 20), (30, 160)]", } + mock.instrument.instrument_name = "LOQ" return mock @@ -96,7 +98,7 @@ def test_loq_transform_apply(script, reduction_1): :param reduction_1: The reduction fixture :return: None """ - transform = LoqTransform() + transform = SansTransform() original_lines = script.value.splitlines() transform.apply(script, reduction_1) @@ -134,7 +136,7 @@ def test_loq_transform_apply_with_optionals(script, reduction_2): :param reduction_2: The reduction fixture :return: None """ - transform = LoqTransform() + transform = SansTransform() original_lines = script.value.splitlines() transform.apply(script, reduction_2) diff --git a/test/scripts/transforms/test_sans2d_transforms.py b/test/scripts/transforms/test_sans2d_transforms.py new file mode 100644 index 00000000..f0a640cb --- /dev/null +++ b/test/scripts/transforms/test_sans2d_transforms.py @@ -0,0 +1,170 @@ +""" +Test cases for LoqTransform +""" + +from unittest.mock import Mock + +import pytest + +from fia_api.scripts.pre_script import PreScript +from fia_api.scripts.transforms.sans_transform import SansTransform + + +@pytest.fixture() +def script(): + """ + LoqTransform PreScript fixture + :return: + """ + return PreScript( + value=""" +import math +import numpy +import csv +import datetime + +from mantid.kernel import ConfigService +from mantid.simpleapi import RenameWorkspace, SaveRKH, SaveNXcanSAS, GroupWorkspaces, mtd, ConjoinWorkspaces +from mantid import config +from sans.user_file.toml_parsers.toml_reader import TomlReader +import sans.command_interface.ISISCommandInterface as ici + +# Setup by rundetection +user_file = "/extras/sans2d/MaskFile.toml" +sample_scatter = 110754 +sample_transmission = None +sample_direct = None +can_scatter = None +can_transmission = None +can_direct = None +sample_thickness = 1.0 +sample_geometry = "Disc" +sample_height = 8.0 +sample_width = 8.0 +slice_wavs = [1.75, 2.75, 3.75, 4.75, 5.75, 6.75, 8.75, 10.75, 12.5] +phi_limits_list = [(-30, 30), (60, 120)] +""" + ) + + +@pytest.fixture() +def reduction_1(): + """ + Reduction fixture + :return: + """ + mock = Mock() + mock.inputs = { + "user_file": "/extras/sans2d/BestMaskFile.toml", + "run_number": 10, + "scatter_transmission": 9, + "scatter_direct": 3, + "can_scatter": 5, + "can_transmission": 4, + "can_direct": 3, + "sample_thickness": 2.0, + "sample_geometry": "Disc", + "sample_height": 8.0, + "sample_width": 8.0, + "slice_wavs": "[1.0, 2.0, 3.0]", + "phi_limits": "[(-20, 20), (30, 160)]", + } + mock.instrument.instrument_name = "SANS2D" + return mock + + +@pytest.fixture() +def reduction_2(): + """ + Reduction fixture + :return: + """ + mock = Mock() + mock.inputs = { + "user_file": "/extras/sans2d/BestMaskFile.toml", + "run_number": 5, + "sample_thickness": 2.0, + "sample_geometry": "Disc", + "sample_height": 8.0, + "sample_width": 8.0, + "slice_wavs": "[1.0, 2.0, 3.0]", + "phi_limits": "[(-20, 20), (30, 160)]", + } + mock.instrument.instrument_name = "SANS2D" + return mock + + +def test_sans2d_transform_apply(script, reduction_1): + """ + Test loq transform applies correct updates to script + :param script: The script fixture + :param reduction_1: The reduction fixture + :return: None + """ + transform = SansTransform() + + original_lines = script.value.splitlines() + transform.apply(script, reduction_1) + updated_lines = script.value.splitlines() + assert len(original_lines) == len(updated_lines) + replacements = { + "user_file": 'user_file = "/extras/sans2d/BestMaskFile.toml"', + "sample_scatter": "sample_scatter = 10", + "sample_transmission": "sample_transmission = 9", + "sample_direct": "sample_direct = 3", + "can_scatter": "can_scatter = 5", + "can_transmission": "can_transmission = 4", + "can_direct": "can_direct = 3", + "sample_thickness": "sample_thickness = 2.0", + "sample_geometry": 'sample_geometry = "Disc"', + "sample_height": "sample_height = 8.0", + "sample_width": "sample_width = 8.0", + "slice_wavs": "slice_wavs = [1.0, 2.0, 3.0]", + "phi_limits": "phi_limits_list = [(-20, 20), (30, 160)]", + } + + for index, line in enumerate(updated_lines): + for key, expected_line in replacements.items(): + if line.startswith(key): + assert line == expected_line + break + else: + assert line == original_lines[index] + + +def test_sans2d_transform_apply_with_optionals(script, reduction_2): + """ + Test loq transform applies correct updates to script + :param script: The script fixture + :param reduction_2: The reduction fixture + :return: None + """ + transform = SansTransform() + + original_lines = script.value.splitlines() + transform.apply(script, reduction_2) + updated_lines = script.value.splitlines() + assert len(original_lines) == len(updated_lines) + replacements = { + "user_file": 'user_file = "/extras/sans2d/BestMaskFile.toml"', + "sample_scatter": "sample_scatter = 5", + "sample_transmission": "sample_transmission = None", + "sample_direct": "sample_direct = None", + "can_scatter": "can_scatter = None", + "can_transmission": "can_transmission = None", + "can_direct": "can_direct = None", + "sample_thickness": "sample_thickness = 2.0", + "sample_geometry": 'sample_geometry = "Disc"', + "sample_height": "sample_height = 8.0", + "sample_width": "sample_width = 8.0", + "slice_wavs": "slice_wavs = [1.0, 2.0, 3.0]", + "phi_limits": "phi_limits_list = [(-20, 20), (30, 160)]", + } + + for index, line in enumerate(updated_lines): + for key, expected_line in replacements.items(): + if line.startswith(key): + assert line == expected_line + break + else: + assert line == original_lines[index]