Skip to content

Commit

Permalink
Adds support for galsim rng
Browse files Browse the repository at this point in the history
  • Loading branch information
EiffL committed Apr 27, 2020
1 parent 94418fd commit 196f39a
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 6 deletions.
36 changes: 30 additions & 6 deletions galsim_hub/generative_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self, file_name=None):
self.module = None

self.quantities = []
self.random_variables = []
self.sample_req_params = {}
self.sample_opt_params = {}
self.sample_single_params = []
Expand All @@ -59,30 +60,53 @@ def __init__(self, file_name=None):
self.stamp_size = module.get_attached_message("stamp_size", tf.train.Int64List).value[0]
self.pixel_size = module.get_attached_message("pixel_size", tf.train.FloatList).value[0]
for k in module.get_input_info_dict():
# Check for random variables
if 'random_normal' in k:
self.random_variables.append(k)
continue
# Otherwise add the rest of the conditional variables to the input
self.quantities.append(k)
self.sample_req_params[k] = float

def sample(self, cat, noise=None, rng=None, x_interpolant=None, k_interpolant=None,
pad_factor=4, noise_pad_size=0, gsparams=None,
session_config=None):
pad_factor=4, noise_pad_size=0, gsparams=None, session_config=None):
"""
Samples galaxy images from the model
"""
# If we are sampling for the first time
if self.module is None:
self.module = hub.Module(self.file_name)

self.sess = tf.Session(session_config)
self.sess.run(tf.global_variables_initializer())

self.inputs = {}
for k in self.quantities:
for k in self.quantities+self.random_variables:
tensor_info = self.module.get_input_info_dict()[k]
self.inputs[k] = tf.placeholder(tensor_info.dtype, shape=[None], name=k)
self.inputs[k] = tf.placeholder(tensor_info.dtype, shape=tensor_info.get_shape(), name=k)

self.generated_images = self.module(self.inputs)

# Populate feed dictionary with input data
feed_dict={self.inputs[k]: cat[k] for k in self.quantities}

# If not provided, create a RNG
if rng is None:
rng = galsim.BaseDeviate(rng)
orig_rng = rng.duplicate()

# Look for requested random_variables
if 'random_normal' in self.random_variables:
# Draw a random normal from the galsim RNG
noise_shape = self.module.get_input_info_dict()['random_normal'].get_shape()
noise_shape = [len(cat)] + [noise_shape[i+1].value for i in range(len(noise_shape)-1)]
noise_array = np.empty(np.prod(noise_shape), dtype=float)
gd = galsim.random.GaussianDeviate(rng, sigma=1)
gd.generate(noise_array)
feed_dict[self.inputs['random_normal']] = noise_array.reshape(noise_shape).astype('float32')

# Run the graph
x = self.sess.run(self.generated_images,
feed_dict={self.inputs[k]: cat[k] for k in self.quantities })
x = self.sess.run(self.generated_images, feed_dict=feed_dict)

# Now, we build an InterpolatedImage for each of these
ims = []
Expand Down
24 changes: 24 additions & 0 deletions specifications.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Specifications for GalSim Hub Modules

GalSim Hub modules are based on [TensorFlow Hub](https://www.tensorflow.org/hub) models
with a set of specific inputs and attributes.

GalSim Hub expects the models to produce as an output an **unconvolved** light profile
as a postage stamp of a given size. This postage stamp will then be wrapped as an
InterpolatedImage object to be used within GalSim.

Inputs are optional, if some named inputs are declared in the module they are interpreted as input parameters for the GalSim light profile.

More presicely:
- **Inputs**: Optional tensors of size at most 1d, designated by a keyword.
The following keywords are reserved:
- `random_normal`: tensor. To ensure full reproducibility, we recommend extracting
all random number generation out of the TensorFlow Hub module. GalSim Hub will
automatically recognize the keyword `random_normal` and will use GalSim to
generate an appropriate random number.

- **Outputs**: A single default output is expected, in the form of a ([None, stamp_size, stamp_size]) float32 tensor.

- **Attributes**: The following module attributes are expected
- `stamp_size`: Size of the light profile postage stamp in pixels
- `pixel_size`: Pixel resolution of the model

0 comments on commit 196f39a

Please sign in to comment.