Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor cli to typer #15

Open
wants to merge 18 commits into
base: add_csv_reader
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"numpy",
"pandas",
"pandas-stubs",
"typer",
] # Add project dependencies here, e.g. ["click", "numpy"]
dynamic = ["version"]
license.file = "LICENSE"
Expand All @@ -36,7 +37,7 @@ dev = [
]

[project.scripts]
bimorph-mirror-analysis = "bimorph_mirror_analysis.__main__:main"
bimorph-mirror-analysis = "bimorph_mirror_analysis.__main__:app"

[project.urls]
GitHub = "https://github.com//bimorph-mirror-analysis"
Expand Down
68 changes: 37 additions & 31 deletions src/bimorph_mirror_analysis/__main__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""Interface for ``python -m bimorph_mirror_analysis``."""

import datetime
from argparse import ArgumentParser
from collections.abc import Sequence

import numpy as np
import typer

from bimorph_mirror_analysis.maths import find_voltages
from bimorph_mirror_analysis.read_file import read_bluesky_plan_output
Expand All @@ -13,37 +12,25 @@

__all__ = ["main"]

app = typer.Typer()

def main(args: Sequence[str] | None = None) -> None:
"""Argument parser for the CLI."""
parser = ArgumentParser()
parser.add_argument(
"-v",
"--version",
action="version",
version=__version__,
)
parser.add_argument(
"file_path",
type=str,
help="Path to the file containing the output of the Bluesky plan.",
)
parser.add_argument(
"-o",
"--output",
type=str,
help="Path to save the optimal voltages.",
)
a = parser.parse_args(args)
file_path = a.file_path
file_type = a.file_path.split(".")[-1]

@app.command(name=None)
def calculate_voltages(
file_path: str = typer.Argument(help="The path to the csv file to be read."),
output_path: str | None = typer.Option(
None,
help="The path to save the output\
optimal voltages to, optional.",
),
):
file_type = file_path.split(".")[-1]
optimal_voltages = calculate_optimal_voltages(file_path)
optimal_voltages = np.round(optimal_voltages, 2)
LukeFiddy marked this conversation as resolved.
Show resolved Hide resolved
date = datetime.datetime.now().date()
if a.output:
output_path = a.output
else:
output_path = f"{a.file_path.replace(f'.{file_type}', '')}\

if output_path is None:
LukeFiddy marked this conversation as resolved.
Show resolved Hide resolved
output_path = f"{file_path.replace(f'.{file_type}', '')}\
_optimal_voltages_{date}.csv"

np.savetxt(
Expand All @@ -57,7 +44,26 @@ def main(args: Sequence[str] | None = None) -> None:
)


# implement this into main
def version_callback(value: bool):
if value:
typer.echo(f"Version: {__version__}")
raise typer.Exit()


@app.callback()
def main(
version: bool = typer.Option(
None,
"--version",
"-v",
callback=version_callback,
is_eager=True,
help="Show the application's version and exit",
),
):
pass


def calculate_optimal_voltages(file_path: str) -> np.typing.NDArray[np.float64]:
pivoted, initial_voltages, increment = read_bluesky_plan_output(file_path)
# numpy array of pencil beam scans
Expand All @@ -69,4 +75,4 @@ def calculate_optimal_voltages(file_path: str) -> np.typing.NDArray[np.float64]:


if __name__ == "__main__":
main()
app()
128 changes: 128 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from typing import Any

import pandas as pd
import pytest

# Prevent pytest from catching exceptions when debugging in vscode so that break on
Expand All @@ -19,3 +20,130 @@ def pytest_exception_interact(call: pytest.CallInfo[Any]):
@pytest.hookimpl(tryfirst=True)
def pytest_internalerror(excinfo: pytest.ExceptionInfo[Any]):
raise excinfo.value


@pytest.fixture
def raw_data() -> pd.DataFrame:
data = """voltage_channel_1,voltage_channel_2,voltage_channel_3,slit_position_x,\
slit_width_x,slit_position_y,slit_width_y,centroid_position_x,centroid_position_y,pencil_beam_scan_number
0,0,0,0.0,1,0,1,0.902155003650946,0.9059778816847021,0
0,0,0,0.5,1,0,1,0.974760013427511,0.048632530864663126,0
0,0,0,1.0,1,0,1,0.628935425613454,0.11203132765051338,0
0,0,0,1.5,1,0,1,0.7762649318706349,0.2719190769515324,0
0,0,0,2.0,1,0,1,0.5071122837043811,0.611609088123056,0
0,0,0,2.5,1,0,1,0.3474659308224035,0.03510213867906076,0
0,0,0,3.0,1,0,1,0.3284992604576076,0.99208831627924,0
0,0,0,3.5,1,0,1,0.7295178626505394,0.5103013553152209,0
0,0,0,4.0,1,0,1,0.14850099035491915,0.07969729764666311,0
0,0,0,4.5,1,0,1,0.022176039254447888,0.049619315174349476,0
0,0,0,5.0,1,0,1,0.35295771305012513,0.9557024629963558,0
0,0,0,5.5,1,0,1,0.7179675309214435,0.2055059586891408,0
0,0,0,6.0,1,0,1,0.6499835221868125,0.606285488948628,0
0,0,0,6.5,1,0,1,0.47779199219005,0.6546623731441139,0
0,0,0,7.0,1,0,1,0.9336634732337328,0.3136740595705516,0
0,0,0,7.5,1,0,1,0.009520569061363338,0.8063235605316648,0
0,0,0,8.0,1,0,1,0.7158395576862129,0.5122984261799449,0
0,0,0,8.5,1,0,1,0.2716444289557093,0.5507234769303935,0
0,0,0,9.0,1,0,1,0.9930168604136324,0.8152365237183438,0
0,0,0,9.5,1,0,1,0.4129681812988907,0.8023526931635533,0
0,0,0,10.0,1,0,1,0.6674668944557347,0.4174468364062732,0
100,0,0,0.0,1,0,1,0.5546590484621201,0.8304502838039677,1
100,0,0,0.5,1,0,1,0.13906662116959356,0.17335719779661474,1
100,0,0,1.0,1,0,1,0.8798442985528538,0.6082544403694868,1
100,0,0,1.5,1,0,1,0.18419363414361012,0.8916318274563698,1
100,0,0,2.0,1,0,1,0.8711036413523057,0.7004663717770345,1
100,0,0,2.5,1,0,1,0.9099533074121638,0.5582690692187812,1
100,0,0,3.0,1,0,1,0.7802219284321528,0.5448810133045838,1
100,0,0,3.5,1,0,1,0.9206330646867245,0.9704364812040612,1
100,0,0,4.0,1,0,1,0.40940536712510833,0.5267501604276668,1
100,0,0,4.5,1,0,1,0.0896507221276781,0.9367451946136546,1
100,0,0,5.0,1,0,1,0.2712519699376731,0.8318670004689493,1
100,0,0,5.5,1,0,1,0.536044487634938,0.8376715817809228,1
100,0,0,6.0,1,0,1,0.43025265873767304,0.3787406941706395,1
100,0,0,6.5,1,0,1,0.3847349424152309,0.10274579669497463,1
100,0,0,7.0,1,0,1,0.15897565790767731,0.20601860566024477,1
100,0,0,7.5,1,0,1,0.33067777617929583,0.6736142304595877,1
100,0,0,8.0,1,0,1,0.11858344701673051,0.6782032580859674,1
100,0,0,8.5,1,0,1,0.4286248262035558,0.9895136712839148,1
100,0,0,9.0,1,0,1,0.7180884861749929,0.056036359595403096,1
100,0,0,9.5,1,0,1,0.26499046575096896,0.39912112255467724,1
100,0,0,10.0,1,0,1,0.14379259711280845,0.14106802929088103,1
100,100,0,0.0,1,0,1,0.7233918949216122,0.982168455516603,2
100,100,0,0.5,1,0,1,0.290742792709767,0.12776420966496438,2
100,100,0,1.0,1,0,1,0.13799656543141725,0.6898436988997922,2
100,100,0,1.5,1,0,1,0.7966501007242092,0.10121665939974611,2
100,100,0,2.0,1,0,1,0.6035548318312184,0.5266494061788435,2
100,100,0,2.5,1,0,1,0.8745250527670112,0.2783673425712273,2
100,100,0,3.0,1,0,1,0.2199662395829104,0.8628343609007396,2
100,100,0,3.5,1,0,1,0.21397922232247635,0.14932168483955655,2
100,100,0,4.0,1,0,1,0.6491136779321542,0.9677149734414517,2
100,100,0,4.5,1,0,1,0.521585333219087,0.3555822165846366,2
100,100,0,5.0,1,0,1,0.2878937119272532,0.9483878657266134,2
100,100,0,5.5,1,0,1,0.5528976167527797,0.2989377930464553,2
100,100,0,6.0,1,0,1,0.2348080763889996,0.27979916967784824,2
100,100,0,6.5,1,0,1,0.22744100740205286,0.8218958168136187,2
100,100,0,7.0,1,0,1,0.18745517586190308,0.6605989632183804,2
100,100,0,7.5,1,0,1,0.6481645591784894,0.6263681865745402,2
100,100,0,8.0,1,0,1,0.16260780954429876,0.27381083790754324,2
100,100,0,8.5,1,0,1,0.1579859256512237,0.9062488979864969,2
100,100,0,9.0,1,0,1,0.37772126256203997,0.7955668100253056,2
100,100,0,9.5,1,0,1,0.2295136220510665,0.8019070476612383,2
100,100,0,10.0,1,0,1,0.8681643230023325,0.5923396832700695,2
100,100,100,0.0,1,0,1,0.232850273843354,0.3203071762815356,3
100,100,100,0.5,1,0,1,0.39422707347590746,0.9169444346778486,3
100,100,100,1.0,1,0,1,0.5195777869079362,0.933578081913559,3
100,100,100,1.5,1,0,1,0.3055357289931444,0.13087372577687528,3
100,100,100,2.0,1,0,1,0.7483418094894336,0.07346239072990834,3
100,100,100,2.5,1,0,1,0.696346546725722,0.7225571582190516,3
100,100,100,3.0,1,0,1,0.9724618758170118,0.9471159228590665,3
100,100,100,3.5,1,0,1,0.21177070577236934,0.6092750600252596,3
100,100,100,4.0,1,0,1,0.9916829027154486,0.9626893747544218,3
100,100,100,4.5,1,0,1,0.474706589914094,0.8168984615795551,3
100,100,100,5.0,1,0,1,0.28338124640130735,0.9619912615534008,3
100,100,100,5.5,1,0,1,0.5541329846866376,0.6162100171902787,3
100,100,100,6.0,1,0,1,0.4041186292722393,0.7502255535642713,3
100,100,100,6.5,1,0,1,0.9755649555219787,0.6290253229311905,3
100,100,100,7.0,1,0,1,0.7228182241678169,0.2973982599315711,3
100,100,100,7.5,1,0,1,0.29495423356560835,0.005345110272705567,3
100,100,100,8.0,1,0,1,0.2785374723069234,0.2728282979215154,3
100,100,100,8.5,1,0,1,0.004064948618126452,0.658081995540037,3
100,100,100,9.0,1,0,1,0.6907922378507042,0.17092071374050632,3
100,100,100,9.5,1,0,1,0.3965148737759935,0.04638829999092675,3
100,100,100,10.0,1,0,1,0.8706751462033909,0.7206080303974453,3"""
df = pd.DataFrame(
[row.split(",") for row in data.split("\n")[1:]],
columns=data.split("\n")[0].split(","),
) # type: ignore
return df.apply(pd.to_numeric, errors="coerce") # type: ignore


@pytest.fixture
def raw_data_pivoted() -> pd.DataFrame:
data = """slit_position_x,pencil_beam_scan_0,pencil_beam_scan_1,pencil_beam_scan_2\
,pencil_beam_scan_3
0.0,0.902155003650946,0.5546590484621201,0.7233918949216122,0.232850273843354
0.5,0.974760013427511,0.1390666211695935,0.290742792709767,0.3942270734759074
1.0,0.628935425613454,0.8798442985528538,0.1379965654314172,0.5195777869079362
1.5,0.7762649318706349,0.1841936341436101,0.7966501007242092,0.3055357289931444
2.0,0.5071122837043811,0.8711036413523057,0.6035548318312184,0.7483418094894336
2.5,0.3474659308224035,0.9099533074121638,0.8745250527670112,0.696346546725722
3.0,0.3284992604576076,0.7802219284321528,0.2199662395829104,0.9724618758170118
3.5,0.7295178626505394,0.9206330646867243,0.2139792223224763,0.2117707057723693
4.0,0.1485009903549191,0.4094053671251083,0.6491136779321542,0.9916829027154486
4.5,0.0221760392544478,0.0896507221276781,0.521585333219087,0.474706589914094
5.0,0.3529577130501251,0.2712519699376731,0.2878937119272532,0.2833812464013073
5.5,0.7179675309214435,0.536044487634938,0.5528976167527797,0.5541329846866376
6.0,0.6499835221868125,0.430252658737673,0.2348080763889996,0.4041186292722393
6.5,0.47779199219005,0.3847349424152309,0.2274410074020528,0.9755649555219787
7.0,0.9336634732337328,0.1589756579076773,0.187455175861903,0.7228182241678169
7.5,0.0095205690613633,0.3306777761792958,0.6481645591784894,0.2949542335656083
8.0,0.7158395576862129,0.1185834470167305,0.1626078095442987,0.2785374723069234
8.5,0.2716444289557093,0.4286248262035558,0.1579859256512237,0.0040649486181264
9.0,0.9930168604136324,0.7180884861749929,0.3777212625620399,0.6907922378507042
9.5,0.4129681812988907,0.2649904657509689,0.2295136220510665,0.3965148737759935
10.0,0.6674668944557347,0.1437925971128084,0.8681643230023325,0.8706751462033909"""
df = pd.DataFrame(
[row.split(",") for row in data.split("\n")[1:]],
columns=data.split("\n")[0].split(","),
) # type: ignore
return df.apply(pd.to_numeric, errors="coerce") # type: ignore
3 changes: 3 additions & 0 deletions tests/data/raw_data_optimal_voltages_2024-12-19.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
72.14
50.98
18.59
41 changes: 30 additions & 11 deletions tests/program_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,35 @@
from unittest.mock import patch

import numpy as np
import pytest
import pandas as pd

from bimorph_mirror_analysis.__main__ import calculate_optimal_voltages
from bimorph_mirror_analysis.maths import find_voltages


def test_calculate_optimal_voltages_mocked(raw_data_pivoted: pd.DataFrame):
with (
patch(
"bimorph_mirror_analysis.__main__.read_bluesky_plan_output"
) as mock_read_bluesky_plan_output,
patch("bimorph_mirror_analysis.__main__.find_voltages") as mock_find_voltages,
):
# set the mock return values
mock_read_bluesky_plan_output.return_value = (
raw_data_pivoted,
np.array([0.0, 0.0, 0.0]),
100,
)
mock_find_voltages.side_effect = find_voltages
voltages = calculate_optimal_voltages("input_file")
voltages = np.round(voltages, 2)
# assert correct voltages calculated
np.testing.assert_almost_equal(voltages, np.array([72.14, 50.98, 18.59]))

@pytest.mark.parametrize(
["input_file", "expected_voltages"],
[["tests/data/raw_data.csv", np.array([72.14, 50.98, 18.59])]],
)
def test_calculate_optimal_voltages(
input_file: str, expected_voltages: np.typing.NDArray[np.float64]
):
voltages = calculate_optimal_voltages(input_file)
voltages = np.round(voltages, 2)
np.testing.assert_almost_equal(voltages, expected_voltages)
# assert mock was called
mock_read_bluesky_plan_output.assert_called()
mock_read_bluesky_plan_output.assert_called_with("input_file")
mock_find_voltages.assert_called()
expected_data = raw_data_pivoted[raw_data_pivoted.columns[1:]].to_numpy() # type: ignore
np.testing.assert_array_equal(mock_find_voltages.call_args[0][0], expected_data) # type: ignore
np.testing.assert_almost_equal(mock_find_voltages.call_args[0][1], 100)
45 changes: 25 additions & 20 deletions tests/read_file_test.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,37 @@
from unittest.mock import patch

import numpy as np
import pandas as pd
import pytest

from bimorph_mirror_analysis.read_file import read_bluesky_plan_output


@pytest.mark.parametrize(
["input_path", "output_path"],
[
["tests/data/raw_data.csv", "tests/data/raw_data_pivoted.csv"],
],
)
def test_read_raw_data(input_path: str, output_path: str):
pivoted, initial_voltages, increment = read_bluesky_plan_output(input_path)
expected_output = pd.read_csv(output_path) # type: ignore
pd.testing.assert_frame_equal(pivoted, expected_output)
np.testing.assert_array_equal(initial_voltages, np.array([0.0, 0.0, 0.0]))
np.testing.assert_equal(increment, np.float64(100.0))
def test_read_raw_data(raw_data: pd.DataFrame, raw_data_pivoted: pd.DataFrame):
with patch("bimorph_mirror_analysis.read_file.pd.read_csv") as mock_read_csv:
mock_read_csv.return_value = raw_data
pivoted, initial_voltages, increment = read_bluesky_plan_output("input_path")
# expected_output = pd.read_csv(output_path) # type: ignore
pd.testing.assert_frame_equal(pivoted, raw_data_pivoted)
np.testing.assert_array_equal(initial_voltages, np.array([0.0, 0.0, 0.0]))
np.testing.assert_equal(increment, np.float64(100.0))
mock_read_csv.assert_called()


@pytest.mark.xfail(
reason="This test is expected to fail, the incrememnt should be 100, not 101"
)
def test_read_raw_data_fail():
pivoted, initial_voltages, increment = read_bluesky_plan_output(
"tests/data/raw_data.csv"
)
expected_output = pd.read_csv("tests/data/raw_data_pivoted.csv") # type: ignore
pd.testing.assert_frame_equal(pivoted, expected_output)
np.testing.assert_array_equal(initial_voltages, np.array([0.0, 0.0, 0.0]))
np.testing.assert_equal(increment, np.float64(101.0))
def test_read_raw_data_fail(raw_data_pivoted: pd.DataFrame):
with patch(
"bimorph_mirror_analysis.read_file.read_bluesky_plan_output"
) as mock_read_bluesky_plan_output:
mock_read_bluesky_plan_output.return_value = (
raw_data_pivoted,
np.array([0.0, 0.0, 0.0]),
np.float64(101.0),
)
pivoted, initial_voltages, increment = mock_read_bluesky_plan_output()
expected_output = pd.read_csv("tests/data/raw_data_pivoted.csv") # type: ignore
pd.testing.assert_frame_equal(pivoted, expected_output)
np.testing.assert_array_equal(initial_voltages, np.array([0.0, 0.0, 0.0]))
np.testing.assert_equal(increment, np.float64(100.0))
Loading
Loading