Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/diffpy/labpdfproc/labpdfprocapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object


def get_args():
def get_args(override_cli_inputs=None):
p = ArgumentParser()
p.add_argument("mud", help="Value of mu*D for your " "sample. Required.", type=float)
p.add_argument("-i", "--input-file", help="The filename of the " "datafile to load.")
Expand Down Expand Up @@ -58,7 +58,7 @@ def get_args():
action="store_true",
help="Outputs will not overwrite existing file unless --force is specified.",
)
args = p.parse_args()
args = p.parse_args(override_cli_inputs)
return args


Expand Down
65 changes: 32 additions & 33 deletions src/diffpy/labpdfproc/tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,77 @@
import argparse
import os
import re
from pathlib import Path

import pytest

from diffpy.labpdfproc.labpdfprocapp import get_args
from diffpy.labpdfproc.tools import known_sources, set_output_directory, set_wavelength

params1 = [
([None], ["."]),
(["."], ["."]),
(["new_dir"], ["new_dir"]),
(["existing_dir"], ["existing_dir"]),
([], ["."]),
(["--output-directory", "."], ["."]),
(["--output-directory", "new_dir"], ["new_dir"]),
(["--output-directory", "input_dir"], ["input_dir"]),
]


@pytest.mark.parametrize("inputs, expected", params1)
def test_set_output_directory(inputs, expected, tmp_path):
directory = Path(tmp_path)
os.chdir(directory)
def test_set_output_directory(inputs, expected, user_filesystem):
tmp_dir = user_filesystem
expected_output_directory = tmp_dir / expected[0]

existing_dir = Path(tmp_path).resolve() / "existing_dir"
existing_dir.mkdir(parents=True, exist_ok=True)

expected_output_directory = Path(tmp_path).resolve() / expected[0]
actual_args = argparse.Namespace(output_directory=inputs[0])
cli_inputs = ["2.5"] + inputs
actual_args = get_args(cli_inputs)
actual_args.output_directory = set_output_directory(actual_args)
assert actual_args.output_directory == expected_output_directory
assert Path(actual_args.output_directory).exists()
assert Path(actual_args.output_directory).is_dir()


def test_set_output_directory_bad(tmp_path):
directory = Path(tmp_path)
os.chdir(directory)

existing_file = Path(tmp_path).resolve() / "existing_file.py"
existing_file.touch()

actual_args = argparse.Namespace(output_directory="existing_file.py")
def test_set_output_directory_bad(user_filesystem):
cli_inputs = ["2.5", "--output-directory", "good_data.chi"]
actual_args = get_args(cli_inputs)
with pytest.raises(FileExistsError):
actual_args.output_directory = set_output_directory(actual_args)
assert Path(actual_args.output_directory).exists()
assert not Path(actual_args.output_directory).is_dir()


params2 = [
([None, None], [0.71]),
([None, "Ag"], [0.59]),
([0.25, "Ag"], [0.25]),
([0.25, None], [0.25]),
([], [0.71]),
(["--anode-type", "Ag"], [0.59]),
(["--wavelength", "0.25"], [0.25]),
(["--wavelength", "0.25", "--anode-type", "Ag"], [0.25]),
]


@pytest.mark.parametrize("inputs, expected", params2)
def test_set_wavelength(inputs, expected):
expected_wavelength = expected[0]
actual_args = argparse.Namespace(wavelength=inputs[0], anode_type=inputs[1])
actual_wavelength = set_wavelength(actual_args)
assert actual_wavelength == expected_wavelength
cli_inputs = ["2.5"] + inputs
actual_args = get_args(cli_inputs)
actual_args.wavelength = set_wavelength(actual_args)
assert actual_args.wavelength == expected_wavelength


params3 = [
(
[None, "invalid"],
["--anode-type", "invalid"],
[f"Anode type not recognized. Please rerun specifying an anode_type from {*known_sources, }."],
),
([0, None], ["No valid wavelength. Please rerun specifying a known anode_type or a positive wavelength."]),
([-1, "Mo"], ["No valid wavelength. Please rerun specifying a known anode_type or a positive wavelength."]),
(
["--wavelength", "0"],
["No valid wavelength. Please rerun specifying a known anode_type or a positive wavelength."],
),
(
["--wavelength", "-1", "--anode-type", "Mo"],
["No valid wavelength. Please rerun specifying a known anode_type or a positive wavelength."],
),
]


@pytest.mark.parametrize("inputs, msg", params3)
def test_set_wavelength_bad(inputs, msg):
actual_args = argparse.Namespace(wavelength=inputs[0], anode_type=inputs[1])
cli_inputs = ["2.5"] + inputs
actual_args = get_args(cli_inputs)
with pytest.raises(ValueError, match=re.escape(msg[0])):
actual_args.wavelength = set_wavelength(actual_args)