Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding AU Mic demo #14

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 40 additions & 22 deletions fleck/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def rotation_model(self, f0=0, t0_rot=0, u1=0, u2=0):
spot_model : array
Flux as a function of time and wavelength
"""
u1 = jnp.atleast_1d(u1)
u2 = jnp.atleast_1d(u2)
u_ld = jnp.column_stack([u1, u2])

(
spot_position_x, spot_position_y, spot_position_z,
major_axis, minor_axis, angle, rad, contrast
Expand All @@ -135,22 +139,32 @@ def rotation_model(self, f0=0, t0_rot=0, u1=0, u2=0):
)

radial_coord = 1 - jnp.geomspace(1e-5, 1, 100)[::-1]

unspotted_total_flux = trapezoid(
y=(
2 * np.pi * radial_coord *
self.limb_darkening(radial_coord, u1, u2)
),
2 * np.pi * radial_coord[:, None] *
self.limb_darkening(
radial_coord[:, None], *u_ld.T
)
).T,
x=radial_coord
)

limb_dark = self.limb_darkening(
mu,
u1=u1[None, None, :, None],
u2=u2[None, None, :, None]
)

# Morris 2020 Eqn 6-7
spot_model = f0 - jnp.sum(
np.pi * rad ** 2 *
(1 - contrast) *
self.limb_darkening(mu, u1, u2) *
mask_behind_star,
limb_dark *
mask_behind_star /
unspotted_total_flux[None, None, :, None],
axis=1
) / unspotted_total_flux
)
f_S = rad ** 2 * mu * (spot_position_z < 0).astype(int)

return spot_model, f_S
Expand Down Expand Up @@ -272,24 +286,22 @@ def add_spot(self, lon, lat, rad, contrast=None, temperature=None, spectrum=None
grid is ``ActiveStar.phot``
"""
if contrast is None and spectrum is None and temperature is not None:
self.phot = self._blackbody(self.wavelength, self.T_eff)
if self.phot is None:
self.phot = self._blackbody(self.wavelength, self.T_eff)
spectrum = self._blackbody(self.wavelength, temperature)

for attr, new_value in zip("lon, lat, rad, spectrum, temperature".split(', '),
[lon, lat, rad, spectrum, temperature]):

prop = getattr(self, attr)

if not hasattr(new_value, 'ndim'):
new_value = jnp.array([new_value])
new_value = jnp.atleast_1d(new_value)

if prop is not None:
if prop.ndim > 1 or (len(prop) > 1 and len(prop) == len(new_value)):
if attr == 'spectrum':
if len(prop):
new_value = jnp.vstack([prop, new_value])
else:
new_value = jnp.concatenate([prop, new_value])

setattr(self, attr, new_value)
elif len(prop.shape) and len(new_value.shape):
new_value = jnp.concatenate([prop, new_value])
setattr(self, attr, new_value)

@jit
def _blackbody(self, wavelength_meters, temperature):
Expand Down Expand Up @@ -439,11 +451,17 @@ def transit_model(self, t0, period, rp, a, inclination,
jnp.sum(spot_coverages * spot_spectra, axis=1)
)

# if rp is scalar, turn it into a vector with length == N_wavelengths:
rp = rp * jnp.ones(u1.shape[0])

transit = vmap(
lambda u_ld: jaxoplanet.core.light_curve(
u1=u_ld[0], u2=u_ld[1], b=jnp.hypot(X, Y), r=rp
lambda u_ld, rp: jaxoplanet.core.light_curve(
u1=u_ld[0],
u2=u_ld[1],
b=jnp.hypot(X, Y),
r=rp
), in_axes=0, out_axes=1
)(u_ld)
)(u_ld, rp)

contaminated_transit = (
time_series_spectrum - jnp.abs(transit) * self.phot[None, :]
Expand All @@ -463,15 +481,15 @@ def transit_model(self, t0, period, rp, a, inclination,
spot_position_x - Y[:, None, None, None]
)
occultation_possible = jnp.squeeze(
(planet_spot_distance < (major_axis + rp)) &
(planet_spot_distance < (major_axis + rp.mean())) &
(spot_position_z < 0)
)

@jit
def time_step(
carry, j, X=X, Y=Y, spot_position_y=spot_position_y,
spot_position_x=spot_position_x, major_axis=major_axis,
minor_axis=minor_axis, rp=rp, angle=angle,
minor_axis=minor_axis, rp=rp.mean(), angle=angle,
occultation_possible=occultation_possible
):
return carry, lax.cond(
Expand Down Expand Up @@ -509,7 +527,7 @@ def time_step(

return (
out_of_transit[..., 0] * (contaminated_transit + scaled_occultation),
apparent_rprs2, X, Y,
apparent_rprs2, out_of_transit[..., 0], (contaminated_transit + scaled_occultation),
spectrum_at_transit
)

Expand Down
Loading
Loading