Skip to content

Commit

Permalink
code style updates to stellar_wind.py (#816)
Browse files Browse the repository at this point in the history
  • Loading branch information
rieder authored Apr 4, 2023
1 parent 62ba260 commit c6224cc
Showing 1 changed file with 73 additions and 51 deletions.
124 changes: 73 additions & 51 deletions src/amuse/ext/stellar_wind.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,49 +157,52 @@ def __init__(self, *args, **kwargs):
self._private.track_mechanical_energy = False
self._private.new_unset_lmech_particles = False

self._private.attribute_names = set(["lost_mass",
"wind_release_time",
"mu",
"mass",
"radius",
"age",
"temperature",
"luminosity",
"stellar_type",
"x",
"y",
"z",
"vx",
"vy",
"vz",
"wind_mass_loss_rate",
"initial_wind_velocity",
"terminal_wind_velocity",
"mass_loss_type",
])
self._private.defaults = dict(lost_mass=0 | units.MSun,
mass=0 | units.MSun,
radius=0 | units.RSun,
age=0 | units.Myr,
temperature=0 | units.K ,
luminosity = 0 | units.LSun,
stellar_type = 1 | units.stellar_type,
x=0 | units.m,
y=0 | units.m,
z=0 | units.m,
vx=0 | units.ms,
vy=0 | units.ms,
vz=0 | units.ms,
wind_mass_loss_rate=0 | units.MSun/units.yr,
initial_wind_velocity=0 | units.ms,
terminal_wind_velocity=0 | units.ms,
mechanical_energy=0 | units.J,
mass_loss_type="wind",
)
self._private.attribute_names = set(
[
"lost_mass",
"wind_release_time",
"mu",
"mass",
"radius",
"age",
"temperature",
"luminosity",
"stellar_type",
"x",
"y",
"z",
"vx",
"vy",
"vz",
"wind_mass_loss_rate",
"initial_wind_velocity",
"terminal_wind_velocity",
"mass_loss_type",
]
)
self._private.defaults = dict(
lost_mass=0 | units.MSun,
mass=0 | units.MSun,
radius=0 | units.RSun,
age=0 | units.Myr,
temperature=0 | units.K,
luminosity=0 | units.LSun,
stellar_type=1 | units.stellar_type,
x=0 | units.m,
y=0 | units.m,
z=0 | units.m,
vx=0 | units.ms,
vy=0 | units.ms,
vz=0 | units.ms,
wind_mass_loss_rate=0 | units.MSun/units.yr,
initial_wind_velocity=0 | units.ms,
terminal_wind_velocity=0 | units.ms,
mechanical_energy=0 | units.J,
mass_loss_type="wind",
)

self.set_global_mu()


def add_particles(self, particles, *args, **kwargs):
new_particles = super(StarsWithMassLoss, self).add_particles(
particles, *args, **kwargs)
Expand All @@ -221,7 +224,10 @@ def add_particles_to_store(self, keys, attributes=[], values=[]):
attributes.append(attr)

if attr == "wind_release_time":
value = self.collection_attributes.timestamp or 0 | units.yr
value = (
self.collection_attributes.timestamp
or 0 | units.yr
)
elif attr == "previous_age":
if "age" in attributes:
value = values[attributes.index("age")]
Expand All @@ -246,14 +252,17 @@ def add_particles_to_store(self, keys, attributes=[], values=[]):
if attributes[i] in good_attributes]
attributes = [a for a in attributes if a in good_attributes]

super(StarsWithMassLoss, self).add_particles_to_store(keys, attributes, values)
super(StarsWithMassLoss, self).add_particles_to_store(
keys, attributes, values)

def set_values_in_store(self, indices, attributes, list_of_values_to_set):
for attr in attributes:
if attr not in self._private.attribute_names:
raise AttributeError("You tried to set attribute '{0}'"
raise AttributeError(
"You tried to set attribute '{0}'"
" but this attribute is not accepted for this set."
.format(attr))
.format(attr)
)

# TODO
super(StarsWithMassLoss, self).set_values_in_store(
Expand Down Expand Up @@ -297,7 +306,9 @@ def evolve_mass_loss(self, time):

def track_mechanical_energy(self, track=True):
self._private.track_mechanical_energy = track
mech_attrs = set(["mechanical_energy", "previous_mechanical_luminosity"])
mech_attrs = set(
["mechanical_energy", "previous_mechanical_luminosity"]
)
if track:
self._private.attribute_names |= mech_attrs
else:
Expand Down Expand Up @@ -699,7 +710,10 @@ def scaling(self, star):
if star.acc_cutoff is not None:
denominator = denominator - 1./star.acc_cutoff

numerator = star.terminal_wind_velocity**2 - star.initial_wind_velocity**2
numerator = (
star.terminal_wind_velocity**2
- star.initial_wind_velocity**2
)
return 0.5 * numerator / denominator

def acceleration_from_radius(self, r, star):
Expand All @@ -719,7 +733,10 @@ def scaling(self, star):
if star.acc_cutoff is not None:
denominator = denominator - 1./star.acc_cutoff

numerator = star.terminal_wind_velocity**2 - star.initial_wind_velocity**2
numerator = (
star.terminal_wind_velocity**2
- star.initial_wind_velocity**2
)
return 0.5 * numerator / denominator

def fix_acc_start_cutoff(self, r, acc, star):
Expand Down Expand Up @@ -982,8 +999,9 @@ def acceleration(self, star, radii):
if star.acc_cutoff is not None:
i_acc = i_acc & (radii < star.acc_cutoff)


accelerations[i_acc] += self.acc_function.acceleration_from_radius(radii[i_acc], star)
accelerations[i_acc] += self.acc_function.acceleration_from_radius(
radii[i_acc], star
)

if self.compensate_pressure:
if self.staging_radius is not None:
Expand All @@ -996,14 +1014,18 @@ def acceleration(self, star, radii):
if self.compensate_gravity:
if star.grav_acc_cutoff is not None:
indices = radii < star.grav_acc_cutoff
accelerations[indices] += constants.G * star.mass / radii[indices]**2
accelerations[indices] += (
constants.G * star.mass / radii[indices]**2
)
else:
accelerations += constants.G * star.mass / radii**2

if self.staging_radius is not None:
i_stag = radii < star.radius * self.staging_radius
if i_stag.any():
accelerations[i_stag] += self.staging_accelerations(i_stag, radii[i_stag], star)
accelerations[i_stag] += (
self.staging_accelerations(i_stag, radii[i_stag], star)
)

return accelerations

Expand Down

0 comments on commit c6224cc

Please sign in to comment.