Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parametric input in single CNN model #163

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion ctlearn/default_models/cnn_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def cnn_rnn_model(data, model_params):
if pretrained_weights:
loaded_model = tf.keras.models.load_model(pretrained_weights)
for layer in loaded_model.layers:
if layer.name.endswith("_block"):
if layer.name.endswith("_backbone"):
model = loaded_model.get_layer(layer.name)
model.trainable = trainable_backbone
else:
Expand Down
52 changes: 0 additions & 52 deletions ctlearn/default_models/single_cnn.py

This file was deleted.

79 changes: 79 additions & 0 deletions ctlearn/default_models/single_nn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import importlib
import sys

import tensorflow as tf


def single_nn_model(data, model_params):

# Load neural network model
network_input, network_output = [], []
if data.img_shape != None:
network_input_img = tf.keras.Input(shape=data.img_shape, name=f"images")
network_input.append(network_input_img)

if data.prm_shape != None and data.mode == "train":
network_input_prm = tf.keras.Input(shape=data.prm_shape, name=f"parameters")
network_input.append(network_input_prm)

backbone_name = model_params.get("name", "SingleNN") + "_backbone"
trainable_backbone = model_params.get("trainable_backbone", True)
pretrained_weights = model_params.get("pretrained_weights", None)
if pretrained_weights:
loaded_model = tf.keras.models.load_model(pretrained_weights)
for layer in loaded_model.layers:
if "_backbone" in layer.name:
singlenn_model = loaded_model.get_layer(layer.name)
singlenn_model.trainable = trainable_backbone
else:
sys.path.append(model_params["model_directory"])
if data.img_shape != None:
engine_cnn_module = importlib.import_module(
model_params["engine_cnn"]["module"]
)
engine_cnn = getattr(
engine_cnn_module, model_params["engine_cnn"]["function"]
)
# The original ResNet implementation use this padding, but we pad the images in the ImageMapper.
# x = tf.pad(telescope_data, tf.constant([[3, 3], [3, 3]]), name='conv1_pad')
init_layer = model_params.get("init_layer", False)
if init_layer:
network_input_img = tf.keras.layers.Conv2D(
filters=init_layer["filters"],
kernel_size=init_layer["kernel_size"],
strides=init_layer["strides"],
name=backbone_name + "_conv1_conv",
)(network_input_img)
# x = tf.pad(x, tf.constant([[1, 1], [1, 1]]), name='pool1_pad')
init_max_pool = model_params.get("init_max_pool", False)
if init_max_pool:
network_input_img = tf.keras.layers.MaxPool2D(
pool_size=init_max_pool["size"],
strides=init_max_pool["strides"],
name=backbone_name + "_pool1_pool",
)(network_input_img)
engine_output_cnn = engine_cnn(
network_input_img, params=model_params, name=backbone_name
)
network_output = output_cnn = tf.keras.layers.GlobalAveragePooling2D(
name=backbone_name + "_global_avgpool"
)(engine_output_cnn)

if data.prm_shape != None and data.mode == "train":
engine_mlp_module = importlib.import_module(
model_params["engine_mlp"]["module"]
)
engine_mlp = getattr(
engine_mlp_module, model_params["engine_mlp"]["function"]
)
engine_output_mlp = engine_mlp(
network_input_prm, params=model_params, name=backbone_name
)
network_output = output_mlp = tf.keras.layers.Flatten()(engine_output_mlp)
if data.img_shape != None:
network_output = tf.keras.layers.Concatenate()([output_cnn, output_mlp])

singlenn_model = tf.keras.Model(
network_input, network_output, name=backbone_name
)
return singlenn_model, network_input