Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable MPI to run on arbitrary number of baselines #16

Merged
merged 6 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,7 @@ dmypy.json

# Pyre type checker
.pyre/

.idea/
results*/
test-data/
5 changes: 1 addition & 4 deletions hydra_pspec/pspec.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import numpy as np
import scipy as sp
from scipy.stats import mode
from scipy.signal.windows import blackmanharris as BH
from scipy.stats import invgamma
from scipy.optimize import minimize, Bounds

from multiprocess import Pool, current_process
from . import utils
import os, time
import time


def sample_S(s=None, sk=None, prior=None):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "hydra-pspec"
version = "0.0.0"
dependencies = ["numpy", "scipy", "multiprocess", "pyuvdata", "astropy", "jsonargparse", "matplotlib"]
dependencies = ["numpy", "scipy", "multiprocess", "pyuvdata", "astropy", "jsonargparse", "matplotlib", "mpi4py"]

[tool.setuptools]
py-modules = ['dpss', 'utils', 'oqe', 'pspec', 'lssa']
149 changes: 85 additions & 64 deletions run-hydra-pspec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@

import numpy as np
import scipy
import scipy.special
from pathlib import Path
from pprint import pprint
Expand All @@ -12,7 +11,6 @@
from jsonargparse import ArgumentParser, ActionConfigFile
from jsonargparse.typing import Path_fr, Path_dw
from pyuvdata import UVData
from astropy import units
from astropy.units import Quantity

import hydra_pspec as hp
Expand Down Expand Up @@ -266,6 +264,27 @@ def check_load_path(fp):
data = np.load(fp)
return fp_is_dir, data

def split_data_for_scatter(data: list, n_ranks: int) -> list:
"""Split a list into a list of lists for MPI scattering"""
data_length = len(data)
quot, rem = divmod(data_length, n_ranks)

if quot == 0:
print(f"Error: Number of baselines ({data_length}) should be >= number of MPI ranks ({size})!")
sys.stdout.flush()
comm.Abort()

# determine the size of each sub-task
counts = [quot + 1 if n < rem else quot for n in range(n_ranks)]

# determine the starting and ending indices of each sub-task
starts = [sum(counts[:n]) for n in range(n_ranks)]
ends = [sum(counts[:n + 1]) for n in range(n_ranks)]

# converts data into a list of arrays
scatter_data = [data[starts[n]:ends[n]] for n in range(n_ranks)]
return scatter_data


if rank == 0:
if "config" in args.__dict__:
Expand Down Expand Up @@ -446,70 +465,72 @@ def check_load_path(fp):
if args.noise_cov:
bl_data_weights["N"] = noise_cov
all_data_weights.append(bl_data_weights)
all_data_weights = split_data_for_scatter(all_data_weights, size)
else:
all_data_weights = None

# Send per-baseline visibilities to each process
data = comm.scatter(all_data_weights)
antpair = data["antpair"]
d = data["d"]
w = ~data["w"]
fgmodes = data["fgmodes"]
S_initial = data["S_initial"]
Ninv = data["Ninv"]

# Create a subdirectory in out_dir for each baseline
out_dir = data["out_dir"]
bl_str = f"{antpair[0]}-{antpair[1]}"
out_dir /= bl_str
out_dir.mkdir(exist_ok=True, parents=True)

# Power spectrum prior
# This has shape (2, Ndelays). The first dimension is for the upper and
# lower prior bounds respectively. If the prior for a given delay is
# set to zero, no prior is applied. Otherwise, the solution is restricted
# to be within the range ps_prior[1] < soln < ps_prior[0].
Nfreqs = d.shape[1]
ps_prior = np.zeros((2, Nfreqs))
if args.ps_prior_lo != 0 or args.ps_prior_hi != 0:
ps_prior_inds = slice(
Nfreqs//2 - args.n_ps_prior_bins,
Nfreqs//2 + args.n_ps_prior_bins + 1
)
ps_prior[0, ps_prior_inds] = args.ps_prior_hi
ps_prior[1, ps_prior_inds] = args.ps_prior_lo
list_of_baselines = comm.scatter(all_data_weights)
for data in list_of_baselines:
antpair = data["antpair"]
d = data["d"]
w = ~data["w"]
fgmodes = data["fgmodes"]
S_initial = data["S_initial"]
Ninv = data["Ninv"]

# Create a subdirectory in out_dir for each baseline
out_dir = data["out_dir"]
bl_str = f"{antpair[0]}-{antpair[1]}"
out_dir /= bl_str
out_dir.mkdir(exist_ok=True, parents=True)

if rank == 0:
verbose = args.verbose
else:
verbose = False
if verbose:
print("Printing status messages for:")
print(f"Rank: {rank}")
print(f"Baseline: {antpair}", end="\n\n")

# Run Gibbs sampler
# signal_cr = (Niter, Ntimes, Nfreqs) [complex]
# signal_S = (Nfreqs, Nfreqs) [complex]
# signal_ps = (Niter, Nfreqs) [float]
# fg_amps = (Niter, Ntimes, Nfgmodes) [complex]
start = time.time()
signal_cr, signal_S, signal_ps, fg_amps, chisq, ln_post = \
hp.pspec.gibbs_sample_with_fg(
d,
w[0], # FIXME: add functionality for time-dependent flags
S_initial,
fgmodes,
Ninv,
ps_prior,
Niter=args.Niter,
seed=args.seed,
map_estimate=args.map_estimate,
verbose=verbose,
nproc=args.Nproc,
write_Niter=args.write_Niter,
out_dir=out_dir
)
print(f"Sampling complete!", end="\n\n")
elapsed = time.time() - start
print(f"Time elapsed: {elapsed} s")
# Power spectrum prior
# This has shape (2, Ndelays). The first dimension is for the upper and
# lower prior bounds respectively. If the prior for a given delay is
# set to zero, no prior is applied. Otherwise, the solution is restricted
# to be within the range ps_prior[1] < soln < ps_prior[0].
Nfreqs = d.shape[1]
ps_prior = np.zeros((2, Nfreqs))
if args.ps_prior_lo != 0 or args.ps_prior_hi != 0:
ps_prior_inds = slice(
Nfreqs//2 - args.n_ps_prior_bins,
Nfreqs//2 + args.n_ps_prior_bins + 1
)
ps_prior[0, ps_prior_inds] = args.ps_prior_hi
ps_prior[1, ps_prior_inds] = args.ps_prior_lo

if rank == 0:
verbose = args.verbose
else:
verbose = False
if verbose:
print("Printing status messages for:")
print(f"Rank: {rank}")
print(f"Baseline: {antpair}", end="\n\n")

# Run Gibbs sampler
# signal_cr = (Niter, Ntimes, Nfreqs) [complex]
# signal_S = (Nfreqs, Nfreqs) [complex]
# signal_ps = (Niter, Nfreqs) [float]
# fg_amps = (Niter, Ntimes, Nfgmodes) [complex]
start = time.time()
signal_cr, signal_S, signal_ps, fg_amps, chisq, ln_post = \
hp.pspec.gibbs_sample_with_fg(
d,
w[0], # FIXME: add functionality for time-dependent flags
S_initial,
fgmodes,
Ninv,
ps_prior,
Niter=args.Niter,
seed=args.seed,
map_estimate=args.map_estimate,
verbose=verbose,
nproc=args.Nproc,
write_Niter=args.write_Niter,
out_dir=out_dir
)
print(f"Sampling complete!", end="\n\n")
elapsed = time.time() - start
print(f"Time elapsed: {elapsed} s")