diff --git a/src/diffpy/labpdfproc/labpdfprocapp.py b/src/diffpy/labpdfproc/labpdfprocapp.py index c0ed5e3..6dca178 100644 --- a/src/diffpy/labpdfproc/labpdfprocapp.py +++ b/src/diffpy/labpdfproc/labpdfprocapp.py @@ -1,9 +1,15 @@ import sys from argparse import ArgumentParser -from pathlib import Path from diffpy.labpdfproc.functions import apply_corr, compute_cve -from diffpy.labpdfproc.tools import known_sources, load_user_metadata, set_output_directory, set_wavelength +from diffpy.labpdfproc.tools import ( + expand_list_file, + known_sources, + load_user_metadata, + set_input_lists, + set_output_directory, + set_wavelength, +) from diffpy.utils.parsers.loaddata import loadData from diffpy.utils.scattering_objects.diffraction_objects import XQUANTITIES, Diffraction_object @@ -21,7 +27,7 @@ def get_args(override_cli_inputs=None): "data-files in that directory will be processed. Examples of valid " "inputs are 'file.xy', 'data/file.xy', 'file.xy, data/file.xy', " "'.' (load everything in the current directory), 'data' (load" - "everything in the folder ./data', 'data/file_list.txt' (load" + "everything in the folder ./data), 'data/file_list.txt' (load" " the list of files contained in the text-file called " "file_list.txt that can be found in the folder ./data).", ) @@ -89,45 +95,47 @@ def get_args(override_cli_inputs=None): def main(): args = get_args() + args = expand_list_file(args) + args = set_input_lists(args) args.output_directory = set_output_directory(args) args.wavelength = set_wavelength(args) args = load_user_metadata(args) - filepath = Path(args.input_file) - outfilestem = filepath.stem + "_corrected" - corrfilestem = filepath.stem + "_cve" - outfile = args.output_directory / (outfilestem + ".chi") - corrfile = args.output_directory / (corrfilestem + ".chi") + for filepath in args.input_directory: + outfilestem = filepath.stem + "_corrected" + corrfilestem = filepath.stem + "_cve" + outfile = args.output_directory / (outfilestem + ".chi") + corrfile = args.output_directory / (corrfilestem + ".chi") - if outfile.exists() and not args.force_overwrite: - sys.exit( - f"Output file {str(outfile)} already exists. Please rerun " - f"specifying -f if you want to overwrite it." - ) - if corrfile.exists() and args.output_correction and not args.force_overwrite: - sys.exit( - f"Corrections file {str(corrfile)} was requested and already " - f"exists. Please rerun specifying -f if you want to overwrite it." - ) + if outfile.exists() and not args.force_overwrite: + sys.exit( + f"Output file {str(outfile)} already exists. Please rerun " + f"specifying -f if you want to overwrite it." + ) + if corrfile.exists() and args.output_correction and not args.force_overwrite: + sys.exit( + f"Corrections file {str(corrfile)} was requested and already " + f"exists. Please rerun specifying -f if you want to overwrite it." + ) - input_pattern = Diffraction_object(wavelength=args.wavelength) - xarray, yarray = loadData(args.input_file, unpack=True) - input_pattern.insert_scattering_quantity( - xarray, - yarray, - "tth", - scat_quantity="x-ray", - name=str(args.input_file), - metadata={"muD": args.mud, "anode_type": args.anode_type}, - ) + input_pattern = Diffraction_object(wavelength=args.wavelength) + xarray, yarray = loadData(args.input_file, unpack=True) + input_pattern.insert_scattering_quantity( + xarray, + yarray, + "tth", + scat_quantity="x-ray", + name=str(args.input_file), + metadata={"muD": args.mud, "anode_type": args.anode_type}, + ) - absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength) - corrected_data = apply_corr(input_pattern, absorption_correction) - corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}" - corrected_data.dump(f"{outfile}", xtype="tth") + absorption_correction = compute_cve(input_pattern, args.mud, args.wavelength) + corrected_data = apply_corr(input_pattern, absorption_correction) + corrected_data.name = f"Absorption corrected input_data: {input_pattern.name}" + corrected_data.dump(f"{outfile}", xtype="tth") - if args.output_correction: - absorption_correction.dump(f"{corrfile}", xtype="tth") + if args.output_correction: + absorption_correction.dump(f"{corrfile}", xtype="tth") if __name__ == "__main__": diff --git a/src/diffpy/labpdfproc/tests/test_tools.py b/src/diffpy/labpdfproc/tests/test_tools.py index db75965..23e56bb 100644 --- a/src/diffpy/labpdfproc/tests/test_tools.py +++ b/src/diffpy/labpdfproc/tests/test_tools.py @@ -6,6 +6,7 @@ from diffpy.labpdfproc.labpdfprocapp import get_args from diffpy.labpdfproc.tools import ( + expand_list_file, known_sources, load_user_metadata, set_input_lists, @@ -49,10 +50,6 @@ "input_dir/binary.pkl", ], ), - ( # file_list.txt list of files provided - ["input_dir/file_list.txt"], - ["good_data.chi", "good_data.xy", "good_data.txt"], - ), ( # file_list_example2.txt list of files provided in different directories ["input_dir/file_list_example2.txt"], ["input_dir/good_data.chi", "good_data.xy", "input_dir/good_data.txt"], @@ -68,8 +65,9 @@ def test_set_input_lists(inputs, expected, user_filesystem): cli_inputs = ["2.5"] + inputs actual_args = get_args(cli_inputs) + actual_args = expand_list_file(actual_args) actual_args = set_input_lists(actual_args) - assert list(actual_args.input_paths).sort() == expected_paths.sort() + assert sorted(actual_args.input_paths) == sorted(expected_paths) # This test covers non-existing single input file or directory, in this case we raise an error with message @@ -87,6 +85,10 @@ def test_set_input_lists(inputs, expected, user_filesystem): ["good_data.chi", "good_data.xy", "unreadable_file.txt", "missing_file.txt"], "Cannot find missing_file.txt. Please specify valid input file(s) or directories.", ), + ( # file_list.txt list of files provided (with missing files) + ["input_dir/file_list.txt"], + "Cannot find missing_file.txt. Please specify valid input file(s) or directories.", + ), ] @@ -96,6 +98,7 @@ def test_set_input_files_bad(inputs, msg, user_filesystem): os.chdir(base_dir) cli_inputs = ["2.5"] + inputs actual_args = get_args(cli_inputs) + actual_args = expand_list_file(actual_args) with pytest.raises(FileNotFoundError, match=msg[0]): actual_args = set_input_lists(actual_args) diff --git a/src/diffpy/labpdfproc/tools.py b/src/diffpy/labpdfproc/tools.py index 2bf2379..a95bf92 100644 --- a/src/diffpy/labpdfproc/tools.py +++ b/src/diffpy/labpdfproc/tools.py @@ -28,6 +28,29 @@ def set_output_directory(args): return output_dir +def expand_list_file(args): + """ + Expands the list of inputs by adding files from file lists and removing the file list. + + Parameters + ---------- + args argparse.Namespace + the arguments from the parser + + Returns + ------- + the arguments with the modified input list + + """ + file_list_inputs = [input_name for input_name in args.input if "file_list" in input_name] + for file_list_input in file_list_inputs: + with open(file_list_input, "r") as f: + file_inputs = [input_name.strip() for input_name in f.readlines()] + args.input.extend(file_inputs) + args.input.remove(file_list_input) + return args + + def set_input_lists(args): """ Set input directory and files. @@ -47,20 +70,24 @@ def set_input_lists(args): """ input_paths = [] - for input in args.input: - input_path = Path(input).resolve() + for input_name in args.input: + input_path = Path(input_name).resolve() if input_path.exists(): if input_path.is_file(): input_paths.append(input_path) elif input_path.is_dir(): input_files = input_path.glob("*") - input_files = [file.resolve() for file in input_files if file.is_file()] + input_files = [ + file.resolve() for file in input_files if file.is_file() and "file_list" not in file.name + ] input_paths.extend(input_files) else: - raise FileNotFoundError(f"Cannot find {input}. Please specify valid input file(s) or directories.") + raise FileNotFoundError( + f"Cannot find {input_name}. Please specify valid input file(s) or directories." + ) else: - raise FileNotFoundError(f"Cannot find {input}") - setattr(args, "input_paths", input_paths) + raise FileNotFoundError(f"Cannot find {input_name}") + setattr(args, "input_paths", list(set(input_paths))) return args