Skip to content

Commit

Permalink
raise if not input files found
Browse files Browse the repository at this point in the history
  • Loading branch information
dermen committed Aug 28, 2024
1 parent 22cbd40 commit 3d6f605
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions rsbooster/io/dials2mtz.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,23 @@ def print_refl():
print_refl_info(args.reflfile)


def _write(ds, mtzname):
def _write(ds, mtzname, verbose=False):
"""write the RS dataset to mtz file"""
if verbose:
print(f"Writing MTZ {mtzname} ...")
ds.write_mtz(mtzname)
if verbose:
print("Done writing MTZ.")


def get_fnames(dirnames, verbose=False, tag=None, ext="integrated.refl"):
def get_fnames(dirnames, verbose=False, optional_tag=None, ext="integrated.refl"):
"""
Parameters
----------
dirnames: list of str, folders to search for files
verbose: bool, whether to print stdout
tag: str, only select files whose names contain this string
optional_tag: str, only select files whose names contain this string
ext: str, only select files ending with this string
Returns
Expand All @@ -62,10 +66,12 @@ def get_fnames(dirnames, verbose=False, tag=None, ext="integrated.refl"):
fnames += glob.glob(dirname + f"/*{ext}")
if verbose:
print(f"Found {len(fnames)} files")
if tag is not None:
fnames = [f for f in fnames if tag in f]
if optional_tag is not None:
fnames = [f for f in fnames if optional_tag in f]
if verbose:
print(f"Selected {len(fnames)} files with {tag} in the name.")
print(f"Selected {len(fnames)} files with {optional_tag} in the name.")
if not fnames:
raise IOError(f"No filenames were found for loading with dirnames={dirnames}, optional_tag={optional_tag}, and ext={ext}")
return fnames


Expand All @@ -76,10 +82,10 @@ def ray_main():
assert args.ucell is not None
assert args.symbol is not None

fnames = get_fnames(args.dirnames, args.verbose, tag=args.tag, ext=args.ext)
fnames = get_fnames(args.dirnames, args.verbose, optional_tag=args.tag, ext=args.ext)
ds = read_dials_stills(fnames, unitcell=args.ucell, spacegroup=args.symbol, numjobs=args.numjobs,
parallel_backend="ray", extra_cols=args.extra_cols, verbose=args.verbose)
_write(ds, args.mtz)
_write(ds, args.mtz, args.verbose)


def mpi_main():
Expand All @@ -89,11 +95,11 @@ def mpi_main():
assert args.symbol is not None
from mpi4py import MPI
COMM = MPI.COMM_WORLD
fnames = get_fnames(args.dirnames, args.verbose, tag=args.tag, ext=args.ext)
fnames = get_fnames(args.dirnames, args.verbose, optional_tag=args.tag, ext=args.ext)
ds = read_dials_stills(fnames, unitcell=args.ucell, spacegroup=args.symbol, parallel_backend="mpi",
extra_cols=args.extra_cols, verbose=args.verbose)
if COMM.rank == 0:
_write(ds, args.mtz)
_write(ds, args.mtz, args.verbose)


if __name__ == "__main__":
Expand Down

0 comments on commit 3d6f605

Please sign in to comment.