Skip to content

Commit

Permalink
Move packet seeds logic into PacketSource instead of MontecarloTransp…
Browse files Browse the repository at this point in the history
…ort (#2329)

* moved seed and rng to packet_source.py from montecarlotransport

* Remove single_packet_seed usages

* Fix documentation test case

* Fix argument placement

* Make reseed private

* Pass iteration as seed_offset from MontecarloTransport. Don't store an internal counter in packet_source

* Handle dict and hdf conversion

* Handle legacy test case in test_bb_packet_sampling

* Handle legacy test case by adding secondary seed inside packet source

* Changed parameter name from `secondary_seed` to `legacy_second_seed`

* Changed parameter name from `secondary_seed` to `legacy_second_seed`

* Rename seed param to base_seed
  • Loading branch information
xansh authored Jun 20, 2023
1 parent 1cc49a2 commit b5738b3
Show file tree
Hide file tree
Showing 8 changed files with 111 additions and 89 deletions.
16 changes: 8 additions & 8 deletions docs/io/optional/custom_source.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@
" with a truncated Blackbody source.\n",
" \"\"\"\n",
" \n",
" def __init__(self, seed, truncation_wavelength):\n",
" super().__init__(seed)\n",
" self.rng = np.random.default_rng(seed=seed)\n",
" def __init__(self, base_seed, truncation_wavelength):\n",
" super().__init__(base_seed)\n",
" self.rng = np.random.default_rng(seed=base_seed)\n",
" self.truncation_wavelength = truncation_wavelength\n",
" \n",
" def create_packets(self, T, no_of_packets, rng, radius,\n",
" def create_packets(self, T, no_of_packets, radius,\n",
" drawing_sample_size=None):\n",
" \"\"\"\n",
" Packet source that generates a truncated Blackbody source.\n",
Expand All @@ -91,8 +91,8 @@
" radii = np.ones(no_of_packets) * radius\n",
"\n",
" # Use mus and energies from normal blackbody source.\n",
" mus = self.create_zero_limb_darkening_packet_mus(no_of_packets, self.rng)\n",
" energies = self.create_uniform_packet_energies(no_of_packets, self.rng)\n",
" mus = self.create_zero_limb_darkening_packet_mus(no_of_packets)\n",
" energies = self.create_uniform_packet_energies(no_of_packets)\n",
"\n",
" # If not specified, draw 2 times as many packets and reject any beyond no_of_packets.\n",
" if drawing_sample_size is None:\n",
Expand All @@ -104,15 +104,15 @@
" \n",
" # Draw nus from blackbody distribution and reject based on truncation_frequency.\n",
" # If more nus.shape[0] > no_of_packets use only the first no_of_packets.\n",
" nus = self.create_blackbody_packet_nus(T, drawing_sample_size, self.rng)\n",
" nus = self.create_blackbody_packet_nus(T, drawing_sample_size)\n",
" nus = nus[nus<truncation_frequency][:no_of_packets]\n",
" \n",
" \n",
" # Only required if the truncation wavelength is too big compared to the maximum \n",
" # of the blackbody distribution. Keep sampling until nus.shape[0] > no_of_packets.\n",
" while nus.shape[0] < no_of_packets:\n",
" additional_nus = self.create_blackbody_packet_nus(\n",
" T, drawing_sample_size, self.rng\n",
" T, drawing_sample_size\n",
" )\n",
" mask = additional_nus < truncation_frequency\n",
" additional_nus = additional_nus[mask][:no_of_packets]\n",
Expand Down
10 changes: 9 additions & 1 deletion docs/physics/montecarlo/initialization.ipynb
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "6c0dbe0a",
"metadata": {},
Expand Down Expand Up @@ -87,6 +88,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "4ae02998",
"metadata": {},
Expand All @@ -112,6 +114,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "450faf76",
"metadata": {},
Expand All @@ -138,6 +141,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "516633c5",
"metadata": {},
Expand All @@ -158,7 +162,6 @@
"radii, nus, mus, energies = packet_source.create_packets(\n",
" temperature_inner.value, \n",
" n_packets, \n",
" rng, \n",
" r_boundary_inner)\n",
"\n",
"# Sets the energies in units of ergs\n",
Expand All @@ -172,6 +175,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "16936bce",
"metadata": {},
Expand All @@ -196,6 +200,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "0d839222",
"metadata": {},
Expand Down Expand Up @@ -223,6 +228,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "78230177",
"metadata": {},
Expand Down Expand Up @@ -257,6 +263,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "ad4f0f0e",
"metadata": {},
Expand Down Expand Up @@ -297,6 +304,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "2f661592",
"metadata": {},
Expand Down
24 changes: 15 additions & 9 deletions tardis/io/model_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

from tardis.io.config_reader import ConfigurationNameSpace
from tardis.montecarlo.base import MontecarloTransport
from tardis.montecarlo.packet_source import (
BlackBodySimpleSource,
BlackBodySimpleSourceRelativistic,
)
from tardis.util.base import parse_quantity, is_valid_nuclide_or_elem

import warnings
Expand Down Expand Up @@ -578,7 +582,7 @@ def transport_to_dict(transport):
"photo_ion_estimator_statistics": transport.photo_ion_estimator_statistics,
"r_inner": transport.r_inner_cgs,
"r_outer": transport.r_outer_cgs,
"seed": transport.seed,
"packet_source_base_seed": transport.packet_source.base_seed,
"spectrum_frequency_cgs": transport.spectrum_frequency,
"spectrum_method": transport.spectrum_method,
"stim_recomb_cooling_estimator": transport.stim_recomb_cooling_estimator,
Expand All @@ -600,11 +604,6 @@ def transport_to_dict(transport):
"volume_cgs": transport.volume,
}

try:
transport_dict["single_packet_seed"] = transport.single_packet_seed
except AttributeError:
transport_dict["single_packet_seed"] = None

for key, value in transport_dict.items():
if key.endswith("_cgs"):
transport_dict[key] = [value.cgs.value, value.unit.to_string()]
Expand Down Expand Up @@ -682,7 +681,7 @@ def transport_from_hdf(fname):
new_transport : tardis.montecarlo.MontecarloTransport
"""

d = {"single_packet_seed": None}
d = {}

# Loading data from hdf file
with h5py.File(fname, "r") as f:
Expand Down Expand Up @@ -713,6 +712,14 @@ def transport_from_hdf(fname):
if key.endswith("_cgs"):
d[key] = u.Quantity(value[0], unit=u.Unit(value[1].decode("utf-8")))

# Using packet source seed to packet source
if not d["enable_full_relativity"]:
d["packet_source"] = BlackBodySimpleSource(d["packet_source_base_seed"])
else:
d["packet_source"] = BlackBodySimpleSourceRelativistic(
d["packet_source_base_seed"]
)

# Converting virtual spectrum spawn range values to astropy quantities
vssr = d["virtual_spectrum_spawn_range"]
d["virtual_spectrum_spawn_range"] = {
Expand All @@ -729,7 +736,6 @@ def transport_from_hdf(fname):

# Creating a transport object and storing data
new_transport = MontecarloTransport(
seed=d["seed"],
spectrum_frequency=d["spectrum_frequency_cgs"],
virtual_spectrum_spawn_range=d["virtual_spectrum_spawn_range"],
disable_electron_scattering=d["disable_electron_scattering"],
Expand All @@ -740,9 +746,9 @@ def transport_from_hdf(fname):
integrator_settings=d["integrator_settings"],
v_packet_settings=d["v_packet_settings"],
spectrum_method=d["spectrum_method"],
packet_source=d["packet_source"],
nthreads=d["nthreads"],
virtual_packet_logging=d["virt_logging"],
single_packet_seed=d["single_packet_seed"],
use_gpu=d["use_gpu"],
)

Expand Down
19 changes: 7 additions & 12 deletions tardis/io/tests/test_model_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,19 +249,16 @@ def test_transport_to_dict(simulation_verysimple):

# Check transport dictionary
for key, value in transport_dict.items():
if key == "single_packet_seed":
if value is None:
assert key not in transport_data.keys()
else:
assert value == transport_data[key]
elif isinstance(value, np.ndarray):
if isinstance(value, np.ndarray):
if key + "_cgs" in transport_data.keys():
assert np.array_equal(value, transport_data[key + "_cgs"])
else:
assert np.array_equal(value, transport_data[key])
elif isinstance(value, list):
assert np.array_equal(value[0], transport_data[key[:-4]].value)
assert value[1] == transport_data[key[:-4]].unit.to_string()
elif key == "packet_source_base_seed": # Check packet source base seed
assert value == transport_data["packet_source"].base_seed
else:
assert value == transport_data[key]

Expand Down Expand Up @@ -372,7 +369,10 @@ def test_store_transport_to_hdf(simulation_verysimple, tmp_path):
assert np.array_equal(
f["transport/r_outer"], transport_data["r_outer_cgs"]
)
assert f["transport/seed"][()] == transport_data["seed"]
assert (
f["transport/packet_source_base_seed"][()]
== transport_data["packet_source"].base_seed
)
assert np.array_equal(
f["transport/spectrum_frequency_cgs"],
transport_data["spectrum_frequency"].value,
Expand Down Expand Up @@ -442,8 +442,3 @@ def test_store_transport_to_hdf(simulation_verysimple, tmp_path):
assert np.array_equal(
f["transport/volume_cgs"], transport_data["volume"].value
)
if "transport/single_packet_seed" in f:
assert (
f["transport/single_packet_seed"][()]
== transport_data["single_packet_seed"]
)
50 changes: 21 additions & 29 deletions tardis/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,6 @@
logger = logging.getLogger(__name__)


MAX_SEED_VAL = 2**32 - 1

# MAX_SEED_VAL must be multiple orders of magnitude larger than no_of_packets;
# otherwise, each packet would not have its own seed. Here, we set the max
# seed val to the maximum allowed by numpy.
# TODO: refactor this into more parts
class MontecarloTransport(HDFWriterMixin):
"""
Expand Down Expand Up @@ -86,7 +81,6 @@ class MontecarloTransport(HDFWriterMixin):

def __init__(
self,
seed,
spectrum_frequency,
virtual_spectrum_spawn_range,
disable_electron_scattering,
Expand All @@ -97,24 +91,14 @@ def __init__(
integrator_settings,
v_packet_settings,
spectrum_method,
packet_source,
virtual_packet_logging,
nthreads=1,
packet_source=None,
debug_packets=False,
logger_buffer=1,
tracking_rpacket=False,
use_gpu=False,
):
self.seed = seed
if packet_source is None:
if not enable_full_relativity:
self.packet_source = source.BlackBodySimpleSource(seed)
else:
self.packet_source = source.BlackBodySimpleSourceRelativistic(
seed
)
else:
self.packet_source = packet_source
# inject different packets
self.disable_electron_scattering = disable_electron_scattering
self.spectrum_frequency = spectrum_frequency
Expand All @@ -127,7 +111,6 @@ def __init__(
self.integrator_settings = integrator_settings
self.v_packet_settings = v_packet_settings
self.spectrum_method = spectrum_method
self.seed = seed
self._integrator = None
self._spectrum_integrated = None
self.use_gpu = use_gpu
Expand All @@ -142,6 +125,8 @@ def __init__(
self.virt_packet_initial_rs = np.ones(2) * -1.0
self.virt_packet_initial_mus = np.ones(2) * -1.0

self.packet_source = packet_source

# Setting up the Tracking array for storing all the RPacketTracker instances
self.rpacket_tracker = None

Expand Down Expand Up @@ -214,24 +199,22 @@ def _initialize_geometry_arrays(self, model):
def _initialize_packets(
self, temperature, no_of_packets, iteration, radius, time_explosion
):
# the iteration is added each time to preserve randomness
# the iteration (passed as seed_offset) is added each time to preserve randomness
# across different simulations with the same temperature,
# for example. We seed the random module instead of the numpy module
# because we call random.sample, which references a different internal
# state than in the numpy.random module.
seed = self.seed + iteration
rng = np.random.default_rng(seed=seed)
seeds = rng.choice(MAX_SEED_VAL, no_of_packets, replace=True)
# for example.
mc_config_module.packet_seeds = self.packet_source.create_packet_seeds(
no_of_packets, iteration
)

if not self.enable_full_relativity:
radii, nus, mus, energies = self.packet_source.create_packets(
temperature, no_of_packets, rng, radius
temperature, no_of_packets, radius
)
else:
radii, nus, mus, energies = self.packet_source.create_packets(
temperature, no_of_packets, rng, radius, time_explosion
temperature, no_of_packets, radius, time_explosion
)

mc_config_module.packet_seeds = seeds
self.input_r = radii
self.input_nu = nus
self.input_mu = mus
Expand Down Expand Up @@ -686,8 +669,17 @@ def from_config(
config.montecarlo.tracking.initial_array_length
)

if packet_source is None:
if not config.montecarlo.enable_full_relativity:
packet_source = source.BlackBodySimpleSource(
config.montecarlo.seed
)
else:
packet_source = source.BlackBodySimpleSourceRelativistic(
config.montecarlo.seed
)

return cls(
seed=config.montecarlo.seed,
spectrum_frequency=spectrum_frequency,
virtual_spectrum_spawn_range=config.montecarlo.virtual_spectrum_spawn_range,
enable_reflective_inner_boundary=config.montecarlo.enable_reflective_inner_boundary,
Expand Down
3 changes: 1 addition & 2 deletions tardis/montecarlo/montecarlo_numba/numba_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,6 @@ def __init__(
self.photo_ion_estimator_statistics = photo_ion_estimator_statistics

def increment(self, other):

self.j_estimator += other.j_estimator
self.nu_bar_estimator += other.nu_bar_estimator
self.j_blue_estimator += other.j_blue_estimator
Expand Down Expand Up @@ -597,7 +596,7 @@ def configuration_initialize(transport, number_of_vpackets):
montecarlo_configuration.number_of_vpackets = number_of_vpackets
montecarlo_configuration.temporary_v_packet_bins = number_of_vpackets
montecarlo_configuration.full_relativity = transport.enable_full_relativity
montecarlo_configuration.montecarlo_seed = transport.seed
montecarlo_configuration.montecarlo_seed = transport.packet_source.base_seed
montecarlo_configuration.v_packet_spawn_start_frequency = (
transport.virtual_spectrum_spawn_range.end.to(
u.Hz, equivalencies=u.spectral()
Expand Down
Loading

0 comments on commit b5738b3

Please sign in to comment.