diff --git a/.gitignore b/.gitignore index 6de761fe..122afe0e 100644 --- a/.gitignore +++ b/.gitignore @@ -141,3 +141,6 @@ conda-lock.yml poetry.lock setup.py.bak *.bin + +# Test data +test/data/testdb/* \ No newline at end of file diff --git a/.gitmodules b/.gitmodules index 9228a61a..57435169 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "PolSpectra"] path = PolSpectra url = https://github.com/AlecThomson/PolSpectra +[submodule "RMTable"] + path = RMTable + url = https://github.com/CIRADA-Tools/RMTable.git diff --git a/CHANGELOG.md b/CHANGELOG.md index 78df759a..c602680f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,12 +11,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - This changelog! - `scripts/tar_cubelets.py` and CLI hook +- `makecat.py`: Added `flag_blended_components` to identify and flag blended components. Adds `is_blended_flag`, `N_blended`, `blend_ratio` to the catalogue. +- Proper logging module ### Fixed -- `columns_possum.py`: Add new Stokes I fit flags and UCDs (plus others) +- `columns_possum.py`: Add new Stokes I fit flags and UCDs (plus others) and descriptions - `scripts/casda_prepare.py`: Refactor to make considated products and make CASDA happy - `scripts/fix_dr1_cat.py`: Added extra columns that needed to be fixed in DR1 e.g. sbid, start_time +- Typing in various places ### Changed @@ -24,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `makecat.py`: Added `compute_local_rm_flag` function - `rmsynth_oncuts.py` Added new Stokes I fit flags - `utils.py`: Refactored Stokes I fitting to use dicts to track values +- Use local installs of customised packages ### Removed diff --git a/README.md b/README.md index 6d4fefa3..80aead00 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ Scripts for processing polarized RACS data products. ## Documentation -The documentation is available at [spice-racs.readthedocs.io](https://spice-racs.readthedocs.io). +The documentation is available at [spice-racs.readthedocs.io](https://spiceracs.readthedocs.io). ## Acknowledging @@ -14,7 +14,7 @@ If you use SPICE-RACS in your research, please cite [Thomson et al. (in prep)](h ### 3rd party software -Please also consider acknowledging the following software packages outlines in [docs](https://spice-racs.readthedocs.io/acknowledging.html). +Please also consider acknowledging the following software packages outlines in [docs](https://spiceracs.readthedocs.io/acknowledging.html). ## Contibuting diff --git a/RMTable b/RMTable new file mode 160000 index 00000000..b8015511 --- /dev/null +++ b/RMTable @@ -0,0 +1 @@ +Subproject commit b801551102723bf7ab181a7d6e956d149b127b52 diff --git a/pyproject.toml b/pyproject.toml index 9e1c800a..cafc3641 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ include = [ [tool.poetry.dependencies] python = "^3.8" -rm-tools = {git = "https://github.com/AlecThomson/RM-Tools.git@spiceracs_dev"} +rm-tools = {path = "./RM-Tools"} astropy = "^5" bilby = "*" casatasks = "*" @@ -63,15 +63,15 @@ vorbin = "*" graphviz = "*" bokeh = "*" prefect = "<2" -RMTable = { git = "https://github.com/CIRADA-Tools/RMTable" } -PolSpectra = { git = "https://github.com/AlecThomson/PolSpectra.git@spiceracs"} +RMTable = { path = "./RMTable" } +PolSpectra = { path = "./PolSpectra"} setuptools = "*" [tool.poetry.dev-dependencies] -black = "^22.10" +black = "^23" flake8 = "^5" isort = "^5" -mypy = "^0.991" +mypy = "^1" [tool.poetry.extras] docs = [ @@ -83,12 +83,6 @@ docs = [ "numpydoc", ] -dev = [ - "black>=22.10", - "flake8>=5", - "isort>=5", -] - [build-system] requires = ["poetry-core>=1.2"] build-backend = "poetry.core.masonry.api" diff --git a/scripts/casda_prepare.py b/scripts/casda_prepare.py index 23d5c253..42b07b41 100755 --- a/scripts/casda_prepare.py +++ b/scripts/casda_prepare.py @@ -2,7 +2,7 @@ """Prepare files for CASDA upload""" import argparse import hashlib -import logging as log +import logging import os import pickle import subprocess as sp @@ -46,6 +46,7 @@ from spectral_cube.cube_utils import convert_bunit from tqdm.auto import tqdm, trange +from spiceracs.logger import logger from spiceracs.makecat import write_votable from spiceracs.utils import chunk_dask, tqdm_dask, try_mkdir, try_symlink, zip_equal @@ -82,7 +83,7 @@ def make_thumbnail(cube_f: str, cube_dir: str): ax.set_ylabel("Dec") fig.colorbar(im, ax=ax, label=f"{convert_bunit(head['BUNIT']):latex_inline}") outf = os.path.join(cube_dir, os.path.basename(cube_f).replace(".fits", ".png")) - log.info(f"Saving thumbnail to {outf}") + logger.info(f"Saving thumbnail to {outf}") fig.savefig(outf, dpi=300) plt.close(fig) @@ -97,9 +98,9 @@ def find_spectra(data_dir: str = ".") -> list: list: List of spectra in ascii format """ cut_dir = os.path.join(data_dir, "cutouts") - log.info(f"Globbing for spectra in {cut_dir}") + logger.info(f"Globbing for spectra in {cut_dir}") spectra = glob(os.path.join(os.path.join(cut_dir, "*"), "*[0-9].dat")) - log.info(f"Found {len(spectra)} spectra (in frequency space)") + logger.info(f"Found {len(spectra)} spectra (in frequency space)") return spectra @@ -255,7 +256,7 @@ def convert_spectra( outf = os.path.join( spec_dir, os.path.basename(spectrum).replace(".dat", f"_polspec.fits") ) - log.info(f"Saving to {outf}") + logger.info(f"Saving to {outf}") spectrum_table.write_FITS(outf, overwrite=True) # Add object to header # Hard code the pixel size for now @@ -339,7 +340,7 @@ def update_cube(cube: str, cube_dir: str) -> None: "image.restored.i.", f"{imtype}.{''.join(stokes)}." ), ).replace("RACS_test4_1.05_", "RACS_") - log.info(f"Writing {outf} cubelet") + logger.info(f"Writing {outf} cubelet") fits.writeto(outf, data, header, overwrite=True) # Move cube to cubelets directory @@ -356,11 +357,11 @@ def find_cubes(data_dir: str = ".") -> list: list: List of cubelets """ cut_dir = os.path.join(data_dir, "cutouts") - log.info(f"Globbing for cubes in {cut_dir}") + logger.info(f"Globbing for cubes in {cut_dir}") cubes = glob( os.path.join(os.path.join(cut_dir, "*"), "*.image.restored.i.*.linmos.fits") ) - log.info(f"Found {len(cubes)} Stokes I image cubes") + logger.info(f"Found {len(cubes)} Stokes I image cubes") return cubes @@ -374,11 +375,11 @@ def init_polspec( outdir = casda_dir polspec_0 = polspectra.from_FITS(spectrum_table_0) out_fits = os.path.join(os.path.abspath(outdir), "spice_racs_dr1_polspec.fits") - log.info(f"Saving to {out_fits}") + logger.info(f"Saving to {out_fits}") polspec_0.write_FITS(out_fits, overwrite=True) out_hdf = os.path.join(os.path.abspath(outdir), "spice_racs_dr1_polspec.hdf5") - # log.info(f"Saving to {out_hdf}") + # logger.info(f"Saving to {out_hdf}") # polspec_0.write_HDF5(out_hdf, overwrite=True, compress=True) return out_fits, out_hdf @@ -434,10 +435,10 @@ def convert_pdf(pdf_file: str, plots_dir: str, spec_dir: str) -> None: """ png_file = pdf_file.replace(".pdf", "") png_file = os.path.join(plots_dir, os.path.basename(png_file)) - log.info(f"Converting {pdf_file} to {png_file}") + logger.info(f"Converting {pdf_file} to {png_file}") cmd = f"pdftoppm {pdf_file} {png_file} -png" - log.info(cmd) + logger.info(cmd) sp.run(cmd.split(), check=True) # Grr, pdftoppm doesn't preserve the file name actual_name = f"{png_file}-1.png" @@ -478,9 +479,9 @@ def find_plots(data_dir: str = ".") -> list: list: List of plots """ cut_dir = os.path.join(data_dir, "cutouts") - log.info(f"Globbing for plots in {cut_dir}") + logger.info(f"Globbing for plots in {cut_dir}") plots = glob(os.path.join(os.path.join(cut_dir, "RACS_*"), "*.pdf")) - log.info(f"Found {len(plots)} plots") + logger.info(f"Found {len(plots)} plots") return plots @@ -500,9 +501,9 @@ def main( # Re-register astropy units for unit in (u.deg, u.hour, u.hourangle, u.Jy, u.arcsec, u.arcmin, u.beam): get_current_unit_registry().add_enabled_units([unit]) - log.info("Starting") - log.info(f"Dask client: {client}") - log.info(f"Reading {polcatf}") + logger.info("Starting") + logger.info(f"Dask client: {client}") + logger.info(f"Reading {polcatf}") polcat = Table.read(polcatf) df = polcat.to_pandas() @@ -512,7 +513,7 @@ def main( test = prep_type == "test" - log.info(f"Preparing data for {prep_type} CASDA upload") + logger.info(f"Preparing data for {prep_type} CASDA upload") if prep_type == "full": pass @@ -545,7 +546,7 @@ def main( cube_outputs = [] if do_update_cubes: - log.info("Updating cubelets") + logger.info("Updating cubelets") cube_dir = os.path.join(casda_dir, "cubelets") try_mkdir(cube_dir) @@ -556,14 +557,14 @@ def main( set(polcat["source_id"]) ), "Number of cubes does not match number of sources" except AssertionError: - log.warning( + logger.warning( f"Found {len(cubes)} cubes, expected {len(set(polcat['source_id']))}" ) if len(cubes) < len(set(polcat["source_id"])): - log.critical("Some cubes are missing on disk!") + logger.critical("Some cubes are missing on disk!") raise else: - log.warning("Need to exclude some cubes") + logger.warning("Need to exclude some cubes") source_ids = [] for i, cube in enumerate(cubes): basename = os.path.basename(cube) @@ -572,13 +573,13 @@ def main( source_ids.append(source_id) in_idx = np.isin(source_ids, polcat["source_id"]) cubes = list(np.array(cubes)[in_idx]) - log.warning( + logger.warning( f"I had to exclude {np.sum(~in_idx)} sources that were not in the catalogue" ) # Write missing source IDs to disk rem_ids = list(set(np.array(source_ids)[~in_idx])) outf = os.path.join(casda_dir, "excluded_sources.txt") - log.info(f"Writing excluded source IDs to {outf}") + logger.info(f"Writing excluded source IDs to {outf}") with open(outf, "w") as f: for rid in rem_ids: f.write(f"{rid}\n") @@ -611,7 +612,7 @@ def my_sorter(x, lookup=lookup, pbar=pbar): spectra_outputs = [] if do_convert_spectra: - log.info("Converting spectra") + logger.info("Converting spectra") spec_dir = os.path.join(casda_dir, "spectra") try_mkdir(spec_dir) spectra = find_spectra(data_dir=data_dir) @@ -672,7 +673,7 @@ def my_sorter(x, lookup=lookup, pbar=pbar): plot_outputs = [] if do_convert_plots: - log.info("Converting plots") + logger.info("Converting plots") plots_dir = os.path.join(casda_dir, "plots") try_mkdir(plots_dir) spec_dir = os.path.join(casda_dir, "spectra") @@ -717,7 +718,7 @@ def my_sorter(x, lookup=lookup, pbar=pbar): for name, outputs in zip( ("cubes", "spectra", "plots"), (cube_outputs, spectra_outputs, plot_outputs) ): - log.info(f"Starting work on {len(outputs)} {name}") + logger.info(f"Starting work on {len(outputs)} {name}") futures = chunk_dask( outputs=outputs, @@ -732,7 +733,7 @@ def my_sorter(x, lookup=lookup, pbar=pbar): spectrum_tables = client.gather(client.compute(futures)) # Add all spectrum_tables to a tar ball tarball = os.path.join(casda_dir, f"spice_racs_dr1_polspec_{prep_type}.tar") - log.info(f"Adding spectra to tarball {tarball}") + logger.info(f"Adding spectra to tarball {tarball}") with tarfile.open(tarball, "w") as tar: for spectrum_table in tqdm( spectrum_tables, "Adding spectra to tarball" @@ -742,7 +743,7 @@ def my_sorter(x, lookup=lookup, pbar=pbar): if do_convert_spectra: os.remove(fname_polcat_hash) - log.info("Done") + logger.info("Done") def cli(): @@ -816,22 +817,9 @@ def cli(): ) args = parser.parse_args() if args.verbose: - log.basicConfig( - level=log.INFO, - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) + logger.setLevel(logging.INFO) elif args.debug: - log.basicConfig( - level=log.DEBUG, - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - else: - log.basicConfig( - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) + logger.setLevel(logging.DEBUG) if args.mpi: initialize( @@ -850,7 +838,7 @@ def cli(): with Client( cluster, ) as client: - log.debug(f"{client=}") + logger.debug(f"{client=}") main( polcatf=args.polcat, client=client, diff --git a/scripts/compare_leakage.py b/scripts/compare_leakage.py index f736f623..08b9ec1d 100755 --- a/scripts/compare_leakage.py +++ b/scripts/compare_leakage.py @@ -33,6 +33,7 @@ from IPython.core.pylabtools import figsize from spiceracs.linmos import gen_seps +from spiceracs.logger import logger, logging from spiceracs.utils import ( chunk_dask, coord_to_string, @@ -44,7 +45,6 @@ def make_plot(data, comp, imfile): - fig, axs = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(10, 10)) fig.suptitle(f"{comp['Gaussian_ID']} leakage") for i, s in enumerate(["q", "u"]): @@ -93,7 +93,7 @@ def interpolate(field, comp, beams, cutdir, septab, holofile, verbose=True): os.path.join(cutdir, f"{comp['Source_ID']}*beam{bm:02d}.conv.fits") )[0] except: - print(f"No image file for source {comp['Source_ID']} beam {bm}") + logger.critical(f"No image file for source {comp['Source_ID']} beam {bm}") return freq = getfreq(imfile) @@ -135,7 +135,7 @@ def interpolate(field, comp, beams, cutdir, septab, holofile, verbose=True): # plotdir = os.path.join(os.path.join(cutdir, 'plots'), os.path.basename(outname)) # copyfile(outname, plotdir) except Exception as e: - print(f"No plot made : {e}") + logger.warning(f"No plot made : {e}") return @@ -216,7 +216,7 @@ def main( verbose=verbose, ) - print("Comparing leakge done!") + logger.info("Comparing leakge done!") def cli(): @@ -298,6 +298,9 @@ def cli(): ) client = Client(cluster) + if args.verbose: + logger.setLevel(logging.INFO) + main( field=args.field, datadir=args.datadir, diff --git a/scripts/compute_leakage.py b/scripts/compute_leakage.py index 2dd7311a..03af4323 100755 --- a/scripts/compute_leakage.py +++ b/scripts/compute_leakage.py @@ -11,6 +11,7 @@ from astropy.wcs import WCS from tqdm.auto import tqdm, trange +from spiceracs.logger import logger, logging from spiceracs.utils import MyEncoder, get_db, getdata @@ -29,7 +30,6 @@ def makesurf(start, stop, field, datadir, save_plots=True, data=None): ras, decs, freqs, stokeis, stokeqs, stokeus = [], [], [], [], [], [] specs = [] for i, comp in enumerate(tqdm(components)): - iname = comp["Source_ID"] cname = comp["Gaussian_ID"] spectra = f"{datadir}/cutouts/{iname}/{cname}.dat" @@ -38,7 +38,7 @@ def makesurf(start, stop, field, datadir, save_plots=True, data=None): freq, iarr, qarr, uarr, rmsi, rmsq, rmsu = np.loadtxt(spectra).T specs.append([freq, iarr, qarr, uarr, rmsi, rmsq, rmsu]) except Exception as e: - print(f"Could not find '{spectra}': {e}") + logger.warning(f"Could not find '{spectra}': {e}") continue else: try: @@ -58,7 +58,7 @@ def makesurf(start, stop, field, datadir, save_plots=True, data=None): stokeqs = np.array(stokeqs) stokeus = np.array(stokeus) freqs = np.nanmean(np.array(freqs)) - print("freq is ", freqs) + logger.debug("freq is ", freqs) coords = SkyCoord(ras * units.deg, decs * units.deg) wcs = WCS( f"/group/askap/athomson/projects/RACS/CI0_mosaic_1.0/RACS_test4_1.05_{field}.fits" @@ -109,14 +109,14 @@ def trim_mean(x): # Positions of grid points to derive leakage estimates at xnew = np.arange(np.min(x), np.max(x) + grid_point_sep_deg, grid_point_sep_deg) ynew = np.arange(np.min(y), np.max(y) + grid_point_sep_deg, grid_point_sep_deg) - print(len(xnew), len(ynew)) + logger.debug(len(xnew), len(ynew)) xxnew, yynew = np.meshgrid(xnew, ynew) pos_estimator_grid = np.array([[a, b] for a in xnew for b in ynew]) # Calculate pair-wise distances between the two sets of coordinate pairs - print("\nDeriving pair-wise distance matrix...") + logger.info("\nDeriving pair-wise distance matrix...") pair_dist = distance_matrix(pos_estimator_grid, pos_measurements) - print("Done.\n") + logger.info("Done.\n") # Collect leakage values nearby each grid point q_estimates = [] @@ -124,12 +124,11 @@ def trim_mean(x): p_estimates = [] num_points_in_aperture_list = [] # Init collectors - print("\nDeriving robust leakage estimates for interpolation grid...") + logger.info("\nDeriving robust leakage estimates for interpolation grid...") for row_idx, row in enumerate(tqdm(pair_dist)): - # Guide to where we're at # if row_idx%100==0: - # print('Processing row %d of %d'%(row_idx,len(pair_dist))) + # logger.info('Processing row %d of %d'%(row_idx,len(pair_dist))) # idxs of poitns within d degs idxs_of_points_in_aperture = np.argwhere(row < d) @@ -159,7 +158,7 @@ def trim_mean(x): u_estimates_arr = np.array(u_estimates) p_estimates_arr = np.array(p_estimates) - print( + logger.info( "\nThe mean number of points in each aperture of %.2f degs was %d\n" % (d, np.nanmean(num_points_in_aperture_list)) ) diff --git a/scripts/copy_cutouts.py b/scripts/copy_cutouts.py index 24a92e81..bea19d4b 100755 --- a/scripts/copy_cutouts.py +++ b/scripts/copy_cutouts.py @@ -10,8 +10,11 @@ import spica from astropy.table import Table +from spiceracs.logger import logger, logging from spiceracs.utils import try_mkdir +logger.setLevel(logging.INFO) + racs_area = os.path.abspath("/askapbuffer/payne/mcc381/RACS") # spice_area = os.path.abspath('/group/askap/athomson/projects/spiceracs/spica') spice_area = os.path.abspath("/scratch/ja3/athomson/spica") @@ -29,7 +32,7 @@ def main(field, dry_run=False, ncores=10): if not test_cut: raise FileNotFoundError(cut_dir) else: - print(f"Copying '{cut_dir}'") + logger.info(f"Copying '{cut_dir}'") store_dir = os.path.join( group_area, f"{row['CAL SBID']}", f"RACS_test4_1.05_{field}", "cutouts" @@ -39,7 +42,7 @@ def main(field, dry_run=False, ncores=10): if not test_cut: raise FileNotFoundError(test_store) else: - print(f"Storing in '{store_dir}'") + logger.info(f"Storing in '{store_dir}'") if not dry_run: copy_data.prsync(f"{cut_dir}/*", store_dir, ncores=ncores) diff --git a/scripts/copy_cutouts_askap.py b/scripts/copy_cutouts_askap.py index cff1f49e..24405540 100755 --- a/scripts/copy_cutouts_askap.py +++ b/scripts/copy_cutouts_askap.py @@ -10,8 +10,11 @@ import spica from astropy.table import Table +from spiceracs.logger import logger, logging from spiceracs.utils import try_mkdir +logger.setLevel(logging.INFO) + # racs_area = os.path.abspath('/askapbuffer/processing/len067/spiceracs') # spice_area = os.path.abspath('/group/askap/athomson/projects/spiceracs/spica') spice_area = os.path.abspath("/askapbuffer/processing/len067/spiceracs") @@ -30,7 +33,7 @@ def main(field, dry_run=False, ncores=10): if not test_cut: raise FileNotFoundError(cut_dir) else: - print(f"Copying '{cut_dir}'") + logger.info(f"Copying '{cut_dir}'") field_dir = os.path.join( group_area, f"{row['CAL SBID']}", f"RACS_test4_1.05_{field}" @@ -48,7 +51,7 @@ def main(field, dry_run=False, ncores=10): if not test_store: raise FileNotFoundError(store_dir) else: - print(f"Storing in '{store_dir}'") + logger.info(f"Storing in '{store_dir}'") if not dry_run: copy_data.prsync(f"{cut_dir}/*", store_dir, ncores=ncores) diff --git a/scripts/copy_data.py b/scripts/copy_data.py index bba9c3fd..fc22bf3c 100755 --- a/scripts/copy_data.py +++ b/scripts/copy_data.py @@ -6,6 +6,7 @@ from astropy.table import Table +from spiceracs.logger import logger from spiceracs.utils import try_mkdir @@ -85,7 +86,7 @@ def main( copyfile(abspath, newpath) except SameFileError: pass - print(os.path.basename(newpath)) + logger.debug(os.path.basename(newpath)) if clean: if yes: diff --git a/scripts/find_row.py b/scripts/find_row.py index c265063d..601d0ab7 100755 --- a/scripts/find_row.py +++ b/scripts/find_row.py @@ -4,6 +4,10 @@ from astropy.table import Table +from spiceracs.logger import logger, logging + +logger.setLevel(logging.INFO) + def main(name: str, sbid: int): scriptdir = os.path.dirname(os.path.realpath(__file__)) @@ -13,7 +17,7 @@ def main(name: str, sbid: int): tab.add_index("CAL_SBID") row = tab.loc["FIELD_NAME", f"RACS_{name}"].loc["CAL_SBID", sbid]["INDEX"] - print(f"Row in RACS database is {row}") + logger.info(f"Row in RACS database is {row}") def cli(): diff --git a/scripts/find_sbid.py b/scripts/find_sbid.py index 6ec695d1..102f501e 100755 --- a/scripts/find_sbid.py +++ b/scripts/find_sbid.py @@ -4,6 +4,10 @@ from astropy.table import Row, Table +from spiceracs.logger import logger, logging + +logger.setLevel(logging.INFO) + sorted_sbids = [ 8570, 8574, @@ -95,18 +99,18 @@ def main(name: str, cal=False, science=False, weight=False): sub_tab = Table(sel_tab.loc["FIELD_NAME", f"RACS_{name}"]) space = " " if cal: - print(int(sub_tab["CAL_SBID"])) + logger.info(int(sub_tab["CAL_SBID"])) if science: - print(int(sub_tab["SBID"])) + logger.info(int(sub_tab["SBID"])) if weight: sbid = int(sub_tab["SBID"]) - print(int(sorted_weights[sorted_sbids.index(sbid)])) + logger.info(int(sorted_weights[sorted_sbids.index(sbid)])) if not cal and not science and not weight: - print(f"DB info for RACS_{name}:\n") + logger.info(f"DB info for RACS_{name}:\n") for i, row in enumerate(sub_tab): - print(f"{space}CAL SBID {i+1}: {row['CAL_SBID']}") - print(f"{space}Science SBID {i+1}: {row['SBID']}\n") + logger.info(f"{space}CAL SBID {i+1}: {row['CAL_SBID']}") + logger.info(f"{space}Science SBID {i+1}: {row['SBID']}\n") def cli(): diff --git a/scripts/fix_dr1_cat.py b/scripts/fix_dr1_cat.py index 50d0d2d7..43ea0096 100755 --- a/scripts/fix_dr1_cat.py +++ b/scripts/fix_dr1_cat.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """Post process DR1 catalog""" -import logging as log +import logging import os import pickle @@ -10,10 +10,12 @@ from astropy.coordinates import SkyCoord from astropy.table import Column, Table from astropy.time import Time +from astropy.units import cds from IPython import embed from rmtable import RMTable from spica import SPICA +from spiceracs.logger import logger from spiceracs.makecat import ( compute_local_rm_flag, get_fit_func, @@ -34,18 +36,24 @@ def fix_fields(tab: Table) -> Table: # Compare the fields we have to those we want fields_in_cat = list(set(tab["tile_id"])) fields_in_spica = [f"RACS_{name}" for name in SPICA] - log.debug(f"Fields in catalogue: {fields_in_cat}") - log.debug(f"Fields in spica: {fields_in_spica}") + logger.debug(f"Fields in catalogue: {fields_in_cat}") + logger.debug(f"Fields in spica: {fields_in_spica}") fields_not_in_spica = [f for f in fields_in_cat if f not in fields_in_spica] spica_field = field.loc[fields_in_spica] spica_field_coords = SkyCoord( spica_field["RA_DEG"], spica_field["DEC_DEG"], unit=(u.deg, u.deg), frame="icrs" ) start_times = Time(spica_field["SCAN_START"] * u.second, format="mjd") - spica_field["start_time"] = start_times + spica_field.add_column( + Column( + start_times.to_value("mjd"), + name="start_time", + unit=cds.MJD, + ), + ) # These are the sources to update sources_to_fix = tab.loc[fields_not_in_spica] - log.info(f"Found {len(sources_to_fix)} sources to fix") + logger.info(f"Found {len(sources_to_fix)} sources to fix") source_coords = SkyCoord(sources_to_fix["ra"], sources_to_fix["dec"]) @@ -67,7 +75,13 @@ def fix_fields(tab: Table) -> Table: all_fields = new_tab["tile_id"].value all_fields[idx] = closest_fields - new_tab["tile_id"] = all_fields + new_tab.replace_column( + "tile_id", + Column( + all_fields, + name="tile_id", + ), + ) all_seps = ( new_tab["separation_tile_centre"].value * new_tab["separation_tile_centre"].unit @@ -77,50 +91,99 @@ def fix_fields(tab: Table) -> Table: all_sbids = new_tab["sbid"].value all_sbids[idx] = spica_field["SBID"][min_idx].value - all_start_times = new_tab["start_time"].value - all_start_times[idx] = spica_field["start_time"][min_idx].value + all_start_times = new_tab["start_time"] + all_start_times[idx] = spica_field["start_time"][min_idx] # Update the columns - new_tab["separation_tile_centre"] = Column( - data=all_seps, + new_tab.replace_column( + "separation_tile_centre", + Column( + data=all_seps, + name="separation_tile_centre", + unit=all_seps.unit, + ), ) - new_tab["beamdist"] = Column( - data=all_seps, + new_tab.replace_column( + "beamdist", + Column( + data=all_seps, + name="beamdist", + unit=all_seps.unit, + ), ) - new_tab["sbid"] = Column( - data=all_sbids, + new_tab.replace_column( + "sbid", + Column( + data=all_sbids, + name="sbid", + ), ) - new_tab["start_time"] = Column( - data=all_start_times, + new_tab.replace_column( + "start_time", + Column( + data=all_start_times, + name="start_time", + unit=all_start_times.unit, + ), ) # Fix the units - Why does VOTable do this?? Thanks I hate it dumb_units = { "Jy.beam-1": u.Jy / u.beam, "mJy.beam-1": u.mJy / u.beam, + "day": u.d, } for col in new_tab.colnames: if str(new_tab[col].unit) in dumb_units.keys(): - new_tab[col].unit = dumb_units[str(new_tab[col].unit)] + new_unit = dumb_units[str(new_tab[col].unit)] + logger.debug(f"Fixing {col} unit from {new_tab[col].unit} to {new_unit}") + new_tab[col].unit = new_unit + new_tab.units[col] = new_unit + + # Convert all mJy to Jy + for col in new_tab.colnames: + if new_tab[col].unit == u.mJy: + logger.debug(f"Converting {col} unit from {new_tab[col].unit} to {u.Jy}") + new_tab[col] = new_tab[col].to(u.Jy) + new_tab.units[col] = u.Jy + if new_tab[col].unit == u.mJy / u.beam: + logger.debug( + f"Converting {col} unit from {new_tab[col].unit} to {u.Jy / u.beam}" + ) + new_tab[col] = new_tab[col].to(u.Jy / u.beam) + new_tab.units[col] = u.Jy / u.beam return new_tab def main(cat: str): - log.debug(f"Reading {cat}") + logger.info(f"Reading {cat}") tab = RMTable.read(cat) - log.debug(f"Fixing {cat}") + logger.info(f"Fixing {cat}") + fix_tab = fix_fields(tab) fit, fig = get_fit_func(fix_tab, do_plot=True, nbins=16, degree=4) fig.savefig("leakage_fit_dr1_fix.pdf") leakage_flag = is_leakage( fix_tab["fracpol"].value, fix_tab["beamdist"].to(u.deg).value, fit ) - fix_tab["leakage_flag"] = leakage_flag + fix_tab.replace_column( + "leakage_flag", + Column( + leakage_flag, + name="leakage_flag", + ), + ) leakage = fit(fix_tab["separation_tile_centre"].to(u.deg).value) - fix_tab["leakage"] = leakage + fix_tab.replace_column( + "leakage", + Column( + leakage, + name="leakage", + ), + ) goodI = ~fix_tab["stokesI_fit_flag"] & ~fix_tab["channel_flag"] goodL = goodI & ~fix_tab["leakage_flag"] & (fix_tab["snr_polint"] > 5) @@ -134,18 +197,15 @@ def main(cat: str): outfit = cat.replace(ext, f".corrected.leakage.pkl") with open(outfit, "wb") as f: pickle.dump(fit, f) - log.info(f"Wrote leakage fit to {outfit}") + logger.info(f"Wrote leakage fit to {outfit}") - # outplot = cat.replace(ext, f'.corrected.leakage.pdf') - # log.info(f"Writing leakage plot to {outplot}") - # fig.savefig(outplot, dpi=300, bbox_inches='tight') - log.info(f"Writing corrected catalogue to {outfile}") + logger.info(f"Writing corrected catalogue to {outfile}") if ext == ".xml" or ext == ".vot": write_votable(fix_flag_tab, outfile) else: tab.write(outfile, overwrite=True) - log.info(f"{outfile} written to disk") - log.info("Done!") + logger.info(f"{outfile} written to disk") + logger.info("Done!") def cli(): @@ -156,21 +216,10 @@ def cli(): parser.add_argument("--debug", action="store_true", help="Print debug messages") args = parser.parse_args() - log.basicConfig( - level=log.INFO, - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) + logger.setLevel(logging.INFO) if args.debug: - log.basicConfig( - level=log.DEBUG, - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) - + logger.setLevel(logging.DEBUG) main(cat=args.catalogue) diff --git a/scripts/make_links.py b/scripts/make_links.py index 68254d5e..4f7a6985 100755 --- a/scripts/make_links.py +++ b/scripts/make_links.py @@ -5,6 +5,10 @@ from glob import glob from shlex import split +from spiceracs.logger import logger, logging + +logger.setLevel(logging.INFO) + def main(indir, outdir): images = glob(f"{os.path.abspath(indir)}/image.restored.*.contcube.*.fits") @@ -13,13 +17,13 @@ def main(indir, outdir): name = os.path.basename(f) link = name.replace(".fits", ".conv.fits") cmd = f"ln -s {f} {os.path.abspath(outdir)}/{link}" - print(cmd) + logger.info(cmd) subprocess.run(split(cmd)) for f in weights: name = os.path.basename(f) cmd = f"ln -s {f} {os.path.abspath(outdir)}/{name}" - print(cmd) + logger.info(cmd) subprocess.run(split(cmd)) diff --git a/scripts/spica.py b/scripts/spica.py index f5726b74..8cf37171 100755 --- a/scripts/spica.py +++ b/scripts/spica.py @@ -9,8 +9,11 @@ import pkg_resources from astropy.table import Table +from spiceracs.logger import logger, logging from spiceracs.utils import try_mkdir +logger.setLevel(logging.INFO) + racs_area = os.path.abspath("/askapbuffer/payne/mcc381/RACS") # spice_area = os.path.abspath('/group/askap/athomson/projects/spiceracs/spica') spice_area = os.path.abspath("/scratch/ja3/athomson/spica") @@ -51,11 +54,8 @@ def mslist(cal_sb, name): - # os.system('module unload askapsoft') - # os.system('module load askapsoft') try: ms = glob(f"{racs_area}/{cal_sb}/RACS_test4_1.05_{name}/*beam00_*.ms")[0] - # print('ms',ms) except: raise Exception( f"Can't find '{racs_area}/{cal_sb}/RACS_test4_1.05_{name}/*beam00_*.ms'" @@ -65,15 +65,14 @@ def mslist(cal_sb, name): shlex.split(f"mslist --full {ms}"), capture_output=True, check=False ) if mslist_out.returncode > 0: - print(mslist_out.stderr.decode("utf-8")) - print(mslist_out.stdout.decode("utf-8")) + logger.error(mslist_out.stderr.decode("utf-8")) + logger.error(mslist_out.stdout.decode("utf-8")) mslist_out.check_returncode() date_out = sb.run( shlex.split("date +%Y-%m-%d-%H%M%S"), capture_output=True, check=True ) out = mslist_out.stderr.decode() + f"METADATA_IS_GOOD {date_out.stdout.decode()}" - # print(out) return out @@ -120,25 +119,22 @@ def main(copy=False, force=False, cal=False, mslist_dir=None, cube_image=False): spica_tab.sort("SBID") spica_tab.pprint_all() if cal: - print("The following row indcies are ready to image:") + logger.info("The following row indcies are ready to image:") sub_tab = spica_tab[spica_tab["Leakage cal"]] indxs = [] for row in sub_tab: indxs.append(row["Row index"]) indxs = np.array(indxs) - print(" ".join(indxs.astype(str))) + logger.info(" ".join(indxs.astype(str))) if mslist_dir is not None: mslist_dir = os.path.abspath(mslist_dir) - # print('mslist_dir',mslist_dir) for row in spica_tab: try: out = mslist( name=row["Field name"].replace("RACS_", ""), cal_sb=row["CAL SBID"] ) - # print('out',out) sbdir = f"{mslist_dir}/{row['SBID']}" - # print('sbdir',sbdir) try_mkdir(sbdir, verbose=False) outdir = f"{sbdir}/metadata" try_mkdir(outdir, verbose=False) @@ -146,14 +142,16 @@ def main(copy=False, force=False, cal=False, mslist_dir=None, cube_image=False): with open(outfile, "w") as f: f.write(out) except Exception as e: - print(e) + logger.error(e) continue if copy: for row in spica_tab: if row["Leakage cal"]: if row["Cube imaging"]: - print(f"Cube imaging done for {row['Field name']}. Skipping...") + logger.info( + f"Cube imaging done for {row['Field name']}. Skipping..." + ) continue else: copy_data.main( @@ -171,7 +169,7 @@ def main(copy=False, force=False, cal=False, mslist_dir=None, cube_image=False): for row in spica_tab: cmd = f"start_pipeline.py -e 0 -p /group/askap/athomson/projects/spiceracs/spica/racs_pipeline_cube.parset -o -m /group/askap/athomson/projects/spiceracs/spica/modules.txt -t /group/askap/athomson/projects/spiceracs/MSlists/{row['SBID']}/metadata/ -i {row['SBID']} -c {row['CAL SBID']} -f {row['Field name'].replace('RACS_', 'RACS_test4_1.05_')} -a ja3 -s" cmds.append(cmd) - print(f"Written imaging commands to '{cube_image}'") + logger.info(f"Written imaging commands to '{cube_image}'") with open(cube_image, "w") as f: f.write("\n".join(cmds)) return spica_tab diff --git a/scripts/tar_cubelets.py b/scripts/tar_cubelets.py index 731711be..3f45bcb1 100755 --- a/scripts/tar_cubelets.py +++ b/scripts/tar_cubelets.py @@ -9,6 +9,8 @@ from dask.distributed import Client from tqdm.auto import tqdm +from spiceracs.logger import logger + @delayed def tar_cubelets(out_dir: str, casda_dir: str, prefix: str) -> None: @@ -19,12 +21,12 @@ def tar_cubelets(out_dir: str, casda_dir: str, prefix: str) -> None: casda_dir (str): CASDA directory containing cubelets/ prefix (str): Prefix of cubelets to tar """ - print(f"Tarring {prefix}...") + logger.info(f"Tarring {prefix}...") with tarfile.open(os.path.join(out_dir, f"{prefix}_cubelets.tar"), "w") as tar: _cube_list = glob(os.path.join(casda_dir, "cubelets", f"{prefix}*.fits")) for cube in _cube_list: tar.add(cube, arcname=os.path.basename(cube)) - print(f"...done {prefix}!") + logger.info(f"...done {prefix}!") def main(casda_dir: str): @@ -43,21 +45,21 @@ def main(casda_dir: str): raise FileNotFoundError(f"Directory {casda_dir} does not contain cubelets/") cube_list = glob(os.path.join(casda_dir, "cubelets", "*.fits")) - print(f"{len(cube_list)} cublets to tar...") + logger.info(f"{len(cube_list)} cublets to tar...") sources = set( [os.path.basename(cube)[:13] for cube in tqdm(cube_list, desc="Sources")] ) - print(f"...into {len(sources)} sources") + logger.info(f"...into {len(sources)} sources") out_dir = os.path.join(casda_dir, "cubelets_tar") os.makedirs(out_dir, exist_ok=True) - print(f"Output directory: {out_dir}") + logger.info(f"Output directory: {out_dir}") outputs = [] for source in tqdm(sources, desc="Tarring"): outputs.append(tar_cubelets(out_dir, casda_dir, source)) dask.compute(*outputs) - print("Done!") + logger.info("Done!") if __name__ == "__main__": @@ -67,6 +69,12 @@ def main(casda_dir: str): parser.add_argument( "casda_dir", help="CASDA directory containing cublets/ to tar", type=str ) + parser.add_argument( + "-v", "--verbose", help="Increase output verbosity", action="store_true" + ) args = parser.parse_args() - # with Client() as client: + + if args.verbose: + logger.setLevel("INFO") + main(args.casda_dir) diff --git a/spiceracs/askap_surveys/racs b/spiceracs/askap_surveys/racs index 8feef9ff..de06d191 160000 --- a/spiceracs/askap_surveys/racs +++ b/spiceracs/askap_surveys/racs @@ -1 +1 @@ -Subproject commit 8feef9ff5a209a5b9d980c17b7ee9262f86e2ba7 +Subproject commit de06d191172097dbe003665235fe7e41a26c40dd diff --git a/spiceracs/cleanup.py b/spiceracs/cleanup.py index eb5be3c2..467d7d45 100644 --- a/spiceracs/cleanup.py +++ b/spiceracs/cleanup.py @@ -1,14 +1,15 @@ #!/usr/bin/env python3 """DANGER ZONE: Purge directories of un-needed FITS files.""" -import logging as log +import logging import os import time from glob import glob -from typing import List +from typing import List, Union from dask import delayed from dask.distributed import Client, LocalCluster +from spiceracs.logger import logger from spiceracs.utils import chunk_dask @@ -29,7 +30,10 @@ def cleanup(workdir: str, stoke: str) -> None: def main( - datadir: str, client: Client, stokeslist: List[str] = None, verbose=True + datadir: str, + client: Client, + stokeslist: Union[List[str], None] = None, + verbose=True, ) -> None: """Clean up beam images @@ -71,7 +75,7 @@ def main( verbose=verbose, ) - log.info("Cleanup done!") + logger.info("Cleanup done!") def cli(): @@ -119,18 +123,7 @@ def cli(): verbose = args.verbose if verbose: - log.basicConfig( - level=log.INFO, - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) - else: - log.basicConfig( - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) + logger.setLevel(logging.INFO) cluster = LocalCluster(n_workers=20) client = Client(cluster) diff --git a/spiceracs/columns_possum.py b/spiceracs/columns_possum.py index b8285beb..0dd10a97 100644 --- a/spiceracs/columns_possum.py +++ b/spiceracs/columns_possum.py @@ -368,4 +368,36 @@ "description": "Second moment complexity flag", "ucd": "meta.code", }, + "is_blended_flag": { + "description": "Component is within beamwidth of another component.", + "ucd": "meta.code", + }, + "blend_ratio": { + "description": "Ratio of total flux of this component to total flux of components that blend with it.", + "ucd": "phot.flux.density;arith.ratio", + }, + "N_blended": { + "description": "Number of components that blend with this component.", + "ucd": "meta.number", + }, + "catalog_name": { + "description": "Name of catalog", + "ucd": "meta.note", + }, + "obs_interval": { + "description": "Interval of observation", + "ucd": "time.interval", + }, + "stokesI_chi2_red": { + "description": "Reduced chi-squared of Stokes I fit", + "ucd": "stat.fit.chi2;phot.flux.density;phys.polarization.stokes.I", + }, + "stokesI_model_order": { + "description": "Order of Stokes I model", + "ucd": "stat.fit;meta.number", + }, + "stokesI_model_coef_err": { + "description": "Error in Stokes I model coefficients", + "ucd": "stat.error;stat.fit.param;phys.polarization.stokes.I", + }, } diff --git a/spiceracs/cutout.py b/spiceracs/cutout.py index 7f3ea10d..1342ffe4 100644 --- a/spiceracs/cutout.py +++ b/spiceracs/cutout.py @@ -3,7 +3,7 @@ import argparse import functools import json -import logging as log +import logging import os import shlex import subprocess @@ -14,7 +14,7 @@ from glob import glob from pprint import pformat from shutil import copyfile -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union import astropy.units as u import dask @@ -38,6 +38,7 @@ from spectral_cube.utils import SpectralCubeWarning from tqdm import tqdm, trange +from spiceracs.logger import logger from spiceracs.utils import ( MyEncoder, chunk_dask, @@ -74,7 +75,7 @@ def cutout( pad=3, verbose=False, dryrun=False, -) -> pymongo.UpdateOne: +) -> List[pymongo.UpdateOne]: """Perform a cutout. Args: @@ -111,10 +112,10 @@ def cutout( ".conv.fits", ".txt" ) copyfile(image, outfile) - log.info(f"Written to {outfile}") + logger.info(f"Written to {outfile}") if imtype == "image": - log.info(f"Reading {image}") + logger.info(f"Reading {image}") with warnings.catch_warnings(): warnings.simplefilter("ignore", AstropyWarning) cube = SpectralCube.read(image) @@ -160,7 +161,7 @@ def cutout( overwrite=True, output_verify="fix", ) - log.info(f"Written to {outfile}") + logger.info(f"Written to {outfile}") # Update database myquery = {"Source_ID": src_name} @@ -252,8 +253,8 @@ def get_args( dec_off = Longitude(majs[dec_i_min]) dec_lo = dec_min - dec_off except Exception as e: - log.debug(f"coords are {coords=}") - log.debug(f"comps are {comps=}") + logger.debug(f"coords are {coords=}") + logger.debug(f"comps are {comps=}") raise e args = [] @@ -319,13 +320,12 @@ def cutout_islands( directory: str, host: str, client: Client, - username: str = None, - password: str = None, - verbose=True, - pad=3, - stokeslist: List[str] = None, - verbose_worker=False, - dryrun=True, + username: Union[str, None] = None, + password: Union[str, None] = None, + pad: float = 3, + stokeslist: Union[List[str], None] = None, + verbose_worker: bool = False, + dryrun: bool = True, ) -> None: """Perform cutouts of RACS islands in parallel. @@ -344,7 +344,7 @@ def cutout_islands( """ if stokeslist is None: stokeslist = ["I", "Q", "U", "V"] - log.debug(f"Client is {client}") + logger.debug(f"Client is {client}") directory = os.path.abspath(directory) outdir = os.path.join(directory, "cutouts") @@ -377,7 +377,7 @@ def cutout_islands( try_mkdir(outdir) args = [] - for (island_id, island, comp, beam) in zip(island_ids, islands, comps, beams): + for island_id, island, comp, beam in zip(island_ids, islands, comps, beams): if len(comp) == 0: warnings.warn(f"Skipping island {island_id} -- no components found") continue @@ -397,9 +397,7 @@ def cutout_islands( flat_args = unpack(args) flat_args = client.compute(flat_args) - tqdm_dask( - flat_args, desc="Getting args", disable=(not verbose), total=len(islands) + 1 - ) + tqdm_dask(flat_args, desc="Getting args", total=len(islands) + 1) flat_args = flat_args.result() cuts = [] for arg in flat_args: @@ -425,16 +423,15 @@ def cutout_islands( client=client, task_name="cutouts", progress_text="Cutting out", - verbose=verbose, ) if not dryrun: _updates = [f.compute() for f in futures] updates = [val for sublist in _updates for val in sublist] - log.info("Updating database...") + logger.info("Updating database...") db_res = beams_col.bulk_write(updates, ordered=False) - log.info(pformat(db_res.bulk_api_result)) + logger.info(pformat(db_res.bulk_api_result)) - log.info("Cutouts Done!") + logger.info("Cutouts Done!") def main(args: argparse.Namespace, verbose=True) -> None: @@ -448,7 +445,7 @@ def main(args: argparse.Namespace, verbose=True) -> None: n_workers=12, threads_per_worker=1, dashboard_address=":9898" ) client = Client(cluster) - log.info(client) + logger.info(client) cutout_islands( field=args.field, directory=args.datadir, @@ -456,14 +453,13 @@ def main(args: argparse.Namespace, verbose=True) -> None: client=client, username=args.username, password=args.password, - verbose=verbose, pad=args.pad, stokeslist=args.stokeslist, verbose_worker=args.verbose_worker, dryrun=args.dryrun, ) - log.info("Done!") + logger.info("Done!") def cli() -> None: @@ -558,21 +554,12 @@ def cli() -> None: verbose = args.verbose if verbose: - log.basicConfig( - level=log.INFO, - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) - else: - log.basicConfig( - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) + logger.setLevel(logging.INFO) test_db( - host=args.host, username=args.username, password=args.password, verbose=verbose + host=args.host, + username=args.username, + password=args.password, ) main(args, verbose=verbose) diff --git a/spiceracs/frion.py b/spiceracs/frion.py index 2bd71a60..8312aebd 100644 --- a/spiceracs/frion.py +++ b/spiceracs/frion.py @@ -1,12 +1,12 @@ #!/usr/bin/env python3 """Correct for the ionosphere in parallel""" -import logging as log +import logging import os import time from glob import glob from pprint import pformat from shutil import copyfile -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union import astropy.units as u import dask @@ -17,6 +17,7 @@ from dask.distributed import Client, LocalCluster, progress, wait from FRion import correct, predict +from spiceracs.logger import logger from spiceracs.utils import get_db, get_field_db, getfreq, test_db, tqdm_dask, try_mkdir @@ -84,7 +85,7 @@ def predict_worker( plotdir (str): Plot directory Returns: - str: Prediction file name + Tuple[str, pymongo.UpdateOne]: FRion prediction file and pymongo update query """ ifile = os.path.join(cutdir, beam["beams"][field]["i_file"]) i_dir = os.path.dirname(ifile) @@ -140,8 +141,8 @@ def main( outdir: str, host: str, client: Client, - username: str = None, - password: str = None, + username: Union[str, None] = None, + password: Union[str, None] = None, database=False, verbose=True, ): @@ -186,7 +187,7 @@ def main( field_datas = list(field_col.find({"FIELD_NAME": f"RACS_{field}"})) sbids = [f["CAL_SBID"] for f in field_datas] max_idx = np.argmax(sbids) - log.info(f"Using CAL_SBID {sbids[max_idx]}") + logger.info(f"Using CAL_SBID {sbids[max_idx]}") field_data = field_datas[max_idx] else: field_data = field_col.find_one({"FIELD_NAME": f"RACS_{field}"}) @@ -237,16 +238,16 @@ def main( futures, desc="Running FRion", disable=(not verbose), total=len(islands) * 3 ) if database: - log.info("Updating beams database...") + logger.info("Updating beams database...") updates = [f.compute() for f in futures] db_res = beams_col.bulk_write(updates, ordered=False) - log.info(pformat(db_res.bulk_api_result)) + logger.info(pformat(db_res.bulk_api_result)) - log.info("Updating island database...") + logger.info("Updating island database...") updates_arrays_cmp = [f.compute() for f in future_arrays] db_res = island_col.bulk_write(updates_arrays_cmp, ordered=False) - log.info(pformat(db_res.bulk_api_result)) + logger.info(pformat(db_res.bulk_api_result)) def cli(): @@ -325,25 +326,13 @@ def cli(): verbose = args.verbose if verbose: - log.basicConfig( - level=log.INFO, - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) - else: - log.basicConfig( - level=log.WARNING, - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) + logger.setLevel(logging.INFO) cluster = LocalCluster( n_workers=10, processes=True, threads_per_worker=1, local_directory="/dev/shm" ) client = Client(cluster) - log.info(client) + logger.info(client) test_db( host=args.host, username=args.username, password=args.password, verbose=verbose diff --git a/spiceracs/init_database.py b/spiceracs/init_database.py index c6337cd8..24553fff 100644 --- a/spiceracs/init_database.py +++ b/spiceracs/init_database.py @@ -2,13 +2,13 @@ """Create the SPICE-RACS database""" import functools import json -import logging as log +import logging import os import sys import time from functools import partial from glob import glob -from typing import List, Tuple +from typing import Dict, List, Tuple, Union import matplotlib.pyplot as plt import numpy as np @@ -21,19 +21,21 @@ from astropy.table import Table, vstack from astropy.wcs import WCS from IPython import embed +from pymongo.results import InsertManyResult from tqdm import tqdm, trange +from spiceracs.logger import logger from spiceracs.utils import MyEncoder, get_db, get_field_db, getdata, test_db, yes_or_no -def source2beams(ra: float, dec: float, database: Table, max_sep=1) -> Table: +def source2beams(ra: float, dec: float, database: Table, max_sep: float = 1) -> Table: """Find RACS beams that contain a given source position Args: ra (float): RA of source in degrees. dec (float): DEC of source in degrees. database (dict): RACS database table. - max_sep (int, optional): Maximum seperation of source to beam centre in degrees. Defaults to 1. + max_sep (float, optional): Maximum seperation of source to beam centre in degrees. Defaults to 1. Returns: Table: Subset of RACS databsae table containing beams that contain the source. @@ -66,20 +68,19 @@ def ndix_unique(x: np.ndarray) -> Tuple[np.ndarray, List[np.ndarray]]: def cat2beams( - mastercat: Table, database: Table, max_sep=1, verbose=True + mastercat: Table, database: Table, max_sep: float = 1 ) -> Tuple[np.ndarray, np.ndarray, Angle]: """Find the separations between sources in the master catalogue and the RACS beams Args: mastercat (Table): Master catalogue table. database (Table): RACS database table. - max_sep (int, optional): Maxium source separation in degrees. Defaults to 1. - verbose (bool, optional): Verbose output. Defaults to True. + max_sep (float, optional): Maxium source separation in degrees. Defaults to 1. Returns: Tuple[np.ndarray, np.ndarray, Angle]: Output of astropy.coordinates.search_around_sky """ - log.info("Getting separations from beam centres...") + logger.info("Getting separations from beam centres...") c1 = SkyCoord(database["RA_DEG"] * u.deg, database["DEC_DEG"] * u.deg, frame="icrs") m_ra = mastercat["RA"] @@ -98,10 +99,9 @@ def source_database( islandcat: Table, compcat: Table, host: str, - username: str = None, - password: str = None, - verbose=True, -): + username: Union[str, None] = None, + password: Union[str, None] = None, +) -> Tuple[InsertManyResult, InsertManyResult]: """Insert sources into the database Following https://medium.com/analytics-vidhya/how-to-upload-a-pandas-dataframe-to-mongodb-ffa18c0953c1 @@ -112,33 +112,39 @@ def source_database( host (str): MongoDB host IP. username (str, optional): Mongo username. Defaults to None. password (str, optional): Mongo host. Defaults to None. - verbose (bool, optional): Verbose output. Defaults to True. + + Returns: + Tuple[InsertManyResult, InsertManyResult]: Results for the islands and components inserts. """ # Read in main catalogues # Use pandas and follow # https://medium.com/analytics-vidhya/how-to-upload-a-pandas-dataframe-to-mongodb-ffa18c0953c1 df_i = islandcat.to_pandas() if type(df_i["Source_ID"][0]) is bytes: - log.info("Decoding strings!") + logger.info("Decoding strings!") str_df = df_i.select_dtypes([object]) str_df = str_df.stack().str.decode("utf-8").unstack() for col in str_df: df_i[col] = str_df[col] source_dict_list = df_i.to_dict("records") - log.info("Loading islands into mongo...") + logger.info("Loading islands into mongo...") beams_col, island_col, comp_col = get_db( host=host, username=username, password=password ) - island_col.delete_many({}) # Delete previous database - island_col.insert_many(source_dict_list) + island_delete_res = island_col.delete_many({}) # Delete previous database + logger.warning( + f"Deleted {island_delete_res.deleted_count} documents from island collection" + ) + island_insert_res = island_col.insert_many(source_dict_list) + count = island_col.count_documents({}) - log.info("Done loading") - log.info(f"Total documents: {count}") + logger.info("Done loading") + logger.info(f"Total documents: {count}") df_c = compcat.to_pandas() if type(df_c["Source_ID"][0]) is bytes: - log.info("Decoding strings!") + logger.info("Decoding strings!") str_df = df_c.select_dtypes([object]) str_df = str_df.stack().str.decode("utf-8").unstack() for col in str_df: @@ -146,38 +152,72 @@ def source_database( source_dict_list = df_c.to_dict("records") - log.info("Loading components into mongo...") - beams_col, island_col, comp_col = get_db( - host=host, username=username, password=password + logger.info("Loading components into mongo...") + comp_delete_res = comp_col.delete_many({}) # Delete previous database + logger.warning( + f"Deleted {comp_delete_res.deleted_count} documents from component collection" ) - comp_col.delete_many({}) # Delete previous database - comp_col.insert_many(source_dict_list) + comp_insert_res = comp_col.insert_many(source_dict_list) count = comp_col.count_documents({}) - log.info("Done loading") - log.info(f"Total documents: {count}") + logger.info("Done loading") + logger.info(f"Total documents: {count}") + + return island_insert_res, comp_insert_res + +def beam_database( + islandcat: Table, + host: str, + username: Union[str, None] = None, + password: Union[str, None] = None, + epoch: int = 0, +) -> InsertManyResult: + """Insert beams into the database + + Args: + islandcat (Table): Island catalogue table. + host (str): MongoDB host IP. + username (str, optional): Mongo username. Defaults to None. + password (str, optional): Mongo host. Defaults to None. + epoch (int, optional): RACS epoch to use. Defaults to 0. -def beam_database(islandcat, host, username=None, password=None, verbose=True): + Returns: + InsertManyResult: Result of the insert. + """ # Get pointing info from RACS database - racs_fields = get_catalogue(verbose=verbose) + racs_fields = get_catalogue( + epoch=epoch, + ) # Get beams - beam_list = get_beams(islandcat, racs_fields, verbose=verbose) - log.info("Loading into mongo...") + beam_list = get_beams(islandcat, racs_fields) + logger.info("Loading into mongo...") json_data = json.loads(json.dumps(beam_list, cls=MyEncoder)) beams_col, island_col, comp_col = get_db( host=host, username=username, password=password ) - beams_col.delete_many({}) # Delete previous databas - beams_col.insert_many(json_data) + delete_res = beams_col.delete_many({}) # Delete previous databas + logger.warning(f"Deleted {delete_res.deleted_count} documents from beam collection") + insert_res = beams_col.insert_many(json_data) count = beams_col.count_documents({}) - log.info("Done loading") - log.info(f"Total documents: {count}") + logger.info("Done loading") + logger.info(f"Total documents: {count}") + return insert_res + + +def get_catalogue(epoch: int = 0) -> Table: + """Get the RACS catalogue for a given epoch + + Args: + epoch (int, optional): Epoch number. Defaults to 0. -def get_catalogue(verbose=True): + Returns: + Table: RACS catalogue table. + + """ survey_dir = pkg_resources.resource_filename("spiceracs", "askap_surveys") - basedir = os.path.join(survey_dir, "racs", "db", "epoch_0") + basedir = os.path.join(survey_dir, "racs", "db", f"epoch_{epoch}") beamfiles = glob(os.path.join(basedir, "beam_inf*")) # Init first field @@ -205,14 +245,24 @@ def get_catalogue(verbose=True): tab.add_column(int(SBID), name="SBID", index=0) racs_fields = vstack([racs_fields, tab]) except TypeError: - log.warning(f"{SBID} failed...") + logger.warning(f"{SBID} failed...") continue return racs_fields -def get_beams(mastercat, database, verbose=True): +def get_beams(mastercat: Table, database: Table) -> List[Dict]: + """Get beams from the master catalogue + + Args: + mastercat (Table): Master catalogue table. + database (Table): RACS database table. + + Returns: + List[Dict]: List of beam dictionaries. + + """ # Get seperations on sky - seps = cat2beams(mastercat, database, max_sep=1, verbose=verbose) + seps = cat2beams(mastercat, database, max_sep=1) vals, ixs = ndix_unique(seps[1]) # Get DR1 fields @@ -225,9 +275,7 @@ def get_beams(mastercat, database, verbose=True): beam_list = [] for i, (val, idx) in enumerate( - tqdm( - zip(vals, ixs), total=len(vals), desc="Getting beams", disable=(not verbose) - ) + tqdm(zip(vals, ixs), total=len(vals), desc="Getting beams") ): beam_dict = {} ra = mastercat[val]["RA"] @@ -253,88 +301,133 @@ def get_beams(mastercat, database, verbose=True): "Source_Name": name, "Source_ID": isl_id, "n_fields": len(beam_dict.keys()), - "n_fields_DR1": sum([val["DR1"] for val in beam_dict.values()]), + "n_fields_DR1": np.sum([val["DR1"] for val in beam_dict.values()]), "beams": beam_dict, } ) return beam_list -def field_database(host, username, password, verbose=True): +def field_database( + host: str, username: Union[str, None], password: Union[str, None], epoch: int = 0 +) -> InsertManyResult: + """Reset and load the field database + + Args: + host (str): Mongo host + username (Union[str, None]): Mongo username + password (Union[str, None]): Mongo password + epoch (int, optional): RACS epoch number. Defaults to 0. + + Returns: + InsertManyResult: Field insert object. + """ survey_dir = pkg_resources.resource_filename("spiceracs", "askap_surveys") - basedir = os.path.join(survey_dir, "racs", "db", "epoch_0") + basedir = os.path.join(survey_dir, "racs", "db", f"epoch_{epoch}") data_file = os.path.join(basedir, "field_data.csv") database = Table.read(data_file) df = database.to_pandas() field_list_dict = df.to_dict("records") - log.info("Loading fields into mongo...") + logger.info("Loading fields into mongo...") field_col = get_field_db(host, username=username, password=password) - field_col.delete_many({}) - field_col.insert_many(field_list_dict) + delete_res = field_col.delete_many({}) + logger.warning(f"Deleted documents: {delete_res.deleted_count}") + insert_res = field_col.insert_many(field_list_dict) count = field_col.count_documents({}) - log.info("Done loading") - log.info(f"Total documents: {count}") - - -def main(args, verbose=True): + logger.info("Done loading") + logger.info(f"Total documents: {count}") + + return insert_res + + +def main( + load: bool = False, + islandcat: Union[str, None] = None, + compcat: Union[str, None] = None, + host: str = "localhost", + username: Union[str, None] = None, + password: Union[str, None] = None, + field: bool = False, + epoch: int = 0, + force: bool = False, +) -> None: """Main script - Arguments: - args -- commandline args + Args: + load (bool, optional): Load the database. Defaults to False. + islandcat (Union[str, None], optional): Island catalogue. Defaults to None. + compcat (Union[str, None], optional): Component catalogue. Defaults to None. + host (str, optional): Mongo host. Defaults to "localhost". + username (Union[str, None], optional): Mongo username. Defaults to None. + password (Union[str, None], optional): Mongo password. Defaults to None. + field (bool, optional): Load the field database. Defaults to False. + epoch (int, optional): RACS epoch to load. Defaults to 0. + force (bool, optional): Force overwrite of database. Defaults to False. + + Raises: + ValueError: If load is True and islandcat or compcat are None. + """ + if force: + logger.critical("This will overwrite the database! ALL data will be lost!") + logger.critical("Sleeping for 30 seconds in case you want to cancel...") + time.sleep(30) + logger.critical("Continuing...you have been warned!") - if args.load: + if load: # Get database from master cat - if args.islandcat is None: - log.critical("Island catalogue is required!") + if islandcat is None: + logger.critical("Island catalogue is required!") islandcat = input("Enter catalogue file:") - else: - islandcat = args.islandcat - if args.compcat is None: - log.critical("Component catalogue is required!") + if compcat is None: + logger.critical("Component catalogue is required!") compcat = input("Enter catalogue file:") - else: - compcat = args.compcat # Get the master cat - log.info(f"Reading {islandcat}") + logger.info(f"Reading {islandcat}") island_cat = Table.read(islandcat) - log.info(f"Reading {compcat}") + logger.info(f"Reading {compcat}") comp_cat = Table.read(compcat) - log.critical("This will overwrite the source database!") - check_source = yes_or_no("Are you sure you wish to proceed?") - log.critical("This will overwrite the beams database!") - check_beam = yes_or_no("Are you sure you wish to proceed?") + logger.critical("This will overwrite the source database!") + check_source = ( + yes_or_no("Are you sure you wish to proceed?") if not force else True + ) + logger.critical("This will overwrite the beams database!") + check_beam = ( + yes_or_no("Are you sure you wish to proceed?") if not force else True + ) if check_source: source_database( islandcat=island_cat, compcat=comp_cat, - host=args.host, - username=args.username, - password=args.password, - verbose=verbose, + host=host, + username=username, + password=password, ) if check_beam: beam_database( islandcat=island_cat, - host=args.host, - username=args.username, - password=args.password, - verbose=verbose, + host=host, + username=username, + password=password, + epoch=epoch, ) - if args.field: - log.critical("This will overwrite the field database!") - check_field = yes_or_no("Are you sure you wish to proceed?") + if field: + logger.critical("This will overwrite the field database!") + check_field = ( + yes_or_no("Are you sure you wish to proceed?") if not force else True + ) if check_field: - field_database( - host=args.host, - username=args.username, - password=args.password, - verbose=verbose, + field_res = field_database( + host=host, + username=username, + password=password, ) else: - log.info("Nothing to do!") + logger.info("Nothing to do!") + + logger.info("Done!") def cli(): @@ -380,11 +473,11 @@ def cli(): ) parser.add_argument( - "--username", type=str, default=None, help="Username of mongodb." + "-u", "--username", type=str, default=None, help="Username of mongodb." ) parser.add_argument( - "--password", type=str, default=None, help="Password of mongodb." + "-p", "--password", type=str, default=None, help="Password of mongodb." ) parser.add_argument( @@ -404,32 +497,37 @@ def cli(): ) parser.add_argument( + "-f", "--field", action="store_true", help="Load field table into database [False].", ) + parser.add_argument( + "-e", + "--epoch", + type=int, + default=0, + help="RACS epoch to load [0].", + ) + args = parser.parse_args() - verbose = args.verbose - if verbose: - log.basicConfig( - level=log.INFO, - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) - else: - log.basicConfig( - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) - test_db( - host=args.host, username=args.username, password=args.password, verbose=verbose - ) + if args.verbose: + logger.setLevel(logging.INFO) + + test_db(host=args.host, username=args.username, password=args.password) - main(args, verbose=verbose) + main( + load=args.load, + islandcat=args.islandcat, + compcat=args.compcat, + host=args.host, + username=args.username, + password=args.password, + field=args.field, + epoch=args.epoch, + ) if __name__ == "__main__": diff --git a/spiceracs/linmos.py b/spiceracs/linmos.py index fbbaf260..dfd8e10d 100644 --- a/spiceracs/linmos.py +++ b/spiceracs/linmos.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 """Run LINMOS on cutouts in parallel""" import ast -import logging as log +import logging import os import shlex import subprocess @@ -11,7 +11,7 @@ from glob import glob from logging import disable from pprint import pformat -from typing import List, Tuple +from typing import List, Tuple, Union import astropy import astropy.units as u @@ -29,6 +29,7 @@ from spectral_cube.utils import SpectralCubeWarning from spython.main import Client as sclient +from spiceracs.logger import logger from spiceracs.utils import chunk_dask, coord_to_string, get_db, test_db, tqdm_dask warnings.filterwarnings(action="ignore", category=SpectralCubeWarning, append=True) @@ -116,7 +117,7 @@ def genparset( stoke: str, datadir: str, septab: Table, - holofile: str = None, + holofile: Union[str, None] = None, ) -> str: """Generate parset for LINMOS @@ -186,7 +187,7 @@ def genparset( """ if holofile is not None: - log.info(f"Using holography file {holofile} -- setting removeleakge to true") + logger.info(f"Using holography file {holofile} -- setting removeleakge to true") parset += f""" linmos.primarybeam = ASKAP_PB @@ -194,7 +195,7 @@ def genparset( linmos.removeleakage = true """ else: - log.warning("No holography file provided - not correcting leakage!") + logger.warning("No holography file provided - not correcting leakage!") with open(parset_file, "w") as f: f.write(parset) @@ -253,7 +254,7 @@ def linmos(parset: str, fieldname: str, image: str, verbose=False) -> pymongo.Up inner = os.path.basename(new_file) new_file = os.path.join(outer, inner) - log.info(f"Cube now in {workdir}/{inner}") + logger.info(f"Cube now in {workdir}/{inner}") query = {"Source_ID": source} newvalues = {"$set": {f"beams.{fieldname}.{stoke.lower()}_file": new_file}} @@ -280,11 +281,11 @@ def main( datadir: str, client: Client, host: str, - holofile: str = None, - username: str = None, - password: str = None, + holofile: Union[str, None] = None, + username: Union[str, None] = None, + password: Union[str, None] = None, yanda="1.3.0", - stokeslist: List[str] = None, + stokeslist: Union[List[str], None] = None, verbose=True, ) -> None: """Main script @@ -322,7 +323,7 @@ def main( beams_col, island_col, comp_col = get_db( host=host, username=username, password=password ) - log.debug(f"{beams_col = }") + logger.debug(f"{beams_col = }") # Query the DB query = { "$and": [{f"beams.{field}": {"$exists": True}}, {f"beams.{field}.DR1": True}] @@ -378,11 +379,11 @@ def main( ) updates = [f.compute() for f in futures] - log.info("Updating database...") + logger.info("Updating database...") db_res = beams_col.bulk_write(updates, ordered=False) - log.info(pformat(db_res.bulk_api_result)) + logger.info(pformat(db_res.bulk_api_result)) - log.info("LINMOS Done!") + logger.info("LINMOS Done!") def cli(): diff --git a/spiceracs/logger.py b/spiceracs/logger.py new file mode 100644 index 00000000..a9cd7d25 --- /dev/null +++ b/spiceracs/logger.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +"""Logging module for spiceracs""" + +import logging + +# Create logger +logging.captureWarnings(True) +logger = logging.getLogger("spiceracs") +logger.setLevel(logging.WARNING) + +# Create console handler and set level to debug +ch = logging.StreamHandler() +ch.setLevel(logging.DEBUG) + + +# Create formatter +# formatter = logging.Formatter( +# "SPICE: %(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s" +# ) +class CustomFormatter(logging.Formatter): + grey = "\x1b[38;20m" + blue = "\x1b[34;20m" + green = "\x1b[32;20m" + yellow = "\x1b[33;20m" + red = "\x1b[31;20m" + bold_red = "\x1b[31;1m" + reset = "\x1b[0m" + format_str = "%(asctime)s.%(msecs)03d %(module)s - %(funcName)s: %(message)s" + + FORMATS = { + logging.DEBUG: f"{blue}SPICE-%(levelname)s{reset} {format_str}", + logging.INFO: f"{green}SPICE-%(levelname)s{reset} {format_str}", + logging.WARNING: f"{yellow}SPICE-%(levelname)s{reset} {format_str}", + logging.ERROR: f"{red}SPICE-%(levelname)s{reset} {format_str}", + logging.CRITICAL: f"{bold_red}SPICE-%(levelname)s{reset} {format_str}", + } + + def format(self, record): + log_fmt = self.FORMATS.get(record.levelno) + formatter = logging.Formatter(log_fmt, "%Y-%m-%d %H:%M:%S") + return formatter.format(record) + + +# Add formatter to ch +ch.setFormatter(CustomFormatter()) + +# Add ch to logger +logger.addHandler(ch) diff --git a/spiceracs/makecat.py b/spiceracs/makecat.py index 4306a669..4d5d21b0 100644 --- a/spiceracs/makecat.py +++ b/spiceracs/makecat.py @@ -1,30 +1,198 @@ #!/usr/bin/env python3 """Make a SPICE-RACS catalogue""" -import logging as log +import logging import os import time import warnings +from functools import partial from pprint import pformat -from typing import Callable, Optional, Union +from typing import Callable, Optional, Tuple, TypeVar, Union import astropy.units as u +import dask.dataframe as dd import matplotlib.pyplot as plt import numpy as np import pandas as pd +from astropy.coordinates import SkyCoord from astropy.io import fits from astropy.io import votable as vot from astropy.stats import mad_std, sigma_clip from astropy.table import Column, Table from corner import hist2d +from dask.diagnostics import ProgressBar from IPython import embed from rmtable import RMTable from scipy.stats import lognorm, norm -from tqdm import tqdm, trange +from tqdm import tqdm, tqdm_pandas, trange from vorbin.voronoi_2d_binning import voronoi_2d_binning from spiceracs import columns_possum +from spiceracs.logger import logger from spiceracs.utils import get_db, get_field_db, latexify, test_db +ArrayLike = TypeVar( + "ArrayLike", np.ndarray, pd.Series, pd.DataFrame, SkyCoord, u.Quantity +) +TableLike = TypeVar("TableLike", RMTable, Table) + + +def combinate(data: ArrayLike) -> Tuple[ArrayLike, ArrayLike]: + """Return all combinations of data with itself + + Args: + data (ArrayLike): Data to combine. + + Returns: + Tuple[ArrayLike, ArrayLike]: Data_1 matched with Data_2 + """ + ix, iy = np.triu_indices(data.shape[0], k=1) + idx = np.vstack((ix, iy)).T + dx, dy = data[idx].swapaxes(0, 1) + return dx, dy + + +def flag_blended_components(cat: RMTable) -> RMTable: + """Identify blended components in a catalogue and flag them. + + Args: + cat (RMTable): Input catalogue + + Returns: + RMTable: Output catalogue with minor components flagged + """ + + def is_blended_component(sub_df: pd.DataFrame) -> pd.DataFrame: + """Return a boolean series indicating whether a component is the maximum + component in a source. + + Args: + sub_df (pd.DataFrame): DataFrame containing all components for a source + + Returns: + pd.DataFrame: DataFrame with a boolean column indicating whether a component + is blended and a float column indicating the ratio of the total flux. + + """ + # Skip single-component sources + if any(sub_df.N_Gaus == 1): + is_blended = pd.Series( + [False], + index=sub_df.index, + name="is_blended_flag", + dtype=bool, + ) + n_blended = pd.Series( + [0], + index=sub_df.index, + name="N_blended", + dtype=int, + ) + blend_ratio = pd.Series( + [np.nan], + index=sub_df.index, + name="blend_ratio", + dtype=float, + ) + else: + # Look up all separations between components + # We'll store: + # - is_blended: boolean array indicating whether a component + # is blended + # - n_blended: integer array indicating the number of components + # blended into a component + # - blend_ratio: float array indicating the ratio of the flux of a + # component to the total flux of all blended components + coords = SkyCoord(sub_df.ra, sub_df.dec, unit="deg") + beam = sub_df.beam_maj.max() * u.deg + is_blended_arr = np.zeros_like(sub_df.index, dtype=bool) + n_blended_arr = np.zeros_like(sub_df.index, dtype=int) + blend_ratio_arr = np.ones_like(sub_df.index, dtype=float) * np.nan + for i, coord in enumerate(coords): + seps = coord.separation(coords) + sep_flag = (seps < beam) & (seps > 0 * u.deg) + is_blended_arr[i] = np.any(sep_flag) + n_blended_arr[i] = np.sum(sep_flag) + blend_total_flux = ( + sub_df.total_I_flux[sep_flag].sum() + sub_df.total_I_flux[i] + ) + blend_ratio_arr[i] = sub_df.total_I_flux[i] / blend_total_flux + + is_blended = pd.Series( + is_blended_arr, + index=sub_df.index, + name="is_blended_flag", + dtype=bool, + ) + n_blended = pd.Series( + n_blended_arr, + index=sub_df.index, + name="N_blended", + dtype=int, + ) + blend_ratio = pd.Series( + blend_ratio_arr, + index=sub_df.index, + name="blend_ratio", + dtype=float, + ) + df = pd.DataFrame( + { + "is_blended_flag": is_blended, + "N_blended": n_blended, + "blend_ratio": blend_ratio, + }, + index=sub_df.index, + ) + return df + + df = cat.to_pandas() + df.set_index("cat_id", inplace=True) + ddf = dd.from_pandas(df, chunksize=1000) + grp = ddf.groupby("source_id") + logger.info("Identifying blended components...") + with ProgressBar(): + is_blended = grp.apply( + is_blended_component, + meta={ + "is_blended_flag": bool, + "N_blended": int, + "blend_ratio": float, + }, + ).compute() + is_blended = is_blended.reindex(cat["cat_id"]) + cat.add_column( + Column( + is_blended["is_blended_flag"], + name="is_blended_flag", + dtype=bool, + ), + index=-1, + ) + cat.add_column( + Column( + is_blended["blend_ratio"], + name="blend_ratio", + dtype=float, + ), + index=-1, + ) + cat.add_column( + Column( + is_blended["N_blended"], + name="N_blended", + dtype=int, + ), + index=-1, + ) + # Sanity check - no single-component sources should be flagged + assert np.array_equal(is_blended.index.values, cat["cat_id"].data), "Index mismatch" + assert not any( + cat["is_blended_flag"] & (cat["N_Gaus"] == 1) + ), "Single-component sources cannot be flagged as blended." + if "index" in cat.colnames: + cat.remove_column("index") + return cat + def lognorm_from_percentiles(x1, p1, x2, p2): """Return a log-normal distribuion X parametrized by: @@ -74,7 +242,12 @@ def sigma_add_fix(tab): med[i] = np.nan std[i] = np.nan - tab.add_column(Column(data=med, name="sigma_add")) + tab.add_column( + Column( + data=med, + name="sigma_add", + ) + ) tab.add_column(Column(data=std, name="sigma_add_err")) tab.remove_columns( [ @@ -207,7 +380,7 @@ def compute_local_rm_flag(good_cat: Table, big_cat: Table) -> Table: Returns: Table: Table with local RM flag """ - log.info("Computing voronoi bins and finding bad RMs") + logger.info("Computing voronoi bins and finding bad RMs") def sn_func(index, signal=None, noise=None): try: @@ -229,7 +402,7 @@ def sn_func(index, signal=None, noise=None): quiet=True, wvt=False, ) - log.info(f"Found {len(set(bin_number))} bins") + logger.info(f"Found {len(set(bin_number))} bins") df = good_cat.to_pandas() df.reset_index(inplace=True) df.set_index("cat_id", inplace=True) @@ -260,10 +433,17 @@ def masker(x): "local_rm_flag" ].description = "RM is statistically different from nearby RMs" + # Bring back the units + for col in cat_out.colnames: + if col in big_cat.colnames: + logger.debug(f"Resetting unit for {col}") + cat_out[col].unit = big_cat[col].unit + cat_out.units[col] = big_cat.units[col] + return cat_out -def cuts_and_flags(cat): +def cuts_and_flags(cat: RMTable) -> RMTable: """Cut out bad sources, and add flag columns A flag of 'True' means the source is bad. @@ -286,14 +466,15 @@ def cuts_and_flags(cat): cat.add_column(Column(data=chan_flag, name="channel_flag")) # Stokes I flag - cat["stokesI_fit_flag"] = ( + stokesI_fit_flag = ( cat["stokesI_fit_flag_is_negative"] + cat["stokesI_fit_flag_is_close_to_zero"] + cat["stokesI_fit_flag_is_not_finite"] ) + cat.add_column(Column(data=stokesI_fit_flag, name="stokesI_fit_flag")) # sigma_add flag - sigma_flag = cat["sigma_add"] > 1 + sigma_flag = cat["sigma_add"] > 10 * cat["sigma_add_err"] cat.add_column(Column(data=sigma_flag, name="complex_sigma_add_flag")) # M2_CC flag m2_flag = cat["rm_width"] > cat["rmsf_fwhm"] @@ -308,6 +489,9 @@ def cuts_and_flags(cat): cat_out = compute_local_rm_flag(good_cat=good_cat, big_cat=cat) + # Flag primary components + cat_out = flag_blended_components(cat_out) + # Restre units and metadata for col in cat.colnames: cat_out[col].unit = cat[col].unit @@ -403,7 +587,7 @@ def add_metadata(vo_table: vot.tree.Table, filename: str): # Add params for CASDA if len(vo_table.params) > 0: - log.warning(f"{filename} already has params - not adding") + logger.warning(f"{filename} already has params - not adding") return vo_table _, ext = os.path.splitext(filename) cat_name = ( @@ -455,7 +639,25 @@ def replace_nans(filename: str): # f.write(xml) -def write_votable(rmtab: RMTable, outfile: str) -> None: +def fix_blank_units(rmtab: TableLike) -> TableLike: + """Fix blank units in table + + Args: + rmtab (TableLike): TableLike + """ + for col in rmtab.colnames: + if rmtab[col].unit is None or rmtab[col].unit == u.Unit(""): + rmtab[col].unit = u.Unit("---") + if isinstance(rmtab, RMTable): + rmtab.units[col] = u.Unit("---") + if rmtab[col].unit is None or rmtab[col].unit == u.Unit(""): + rmtab[col].unit = u.Unit("---") + if isinstance(rmtab, RMTable): + rmtab.units[col] = u.Unit("---") + return rmtab + + +def write_votable(rmtab: TableLike, outfile: str) -> None: # Replace bad column names fix_columns = { "catalog": "catalog_name", @@ -465,6 +667,8 @@ def write_votable(rmtab: RMTable, outfile: str) -> None: for col_name, new_name in fix_columns.items(): if col_name in rmtab.colnames: rmtab.rename_column(col_name, new_name) + # Fix blank units + rmtab = fix_blank_units(rmtab) vo_table = vot.from_table(rmtab) vo_table.version = "1.3" vo_table = add_metadata(vo_table, outfile) @@ -496,16 +700,16 @@ def main( beams_col, island_col, comp_col = get_db( host=host, username=username, password=password ) - log.info("Starting beams collection query") + logger.info("Starting beams collection query") tick = time.time() query = { "$and": [{f"beams.{field}": {"$exists": True}}, {f"beams.{field}.DR1": True}] } all_island_ids = sorted(beams_col.distinct("Source_ID", query)) tock = time.time() - log.info(f"Finished beams collection query - {tock-tick:.2f}s") + logger.info(f"Finished beams collection query - {tock-tick:.2f}s") - log.info("Starting component collection query") + logger.info("Starting component collection query") tick = time.time() query = { "$and": [ @@ -526,7 +730,7 @@ def main( comps = list(comp_col.find(query, fields)) tock = time.time() - log.info(f"Finished component collection query - {tock-tick:.2f}s") + logger.info(f"Finished component collection query - {tock-tick:.2f}s") rmtab = RMTable() # type: RMTable # Add items to main cat using RMtable standard @@ -583,10 +787,10 @@ def main( alpha_dict = get_alpha(rmtab) rmtab.add_column(Column(data=alpha_dict["alphas"], name="spectral_index")) rmtab.add_column(Column(data=alpha_dict["alphas_err"], name="spectral_index_err")) - rmtab.add_column(Column(data=alpha_dict["betas"], name="spectral_curvature")) - rmtab.add_column( - Column(data=alpha_dict["betas_err"], name="spectral_curvature_err") - ) + # rmtab.add_column(Column(data=alpha_dict["betas"], name="spectral_curvature")) + # rmtab.add_column( + # Column(data=alpha_dict["betas_err"], name="spectral_curvature_err") + # ) # Add integration time field_col = get_field_db(host=host, username=username, password=password) @@ -634,6 +838,19 @@ def main( if type(rmtab[col][0]) == np.float_: rmtab[col][np.isinf(rmtab[col])] = np.nan + # Convert all mJy to Jy + for col in rmtab.colnames: + if rmtab[col].unit == u.mJy: + logger.debug(f"Converting {col} unit from {rmtab[col].unit} to {u.Jy}") + rmtab[col] = rmtab[col].to(u.Jy) + rmtab.units[col] = u.Jy + if rmtab[col].unit == u.mJy / u.beam: + logger.debug( + f"Converting {col} unit from {rmtab[col].unit} to {u.Jy / u.beam}" + ) + rmtab[col] = rmtab[col].to(u.Jy / u.beam) + rmtab.units[col] = u.Jy / u.beam + # Verify table rmtab.add_missing_columns() rmtab.verify_standard_strings() @@ -648,18 +865,18 @@ def main( rmtab.verify_ucds() if outfile is None: - log.info(pformat(rmtab)) + logger.info(pformat(rmtab)) if outfile is not None: - log.info(f"Writing {outfile} to disk") + logger.info(f"Writing {outfile} to disk") _, ext = os.path.splitext(outfile) if ext == ".xml" or ext == ".vot": write_votable(rmtab, outfile) else: rmtab.write(outfile, overwrite=True) - log.info(f"{outfile} written to disk") + logger.info(f"{outfile} written to disk") - log.info("Done!") + logger.info("Done!") def cli(): @@ -736,23 +953,10 @@ def cli(): verbose = args.verbose if verbose: - log.basicConfig( - level=log.INFO, - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) - else: - log.basicConfig( - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) + logger.setLevel(logging.INFO) host = args.host - test_db( - host=args.host, username=args.username, password=args.password, verbose=verbose - ) + test_db(host=args.host, username=args.username, password=args.password) main( field=args.field, diff --git a/spiceracs/merge_fields.py b/spiceracs/merge_fields.py index 478f6fe0..df013682 100644 --- a/spiceracs/merge_fields.py +++ b/spiceracs/merge_fields.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 """Merge multiple RACS fields""" -import logging as log +import logging import os import time from pprint import pformat, pprint from shutil import copyfile -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Union import pymongo from dask import delayed, distributed @@ -13,6 +13,7 @@ from tqdm import tqdm from spiceracs.linmos import get_yanda, linmos +from spiceracs.logger import logger from spiceracs.utils import chunk_dask, get_db, test_db, tqdm_dask, try_mkdir @@ -180,7 +181,6 @@ def merge_multiple_fields( merge_name: str, image: str, ) -> list: - # Find all islands with the given fields that overlap another field query = { "$or": [ @@ -215,13 +215,12 @@ def main( output_dir: str, client: Client, host: str, - username: str = None, - password: str = None, + username: Union[str, None] = None, + password: Union[str, None] = None, yanda="1.3.0", verbose: bool = True, ) -> str: - - log.debug(f"{fields=}") + logger.debug(f"{fields=}") assert len(fields) == len( field_dirs @@ -281,12 +280,12 @@ def main( m._doc["$set"].update({f"beams.{merge_name}.DR1": True}) db_res_single = beams_col.bulk_write(singleton_comp, ordered=False) - log.info(pformat(db_res_single.bulk_api_result)) + logger.info(pformat(db_res_single.bulk_api_result)) db_res_multiple = beams_col.bulk_write(multiple_comp, ordered=False) - log.info(pformat(db_res_multiple.bulk_api_result)) + logger.info(pformat(db_res_multiple.bulk_api_result)) - log.info("LINMOS Done!") + logger.info("LINMOS Done!") return inter_dir diff --git a/spiceracs/process_region.py b/spiceracs/process_region.py index 9c238cc4..9a500c62 100644 --- a/spiceracs/process_region.py +++ b/spiceracs/process_region.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """SPICE-RACS multi-field pipeline""" -import logging as log +import logging import os from time import sleep @@ -19,6 +19,7 @@ from prefect.engine.executors import DaskExecutor from spiceracs import merge_fields, process_spice +from spiceracs.logger import logger from spiceracs.utils import port_forward, test_db @@ -80,7 +81,7 @@ def main(args: configargparse.Namespace) -> None: cluster = SLURMCluster( **config, ) - log.debug(f"Submitted scripts will look like: \n {cluster.job_script()}") + logger.debug(f"Submitted scripts will look like: \n {cluster.job_script()}") # Request 15 nodes cluster.scale(jobs=15) @@ -91,12 +92,11 @@ def main(args: configargparse.Namespace) -> None: host=args.host, username=args.username, password=args.password, - verbose=args.verbose, ) args_yaml = yaml.dump(vars(args)) args_yaml_f = os.path.abspath(f"{args.merge_name}-config-{Time.now().fits}.yaml") - log.info(f"Saving config to '{args_yaml_f}'") + logger.info(f"Saving config to '{args_yaml_f}'") with open(args_yaml_f, "w") as f: f.write(args_yaml) @@ -108,7 +108,7 @@ def main(args: configargparse.Namespace) -> None: port_forward(port, p) # Prin out Dask client info - log.info(client.scheduler_info()["services"]) + logger.info(client.scheduler_info()["services"]) # Define flow inter_dir = os.path.join(os.path.abspath(args.output_dir), args.merge_name) with Flow(f"SPICE-RACS: {args.merge_name}") as flow: @@ -479,25 +479,9 @@ def cli(): verbose = args.verbose if verbose: - log.basicConfig( - level=log.INFO, - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) + logger.setLevel(logger.INFO) if args.debugger: - log.basicConfig( - level=log.DEBUG, - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) - else: - log.basicConfig( - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) + logger.setLevel(logger.DEBUG) main(args) diff --git a/spiceracs/process_spice.py b/spiceracs/process_spice.py index c350bef2..24bcf341 100644 --- a/spiceracs/process_spice.py +++ b/spiceracs/process_spice.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 """SPICE-RACS single-field pipeline""" -import logging as log +import logging import os import socket from time import sleep @@ -29,6 +29,7 @@ rmclean_oncuts, rmsynth_oncuts, ) +from spiceracs.logger import logger from spiceracs.utils import port_forward, test_db @@ -210,7 +211,7 @@ def main(args: configargparse.Namespace) -> None: cluster = SLURMCluster( **config, ) - log.debug(f"Submitted scripts will look like: \n {cluster.job_script()}") + logger.debug(f"Submitted scripts will look like: \n {cluster.job_script()}") # Request 15 nodes cluster.scale(jobs=15) @@ -221,12 +222,11 @@ def main(args: configargparse.Namespace) -> None: host=args.host, username=args.username, password=args.password, - verbose=args.verbose, ) args_yaml = yaml.dump(vars(args)) args_yaml_f = os.path.abspath(f"{args.field}-config-{Time.now().fits}.yaml") - log.info(f"Saving config to '{args_yaml_f}'") + logger.info(f"Saving config to '{args_yaml_f}'") with open(args_yaml_f, "w") as f: f.write(args_yaml) @@ -238,7 +238,7 @@ def main(args: configargparse.Namespace) -> None: port_forward(port, p) # Prin out Dask client info - log.info(client.scheduler_info()["services"]) + logger.info(client.scheduler_info()["services"]) # Define flow with Flow(f"SPICE-RACS: {args.field}") as flow: @@ -652,18 +652,7 @@ def cli(): verbose = args.verbose if verbose: - log.basicConfig( - level=log.INFO, - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) - else: - log.basicConfig( - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) + logger.setLevel(logger.INFO) main(args) diff --git a/spiceracs/rmclean_oncuts.py b/spiceracs/rmclean_oncuts.py index 414201b8..5144a520 100644 --- a/spiceracs/rmclean_oncuts.py +++ b/spiceracs/rmclean_oncuts.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 """Run RM-synthesis on cutouts in parallel""" import json -import logging as log +import logging import os import sys import time @@ -9,6 +9,7 @@ from glob import glob from pprint import pformat from shutil import copyfile +from typing import List, Union import dask import matplotlib.pyplot as plt @@ -27,6 +28,7 @@ from spectral_cube import SpectralCube from tqdm import tqdm, trange +from spiceracs.logger import logger from spiceracs.utils import MyEncoder, chunk_dask, get_db, getfreq, test_db, tqdm_dask @@ -60,9 +62,8 @@ def rmclean1d( iname = comp["Source_ID"] cname = comp["Gaussian_ID"] - log.debug(f"Working on {comp}") + logger.debug(f"Working on {comp}") try: - rm1dfiles = comp["rm1dfiles"] fdfFile = os.path.join(outdir, f"{rm1dfiles['FDF_dirty']}") rmsfFile = os.path.join(outdir, f"{rm1dfiles['RMSF']}") @@ -73,9 +74,9 @@ def rmclean1d( # Sanity checks for f in [weightFile, fdfFile, rmsfFile, rmSynthFile]: - log.debug(f"Checking {os.path.abspath(f)}") + logger.debug(f"Checking {os.path.abspath(f)}") if not os.path.exists(f): - log.fatal("File does not exist: '{:}'.".format(f)) + logger.fatal("File does not exist: '{:}'.".format(f)) sys.exit() nBits = 32 mDict, aDict = do_RMclean_1D.readFiles( @@ -131,8 +132,8 @@ def rmclean1d( }, } except KeyError: - log.critical("Failed to load data! RM-CLEAN not applied to component!") - log.critical(f"Island is {iname}, component is {cname}") + logger.critical("Failed to load data! RM-CLEAN not applied to component!") + logger.critical(f"Island is {iname}, component is {cname}") myquery = {"Gaussian_ID": cname} newvalues = { @@ -216,14 +217,14 @@ def main( outdir: str, host: str, client: Client, - username: str = None, - password: str = None, + username: Union[str, None] = None, + password: Union[str, None] = None, dimension="1d", verbose=True, database=False, savePlots=True, validate=False, - limit: int = None, + limit: Union[int, None] = None, cutoff: float = -3, maxIter=10000, gain=0.1, @@ -309,7 +310,7 @@ def main( outputs = [] if dimension == "1d": - log.info(f"Running RM-CLEAN on {n_comp} components") + logger.info(f"Running RM-CLEAN on {n_comp} components") for i, comp in enumerate(tqdm(components, total=n_comp)): if i > n_comp + 1: break @@ -328,7 +329,7 @@ def main( outputs.append(output) elif dimension == "3d": - log.info(f"Running RM-CLEAN on {n_island} islands") + logger.info(f"Running RM-CLEAN on {n_island} islands") for i, island in enumerate(islands): if i > n_island + 1: @@ -352,15 +353,15 @@ def main( ) if database: - log.info("Updating database...") + logger.info("Updating database...") updates = [f.compute() for f in futures] if dimension == "1d": db_res = comp_col.bulk_write(updates, ordered=False) - log.info(pformat(db_res.bulk_api_result)) + logger.info(pformat(db_res.bulk_api_result)) elif dimension == "3d": db_res = island_col.bulk_write(updates, ordered=False) - log.info(pformat(db_res.bulk_api_result)) - log.info("RM-CLEAN done!") + logger.info(pformat(db_res.bulk_api_result)) + logger.info("RM-CLEAN done!") def cli(): @@ -504,26 +505,13 @@ def cli(): ) if rmv: - log.basicConfig( - level=log.DEBUG, - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, + logger.setLevel( + level=logger.DEBUG, ) elif verbose: - log.basicConfig( - level=log.INFO, - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, - ) - else: - log.basicConfig( - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - force=True, + logger.setLevel( + level=logger.INFO, ) - main( field=args.field, outdir=args.outdir, diff --git a/spiceracs/rmsynth_oncuts.py b/spiceracs/rmsynth_oncuts.py index 8f1696b9..77fa4b21 100644 --- a/spiceracs/rmsynth_oncuts.py +++ b/spiceracs/rmsynth_oncuts.py @@ -2,7 +2,7 @@ """Run RM-CLEAN on cutouts in parallel""" import functools import json -import logging as log +import logging import os import pdb import subprocess @@ -38,6 +38,7 @@ from spectral_cube import SpectralCube from tqdm import tqdm, trange +from spiceracs.logger import logger from spiceracs.utils import ( MyEncoder, chunk_dask, @@ -103,7 +104,7 @@ def rmsynthoncut3d( dataI = np.squeeze(dataI) if np.isnan(dataI).all() or np.isnan(dataQ).all() or np.isnan(dataU).all(): - log.critical(f"Cubelet {iname} is entirely NaN") + logger.critical(f"Cubelet {iname} is entirely NaN") myquery = {"Source_ID": iname} badvalues = { "$set": { @@ -314,7 +315,7 @@ def rmsynthoncut1d( dataI = np.squeeze(dataI) if np.isnan(dataI).all() or np.isnan(dataQ).all() or np.isnan(dataU).all(): - log.critical(f"Entire data is NaN for {iname}") + logger.critical(f"Entire data is NaN for {iname}") myquery = {"Gaussian_ID": cname} badvalues = {"$set": {"rmsynth1d": False}} return pymongo.UpdateOne(myquery, badvalues) @@ -378,13 +379,13 @@ def rmsynthoncut1d( amplitude = tt0_p x_0 = mfs_head["RESTFREQ"] - log.debug(f"alpha is {alpha}") + logger.debug(f"alpha is {alpha}") model_I = models.PowerLaw1D(amplitude=amplitude, x_0=x_0, alpha=alpha) modStokesI = model_I(freq) model_repr = model_I.__repr__() elif do_own_fit: - log.debug(f"Doing own fit") + logger.debug(f"Doing own fit") fit_dict = fit_pl(freq=freq, flux=iarr, fluxerr=rmsi, nterms=abs(polyOrd)) alpha = None amplitude = None @@ -400,7 +401,7 @@ def rmsynthoncut1d( modStokesI = None if np.sum(np.isfinite(qarr)) < 2 or np.sum(np.isfinite(uarr)) < 2: - log.critical(f"{cname} QU data is all NaNs.") + logger.critical(f"{cname} QU data is all NaNs.") myquery = {"Gaussian_ID": cname} badvalues = {"$set": {"rmsynth1d": False}} return pymongo.UpdateOne(myquery, badvalues) @@ -410,7 +411,7 @@ def rmsynthoncut1d( data = [np.array(freq), iarr, qarr, uarr, rmsi, rmsq, rmsu] if np.isnan(iarr).all(): - log.critical(f"{cname} I data is all NaNs.") + logger.critical(f"{cname} I data is all NaNs.") myquery = {"Gaussian_ID": cname} badvalues = {"$set": {"rmsynth1d": False}} return pymongo.UpdateOne(myquery, badvalues) @@ -418,7 +419,7 @@ def rmsynthoncut1d( # Run 1D RM-synthesis on the spectra np.savetxt(f"{prefix}.dat", np.vstack(data).T, delimiter=" ") try: - log.debug(f"Using {fit_function} to fit Stokes I") + logger.debug(f"Using {fit_function} to fit Stokes I") mDict, aDict = do_RMsynth_1D.run_rmsynth( data=data, polyOrd=polyOrd, @@ -723,7 +724,6 @@ def main( ion: bool = False, do_own_fit: bool = False, ) -> None: - outdir = os.path.abspath(outdir) outdir = os.path.join(outdir, "cutouts") @@ -792,7 +792,7 @@ def main( outputs = [] if validate: - log.info(f"Running RMsynth on {n_comp} components") + logger.info(f"Running RMsynth on {n_comp} components") # We don't run this in parallel! for i, comp_id in enumerate(component_ids): output = rmsynthoncut_i( @@ -811,7 +811,7 @@ def main( output.compute() elif dimension == "1d": - log.info(f"Running RMsynth on {n_comp} components") + logger.info(f"Running RMsynth on {n_comp} components") for i, (_, comp) in tqdm( enumerate(components.iterrows()), total=n_comp, @@ -848,7 +848,7 @@ def main( outputs.append(output) elif dimension == "3d": - log.info(f"Running RMsynth on {n_island} islands") + logger.info(f"Running RMsynth on {n_island} islands") for i, island_id in enumerate(island_ids): if i > n_island + 1: @@ -881,18 +881,18 @@ def main( ) if database: - log.info("Updating database...") + logger.info("Updating database...") updates = [f.compute() for f in futures] # Remove None values updates = [u for u in updates if u is not None] - log.info("Sending updates to database...") + logger.info("Sending updates to database...") if dimension == "1d": db_res = comp_col.bulk_write(updates, ordered=False) - log.info(pformat(db_res.bulk_api_result)) + logger.info(pformat(db_res.bulk_api_result)) elif dimension == "3d": db_res = island_col.bulk_write(updates, ordered=False) - log.info(pformat(db_res.bulk_api_result)) - log.info("RMsynth done!") + logger.info(pformat(db_res.bulk_api_result)) + logger.info("RMsynth done!") def cli(): @@ -1100,29 +1100,16 @@ def cli(): verbose = args.verbose rmv = args.rm_verbose if rmv: - log.basicConfig( - level=log.DEBUG, - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) + logger.setLevel(logger.DEBUG) elif verbose: - log.basicConfig( - level=log.INFO, - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - else: - log.basicConfig( - format="%(asctime)s.%(msecs)03d %(levelname)s %(module)s - %(funcName)s: %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) + logger.setLevel(logger.INFO) cluster = LocalCluster( # n_workers=12, processes=True, threads_per_worker=1, local_directory="/dev/shm" ) client = Client(cluster) - log.debug(client) + logger.debug(client) test_db( host=args.host, username=args.username, password=args.password, verbose=verbose diff --git a/spiceracs/utils.py b/spiceracs/utils.py index 80d63c22..14fe658a 100644 --- a/spiceracs/utils.py +++ b/spiceracs/utils.py @@ -3,7 +3,7 @@ import dataclasses import functools import json -import logging as log +import logging import os import shlex import stat @@ -45,6 +45,8 @@ from tornado.ioloop import IOLoop from tqdm.auto import tqdm, trange +from spiceracs.logger import logger + warnings.filterwarnings(action="ignore", category=SpectralCubeWarning, append=True) warnings.simplefilter("ignore", category=AstropyWarning) @@ -68,10 +70,10 @@ def chi_squared(model: np.ndarray, data: np.ndarray, error: np.ndarray) -> float def best_aic_func(aics: np.ndarray, n_param: np.ndarray) -> Tuple[float, int, int]: """Find the best AIC for a set of AICs using Occam's razor.""" # Find the best AIC - best_aic_idx = np.nanargmin(aics) - best_aic = aics[best_aic_idx] - best_n = n_param[best_aic_idx] - log.debug(f"Lowest AIC is {best_aic}, with {best_n} params.") + best_aic_idx = int(np.nanargmin(aics)) + best_aic = float(aics[best_aic_idx]) + best_n = int(n_param[best_aic_idx]) + logger.debug(f"Lowest AIC is {best_aic}, with {best_n} params.") # Check if lower have diff < 2 in AIC aic_abs_diff = np.abs(aics - best_aic) bool_min_idx = np.zeros_like(aics).astype(bool) @@ -79,14 +81,16 @@ def best_aic_func(aics: np.ndarray, n_param: np.ndarray) -> Tuple[float, int, in potential_idx = (aic_abs_diff[~bool_min_idx] < 2) & ( n_param[~bool_min_idx] < best_n ) - if any(potential_idx): - best_n = np.min(n_param[~bool_min_idx][potential_idx]) - best_aic_idx = np.where(n_param == best_n)[0][0] - best_aic = aics[best_aic_idx] - log.debug( - f"Model within 2 of lowest AIC found. Occam says to take AIC of {best_aic}, with {best_n} params." - ) - return best_aic, best_n, best_aic_idx + if not any(potential_idx): + return best_aic, best_n, best_aic_idx + + bestest_n = int(np.min(n_param[~bool_min_idx][potential_idx])) + bestest_aic_idx = int(np.where(n_param == bestest_n)[0][0]) + bestest_aic = float(aics[bestest_aic_idx]) + logger.debug( + f"Model within 2 of lowest AIC found. Occam says to take AIC of {bestest_aic}, with {bestest_n} params." + ) + return bestest_aic, bestest_n, bestest_aic_idx # Stolen from GLEAM-X - thanks Uncle Timmy! @@ -200,7 +204,7 @@ def fit_pl( absolute_sigma=True, ) except RuntimeError: - log.critical(f"Failed to fit {n}-term power law") + logger.critical(f"Failed to fit {n}-term power law") continue best, covar = fit_res @@ -223,11 +227,11 @@ def fit_pl( # Flag if model is negative is_negative = (model_arr < 0).any() if is_negative: - log.warning(f"Stokes I flag: Model {n} is negative") + logger.warning(f"Stokes I flag: Model {n} is negative") # Flag if model is NaN or Inf is_not_finite = ~np.isfinite(model_arr).all() if is_not_finite: - log.warning(f"Stokes I flag: Model {n} is not finite") + logger.warning(f"Stokes I flag: Model {n} is not finite") # # Flag if model and data are statistically different residuals = flux[goodchan] - model_arr[goodchan] # Assume errors on resdiuals are the same as the data @@ -238,14 +242,14 @@ def fit_pl( ks, pval = normaltest(residuals_norm) is_not_normal = pval < 1e-6 # 1 in a million chance of being unlucky if is_not_normal: - log.warning( + logger.warning( f"Stokes I flag: Model {n} is not normally distributed - {pval=}, {ks=}" ) # Test if model is close to 0 within 1 sigma is_close_to_zero = (model_arr[goodchan] / fluxerr[goodchan] < 1).any() if is_close_to_zero: - log.warning(f"Stokes I flag: Model {n} is close (1sigma) to 0") + logger.warning(f"Stokes I flag: Model {n} is close (1sigma) to 0") fit_flag = { "is_negative": is_negative, "is_not_finite": is_not_finite, @@ -253,14 +257,14 @@ def fit_pl( "is_close_to_zero": is_close_to_zero, } save_dict[n]["fit_flags"] = fit_flag - log.debug(f"{n}: {aic}") + logger.debug(f"{n}: {aic}") # Now find the best model best_aic, best_n, best_aic_idx = best_aic_func( np.array([save_dict[n]["aics"] for n in range(nterms + 1)]), np.array([n for n in range(nterms + 1)]), ) - log.debug(f"Best fit: {best_n}, {best_aic}") + logger.debug(f"Best fit: {best_n}, {best_aic}") best_p = save_dict[best_n]["params"] best_e = save_dict[best_n]["errors"] best_m = save_dict[best_n]["models"] @@ -288,7 +292,7 @@ def fit_pl( chi_sq_red=chi_sq_red, ) except Exception as e: - log.critical(f"Failed to fit power law: {e}") + logger.critical(f"Failed to fit power law: {e}") return dict( best_n=np.nan, best_p=[np.nan], @@ -334,9 +338,9 @@ def chunk_dask( futures = client.persist(outputs_chunk) # dumb solution for https://github.com/dask/distributed/issues/4831 if i == 0: - log.debug("I sleep!") + logger.debug("I sleep!") time.sleep(10) - log.debug("I awake!") + logger.debug("I awake!") tqdm_dask(futures, desc=progress_text, disable=(not verbose)) chunk_outputs.extend(futures) return chunk_outputs @@ -374,8 +378,8 @@ def latexify(fig_width=None, fig_height=None, columns=1): MAX_HEIGHT_INCHES = 8.0 if fig_height > MAX_HEIGHT_INCHES: - print( - "WARNING: fig_height too large:" + logger.waning( + "fig_height too large:" + fig_height + "so will reduce to" + MAX_HEIGHT_INCHES @@ -399,7 +403,9 @@ def latexify(fig_width=None, fig_height=None, columns=1): matplotlib.rcParams.update(params) -def delayed_to_da(list_of_delayed: List[Delayed], chunk: int = None) -> da.Array: +def delayed_to_da( + list_of_delayed: List[Delayed], chunk: Union[int, None] = None +) -> da.Array: """Convert list of delayed arrays to a dask array Args: @@ -519,8 +525,8 @@ def coord_to_string(coord: SkyCoord) -> Tuple[str, str]: def test_db( - host: str, username: str = None, password: str = None, verbose=True -) -> None: + host: str, username: Union[str, None] = None, password: Union[str, None] = None +) -> bool: """Test connection to MongoDB Args: @@ -529,10 +535,14 @@ def test_db( password (str, optional): Mongo password. Defaults to None. verbose (bool, optional): Verbose output. Defaults to True. + + Returns: + bool: True if connection succesful + Raises: Exception: If connection fails. """ - log.info("Testing MongoDB connection...") + logger.info("Testing MongoDB connection...") # default connection (ie, local) with pymongo.MongoClient( host=host, @@ -545,8 +555,10 @@ def test_db( dbclient.list_database_names() except pymongo.errors.ServerSelectionTimeoutError: raise Exception("Please ensure 'mongod' is running") - else: - log.info("MongoDB connection succesful!") + + logger.info("MongoDB connection succesful!") + + return True def get_db( @@ -576,9 +588,7 @@ def get_db( return beams_col, island_col, comp_col -def get_field_db( - host: str, username=None, password=None -) -> pymongo.collection.Collection: +def get_field_db(host: str, username=None, password=None) -> Collection: """Get MongoDBs Args: @@ -646,7 +656,7 @@ def port_forward(port: int, target: str) -> None: port (int): port to forward target (str): Target host """ - log.info(f"Forwarding {port} from localhost to {target}") + logger.info(f"Forwarding {port} from localhost to {target}") cmd = f"ssh -N -f -R {port}:localhost:{port} {target}" command = shlex.split(cmd) output = subprocess.Popen(command) @@ -662,9 +672,9 @@ def try_mkdir(dir_path: str, verbose=True): # Create output dir if it doesn't exist try: os.mkdir(dir_path) - log.info(f"Made directory '{dir_path}'.") + logger.info(f"Made directory '{dir_path}'.") except FileExistsError: - log.info(f"Directory '{dir_path}' exists.") + logger.info(f"Directory '{dir_path}' exists.") def try_symlink(src: str, dst: str, verbose=True): @@ -678,9 +688,9 @@ def try_symlink(src: str, dst: str, verbose=True): # Create output dir if it doesn't exist try: os.symlink(src, dst) - log.info(f"Made symlink '{dst}'.") + logger.info(f"Made symlink '{dst}'.") except FileExistsError: - log.info(f"Symlink '{dst}' exists.") + logger.info(f"Symlink '{dst}' exists.") def head2dict(h: fits.Header) -> Dict[str, Any]: @@ -752,7 +762,9 @@ def cpu_to_use(max_cpu: int, count: int) -> int: return np.max(factors_arr[factors_arr <= max_cpu]) -def getfreq(cube: str, outdir: str = None, filename: str = None): +def getfreq( + cube: str, outdir: Union[str, None] = None, filename: Union[str, None] = None +): """Get list of frequencies from FITS data. Gets the frequency list from a given cube. Can optionally save @@ -790,7 +802,7 @@ def getfreq(cube: str, outdir: str = None, filename: str = None): outfile = f"{outdir}/frequencies.txt" else: outfile = f"{outdir}/{filename}" - log.info(f"Saving to {outfile}") + logger.info(f"Saving to {outfile}") np.savetxt(outfile, np.array(freq)) return freq, outfile # Type: Tuple[u.Quantity, str] @@ -811,7 +823,7 @@ def gettable(tabledir: str, keyword: str, verbose=True) -> Tuple[Table, str]: # Glob out the necessary files files = glob(f"{tabledir}/*.{keyword}*.xml") # Selvay VOTab filename = files[0] - log.info(f"Getting table data from {filename}...") + logger.info(f"Getting table data from {filename}...") # Get selvay data from VOTab table = Table.read(filename, format="votable") @@ -858,8 +870,8 @@ def getdata(cubedir="./", tabledir="./", mapdata=None, verbose=True): i_tab, voisle = gettable(tabledir, "islands", verbose=verbose) # Selvay VOTab components, tablename = gettable(tabledir, "components", verbose=verbose) - log.info(f"Getting spectral data from: {cubes}\n") - log.info(f"Getting source location data from: {selavyfits}\n") + logger.info(f"Getting spectral data from: {cubes}\n") + logger.info(f"Getting source location data from: {selavyfits}\n") # Read data using Spectral cube i_taylor = SpectralCube.read(selavyfits, mode="denywrite") diff --git a/submit/casda_prepare.sh b/submit/casda_prepare.sh index dc630570..f865c50f 100644 --- a/submit/casda_prepare.sh +++ b/submit/casda_prepare.sh @@ -10,14 +10,14 @@ #SBATCH --ntasks=1000 #SBATCH --ntasks-per-node=10 ##SBATCH --time=0-00:45:00 # For cut -#SBATCH --time=0-00:10:00 # For test -##SBATCH --time=0-01:45:00 # For full +##SBATCH --time=0-00:10:00 # For test +#SBATCH --time=0-01:45:00 # For full -# conda activate spice +prep_type=full conda activate spice data_dir=/group/askap/athomson/projects/spiceracs/DR1/full_spica polcat=/group/askap/athomson/projects/spiceracs/DR1/spice-racs.dr1.corrected.xml -prep_type=test + cd $data_dir srun -n $SLURM_NTASKS casda_prepare.py $data_dir $polcat $prep_type --convert-spectra --convert-cubes --convert-plots -v --mpi --batch_size 10_000 diff --git a/tests/data/gaussians_RACS_1237+12A_3165.fits b/tests/data/gaussians_RACS_1237+12A_3165.fits new file mode 100644 index 00000000..f6ac013a --- /dev/null +++ b/tests/data/gaussians_RACS_1237+12A_3165.fits @@ -0,0 +1,9 @@ +SIMPLE = T / conforms to FITS standard BITPIX = 8 / array data type NAXIS = 0 / number of array dimensions EXTEND = T END XTENSION= 'BINTABLE' / binary table extension BITPIX = 8 / array data type NAXIS = 2 / number of array dimensions NAXIS1 = 300 / length of dimension 1 NAXIS2 = 3 / length of dimension 2 PCOUNT = 0 / number of group parameters GCOUNT = 1 / number of groups TFIELDS = 35 / number of table fields TTYPE1 = 'Gaussian_ID' TFORM1 = '19A ' TTYPE2 = 'Source_ID' TFORM2 = '19A ' TTYPE3 = 'Tile_ID ' TFORM3 = '13A ' TTYPE4 = 'SBID ' TFORM4 = 'K ' TTYPE5 = 'Obs_Start_Time' TFORM5 = 'D ' TTYPE6 = 'N_Gaus ' TFORM6 = 'K ' TTYPE7 = 'RA ' TFORM7 = 'D ' TUNIT7 = 'deg ' TTYPE8 = 'Dec ' TFORM8 = 'D ' TUNIT8 = 'deg ' TTYPE9 = 'E_RA ' TFORM9 = 'D ' TUNIT9 = 'arcsec ' TTYPE10 = 'E_Dec ' TFORM10 = 'D ' TUNIT10 = 'arcsec ' TTYPE11 = 'Total_flux_Gaussian' TFORM11 = 'D ' TUNIT11 = 'mJy ' TTYPE12 = 'E_Total_flux_Gaussian_PyBDSF' TFORM12 = 'D ' TUNIT12 = 'mJy ' TTYPE13 = 'E_Total_flux_Gaussian' TFORM13 = 'D ' TUNIT13 = 'mJy ' TTYPE14 = 'Total_flux_Source' TFORM14 = 'D ' TUNIT14 = 'mJy ' TTYPE15 = 'E_Total_flux_Source_PyBDSF' TFORM15 = 'D ' TUNIT15 = 'mJy ' TTYPE16 = 'E_Total_flux_Source' TFORM16 = 'D ' TUNIT16 = 'mJy ' TTYPE17 = 'Peak_flux' TFORM17 = 'D ' TUNIT17 = 'beam-1 mJy' TTYPE18 = 'E_Peak_flux' TFORM18 = 'D ' TUNIT18 = 'beam-1 mJy' TTYPE19 = 'Maj ' TFORM19 = 'D ' TUNIT19 = 'arcsec ' TTYPE20 = 'E_Maj ' TFORM20 = 'D ' TUNIT20 = 'arcsec ' TTYPE21 = 'Min ' TFORM21 = 'D ' TUNIT21 = 'arcsec ' TTYPE22 = 'E_Min ' TFORM22 = 'D ' TUNIT22 = 'arcsec ' TTYPE23 = 'PA ' TFORM23 = 'D ' TUNIT23 = 'deg ' TTYPE24 = 'E_PA ' TFORM24 = 'D ' TUNIT24 = 'deg ' TTYPE25 = 'DC_Maj ' TFORM25 = 'D ' TUNIT25 = 'arcsec ' TTYPE26 = 'E_DC_Maj' TFORM26 = 'D ' TUNIT26 = 'arcsec ' TTYPE27 = 'DC_Min ' TFORM27 = 'D ' TUNIT27 = 'arcsec ' TTYPE28 = 'E_DC_Min' TFORM28 = 'D ' TUNIT28 = 'arcsec ' TTYPE29 = 'DC_PA ' TFORM29 = 'D ' TUNIT29 = 'deg ' TTYPE30 = 'E_DC_PA ' TFORM30 = 'D ' TUNIT30 = 'deg ' TTYPE31 = 'S_Code ' TFORM31 = '1A ' TTYPE32 = 'Separation_Tile_Centre' TFORM32 = 'D ' TUNIT32 = 'deg ' TTYPE33 = 'Noise ' TFORM33 = 'D ' TUNIT33 = 'beam-1 mJy' TTYPE34 = 'Gal_lon ' TFORM34 = 'D ' TUNIT34 = 'deg ' TTYPE35 = 'Gal_lat ' TFORM35 = 'D ' TUNIT35 = 'deg ' EXTNAME = 'RACS_DR1_Gaussians_GalacticCut_v2021_08' TUCD1 = 'meta.id;meta.main' TCOMM2 = 'Source ID as a combination of RACS tile ID and PyBDSF Src_ID' TUCD2 = 'meta.id.parent' TCOMM3 = 'Tile ID in which source was found' TUCD3 = 'obs.field' TCOMM4 = 'Scheduling block associated with observation' TUCD4 = 'obs.sequence' TCOMM5 = 'Start time of observation (MJD)' TUCD5 = 'time.start;obs' TCOMM6 = 'Number of Gaussian components used to fit the source' TUCD6 = 'meta.number' TCOMM7 = 'J2000 right ascension in decimal degrees' TUCD7 = 'pos.eq.ra;meta.main' TCOMM8 = 'J2000 declination in decimal degrees' TUCD8 = 'pos.eq.dec;meta.main' TCOMM9 = 'Error in right ascension' TUCD9 = 'stat.error;pos.eq.ra' TCOMM10 = 'Error in declination' TUCD10 = 'stat.error;pos.eq.dec' TCOMM11 = 'Total flux density of Gaussian component' TUCD11 = 'phot.flux.density;em.radio;stat.fit' TUCD12 = 'stat.error;phot.flux.density;em.radio;stat.fit' TCOMM13 = 'Error in total flux density of Gaussian component' TUCD13 = 'stat.error;phot.flux.density;em.radio;stat.fit' TCOMM14 = 'Total flux density of source' TUCD14 = 'phot.flux.density;em.radio;stat.fit' TUCD15 = 'stat.error;phot.flux.density;em.radio;stat.fit' TCOMM16 = 'Error in total flux density of source' TUCD16 = 'stat.error;phot.flux.density;em.radio;stat.fit' TCOMM17 = 'Modelled peak flux density for Gaussian component' TUCD17 = 'phot.flux.density;stat.max;em.radio;stat.fit' TCOMM18 = 'Error in modelled peak flux density for Gaussian component' TUCD18 = 'stat.error;phot.flux.density;stat.max;em.radio;stat.fit' TCOMM19 = 'FWHM major axis before deconvolution' TUCD19 = 'phys.angSize.smajAxis;em.radio;stat.fit' TCOMM20 = 'Error in major axis before deconvolution' TUCD20 = 'stat.error;phys.angSize.smajAxis;em.radio' TCOMM21 = 'FWHM minor axis before deconvolution' TUCD21 = 'phys.angSize.sminAxis;em.radio;stat.fit' TCOMM22 = 'Error in minor axis before deconvolution' TUCD22 = 'stat.error;phys.angSize.sminAxis;em.radio' TCOMM23 = 'Position angle before deconvolution' TUCD23 = 'phys.angSize;pos.posAng;em.radio;stat.fit' TCOMM24 = 'Error in position angle before deconvolution' TUCD24 = 'stat.error;phys.angSize;pos.posAng;em.radio' TCOMM25 = 'FWHM major axis after deconvolution' TUCD25 = 'phys.angSize.smajAxis;em.radio;askap:meta.deconvolved' TCOMM26 = 'Error in major axis after deconvolution' TUCD26 = 'stat.error;phys.angSize.smajAxis;em.radio;askap:meta.deconvolved' TCOMM27 = 'FWHM minor axis after deconvolution' TUCD27 = 'phys.angSize.sminAxis;em.radio;askap:meta.deconvolved' TCOMM28 = 'Error in minor axis after deconvolution' TUCD28 = 'stat.error;phys.angSize.sminAxis;em.radio;askap:meta.deconvolved' TCOMM29 = 'Position angle after deconvolution' TUCD29 = 'phys.angSize;pos.posAng;em.radio;askap:meta.deconvolved' TCOMM30 = 'Error in position angle after deconvolution' TUCD30 = 'stat.error;phys.angSize;pos.posAng;em.radio;askap:meta.deconvolved' TCOMM31 = 'Code indicating single(S), multiple(M) or complex(C) source' TUCD31 = 'meta.code.class' TCOMM32 = 'Angular separation between Gaussian component and tile centre' TUCD32 = 'pos.angDistance' TCOMM33 = 'The rms noise within island boundary from PyBDSF Isl_rms column' TUCD33 = 'stat.stdev;phot.flux.density' TCOMM34 = 'Galactic longitude' TUCD34 = 'pos.galactic.lon' TCOMM35 = 'Galactic latitude' TUCD35 = 'pos.galactic.lat' DATE-HDU= '2022-12-19T02:47:57' STILVERS= '4.0-4 ' STILCLAS= 'uk.ac.starlink.votable.FitsPlusTableWriter' END RACS_1237+12A_3630RACS_1237+12A_3165RACS_1237+12A5@2o5@gv@(zp?Q?(\)@MbM@|G{@?j~AɉnP@@ɓnP@GFffff@glC@C\(\?\(@:Q?У +=p@/8Q?zG@=R?\(@(\)?У +=p@/8Q?zGM?E@fmhr @qYW!@R;RACS_1237+12A_3631RACS_1237+12A_3165RACS_1237+12A5@2o5@gvj@(g†?Q?@`A@z@AɉnP@@ɓnP@%@glC@@fffff?333333@<zG?Q@WS33333?\(@4 +=p +?333333@)?Q@WS33333?\(M?/w@fmhr @q5?|@Rie'_RACS_1237+12A_3632RACS_1237+12A_3165RACS_1237+12A5@2o5@gvwȎz@(Fs?Q?Q@?VQ@;dZ@@JoAɉnP@@ɓnP@ج1'@hQ\(@C\)?\(@@p +=? +=p@Yfffff?(\)@>\(?\(@6 +=p? +=p@Yfffff?(\)M?n@fmhr @q`c'@RRiY_ \ No newline at end of file diff --git a/tests/data/source_RACS_1237+12A_3165.fits b/tests/data/source_RACS_1237+12A_3165.fits new file mode 100644 index 00000000..a6af3a55 --- /dev/null +++ b/tests/data/source_RACS_1237+12A_3165.fits @@ -0,0 +1,4 @@ +SIMPLE = T / conforms to FITS standard BITPIX = 8 / array data type NAXIS = 0 / number of array dimensions EXTEND = T END XTENSION= 'BINTABLE' / binary table extension BITPIX = 8 / array data type NAXIS = 2 / number of array dimensions NAXIS1 = 287 / length of dimension 1 NAXIS2 = 1 / length of dimension 2 PCOUNT = 0 / number of group parameters GCOUNT = 1 / number of groups TFIELDS = 33 / number of table fields TTYPE1 = 'Source_Name' TFORM1 = '28A ' TTYPE2 = 'Source_ID' TFORM2 = '19A ' TTYPE3 = 'Tile_ID ' TFORM3 = '13A ' TTYPE4 = 'SBID ' TFORM4 = 'K ' TNULL4 = 999999 TTYPE5 = 'Obs_Start_Time' TFORM5 = 'D ' TTYPE6 = 'N_Gaus ' TFORM6 = 'K ' TNULL6 = 999999 TTYPE7 = 'RA ' TFORM7 = 'D ' TUNIT7 = 'deg ' TTYPE8 = 'Dec ' TFORM8 = 'D ' TUNIT8 = 'deg ' TTYPE9 = 'E_RA ' TFORM9 = 'D ' TUNIT9 = 'arcsec ' TTYPE10 = 'E_Dec ' TFORM10 = 'D ' TUNIT10 = 'arcsec ' TTYPE11 = 'Total_flux_Source' TFORM11 = 'D ' TUNIT11 = 'mJy ' TTYPE12 = 'E_Total_flux_Source_PyBDSF' TFORM12 = 'D ' TUNIT12 = 'mJy ' TTYPE13 = 'E_Total_flux_Source' TFORM13 = 'D ' TUNIT13 = 'mJy ' TTYPE14 = 'Peak_flux' TFORM14 = 'D ' TUNIT14 = 'beam-1 mJy' TTYPE15 = 'E_Peak_flux' TFORM15 = 'D ' TUNIT15 = 'beam-1 mJy' TTYPE16 = 'Maj ' TFORM16 = 'D ' TUNIT16 = 'arcsec ' TTYPE17 = 'E_Maj ' TFORM17 = 'D ' TUNIT17 = 'arcsec ' TTYPE18 = 'Min ' TFORM18 = 'D ' TUNIT18 = 'arcsec ' TTYPE19 = 'E_Min ' TFORM19 = 'D ' TUNIT19 = 'arcsec ' TTYPE20 = 'PA ' TFORM20 = 'D ' TUNIT20 = 'deg ' TTYPE21 = 'E_PA ' TFORM21 = 'D ' TUNIT21 = 'deg ' TTYPE22 = 'DC_Maj ' TFORM22 = 'D ' TUNIT22 = 'arcsec ' TTYPE23 = 'E_DC_Maj' TFORM23 = 'D ' TUNIT23 = 'arcsec ' TTYPE24 = 'DC_Min ' TFORM24 = 'D ' TUNIT24 = 'arcsec ' TTYPE25 = 'E_DC_Min' TFORM25 = 'D ' TUNIT25 = 'arcsec ' TTYPE26 = 'DC_PA ' TFORM26 = 'D ' TUNIT26 = 'deg ' TTYPE27 = 'E_DC_PA ' TFORM27 = 'D ' TUNIT27 = 'deg ' TTYPE28 = 'S_Code ' TFORM28 = '1A ' TTYPE29 = 'Separation_Tile_Centre' TFORM29 = 'D ' TUNIT29 = 'deg ' TTYPE30 = 'Noise ' TFORM30 = 'D ' TUNIT30 = 'beam-1 mJy' TTYPE31 = 'Gal_lon ' TFORM31 = 'D ' TUNIT31 = 'deg ' TTYPE32 = 'Gal_lat ' TFORM32 = 'D ' TUNIT32 = 'deg ' TTYPE33 = 'Flag_Close' TFORM33 = '2A ' ID = 'RACS_DR1_Sources_GalacticCut_v2021_08' NAME = 'RACS_DR1_Sources_GalacticCut_v2021_08' END RACS-DR1 J123049.4+122322RACS_1237+12A_3165RACS_1237+12A5@2o5@gv9@(s} ??zG{AɉnP@@ɓnP@?|@fmhr @IGz?ə@@b\(?@]fffff?=p +=@F`?ə@50 +=q?@]fffff?=p +=M?K]c@fmhr @qzKUh@RW$- \ No newline at end of file diff --git a/tests/function_tests.py b/tests/function_tests.py deleted file mode 100644 index 906311d4..00000000 --- a/tests/function_tests.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Tests for functions.""" - -import unittest - - -class UnitTests(unittest.TestCase): - def test_somthing(self): - self.assertEqual(0,0) - -if __name__ == '__main__': - unittest.main() \ No newline at end of file diff --git a/tests/init_test.py b/tests/init_test.py new file mode 100644 index 00000000..59ce9097 --- /dev/null +++ b/tests/init_test.py @@ -0,0 +1,98 @@ +# """Tests for SPICE init.""" + +# import unittest +# import subprocess as sp +# import socket +# import time +# from typing import Tuple, Union +# import os +# import shutil + +# import pymongo +# from spiceracs.logger import logger + +# logger.setLevel("DEBUG") + +# from spiceracs.init_database import ( +# main, +# source2beams, +# ndix_unique, +# cat2beams, +# source_database, +# beam_database, +# get_catalogue, +# get_beams, +# field_database, +# ) + +# SOURCE_CAT = "data/source_RACS_1237+12A_3165.fits" +# GAUSS = "data/gaussians_RACS_1237+12A_3165.fits" +# EPOCH = "data/spiceracs_test" +# FIELD = "data/RACS_1237+12A" + +# def start_mongodb(port: int = 27017) -> Tuple[str, int]: +# """Start a local MongoDB instance.""" +# return +# if os.path.exists("data/testdb"): +# shutil.rmtree("data/testdb") + +# os.makedirs("data/testdb") + +# cmd = f"mongod --dbpath data/testdb --port {port} --fork --auth --logpath data/testdb/mongodb.log" +# logger.debug("Starting mongo...") +# sp.run(cmd.split(), check=True) +# logger.debug("Mongo started. Sleeping for 5 seconds...") +# time.sleep(5) +# logger.debug("Mongo should be ready.") +# host = socket.gethostbyname(socket.gethostname()) +# return host, port + +# def create_mongo_admin( +# host: str, +# port: int=27017, +# username: str="admin", +# password: str="admin", +# ) -> None: + +# """Create an admin user in a local MongoDB instance.""" +# return +# client = pymongo.MongoClient(host, port) +# db = client.admin +# cmd = db.command("createUser", username, pwd=password, roles=["root"]) +# logger.debug(cmd) + + +# def stop_mongodb() -> None: +# """Stop a local MongoDB instance.""" +# return +# cmd = f"mongod --dbpath data/testdb --shutdown" +# logger.debug("Stopping mongo...") +# p = sp.Popen(cmd.split(),) +# logger.debug("Mongo stopped.") + +# class TestInit(unittest.TestCase): +# host, port = start_mongodb() +# create_mongo_admin(host, port) + + +# def test_main(self): +# main( +# load=True, +# islandcat=SOURCE_CAT, +# compcat=GAUSS, +# host=self.host, +# username="admin", +# password="admin", +# field=FIELD, +# epoch=EPOCH, +# force=True, +# ) +# assert True + + +# logger.debug("BOOP") + +# stop_mongodb() + +# if __name__ == '__main__': +# unittest.main() diff --git a/tests/unit_test.py b/tests/unit_test.py new file mode 100644 index 00000000..d1d1870a --- /dev/null +++ b/tests/unit_test.py @@ -0,0 +1,40 @@ +"""Tests for functions.""" + +import unittest + +from spiceracs.rmsynth_oncuts import ( + rmsynthoncut3d, + rms_1d, + estimate_noise_annulus, + rmsynthoncut1d, + rmsynthoncut_i, +) + +# Test functions within spiceracs.rmsyth_oncuts + +class TestRmsynthOncuts(unittest.TestCase): + """Test rmsynth_oncuts functions.""" + + def test_rmsynthoncut3d(self): + """Test rmsynthoncut3d.""" + pass + + def test_rms_1d(self): + """Test rms_1d.""" + pass + + def test_estimate_noise_annulus(self): + """Test estimate_noise_annulus.""" + pass + + def test_rmsynthoncut1d(self): + """Test rmsynthoncut1d.""" + pass + + def test_rmsynthoncut_i(self): + """Test rmsynthoncut_i.""" + pass + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file