Skip to content

Commit

Permalink
Always return a single NNs model for all replicas, adjust weight gett…
Browse files Browse the repository at this point in the history
…ing and setting accordingly
  • Loading branch information
APJansen committed Jan 9, 2024
1 parent 6f79368 commit c650cf3
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 34 deletions.
59 changes: 42 additions & 17 deletions n3fit/src/n3fit/backends/keras_backend/MetaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
}

NN_PREFIX = "NN"
PREPROCESSING_PREFIX = "preprocessing_factor"
NN_LAYER = "NNs"
PREPROESSING_LAYER = "preprocessing_factor"

# Some keys need to work for everyone
for k, v in optimizers.items():
Expand Down Expand Up @@ -353,14 +354,12 @@ def get_replica_weights(self, i_replica):
dict
dictionary with the weights of the replica
"""
NN_weights = [
tf.Variable(w, name=w.name) for w in self.get_layer(f"{NN_PREFIX}_{i_replica}").weights
]
prepro_weights = [
tf.Variable(w, name=w.name)
for w in get_layer_replica_weights(self.get_layer(PREPROCESSING_PREFIX), i_replica)
]
weights = {NN_PREFIX: NN_weights, PREPROCESSING_PREFIX: prepro_weights}
weights = {}
for layer_type in [NN_LAYER, PREPROESSING_LAYER]:
weights[layer_type] = [
tf.Variable(w, name=w.name)
for w in get_layer_replica_weights(self.get_layer(layer_type), i_replica)
]

return weights

Expand All @@ -378,12 +377,10 @@ def set_replica_weights(self, weights, i_replica=0):
i_replica: int
the replica number to set, defaulting to 0
"""
self.get_layer(f"{NN_PREFIX}_{i_replica}").set_weights(weights[NN_PREFIX])
set_layer_replica_weights(
layer=self.get_layer(PREPROCESSING_PREFIX),
weights=weights[PREPROCESSING_PREFIX],
i_replica=i_replica,
)
for layer_type in [NN_LAYER, PREPROESSING_LAYER]:
set_layer_replica_weights(
layer=self.get_layer(layer_type), weights=weights[layer_type], i_replica=i_replica
)

def split_replicas(self):
"""
Expand Down Expand Up @@ -427,6 +424,25 @@ def load_identical_replicas(self, model_file):
self.set_replica_weights(weights, i_replica)


def stacked_single_replicas(layer):
"""
Check if the layer consists of stacked single replicas (Only happens for NN layers)
Parameters
----------
layer: MetaLayer
the layer to check
Returns
-------
bool
True if the layer consists of stacked single replicas
"""
if not isinstance(layer, MetaModel):
return False
return f"{NN_PREFIX}_0" in [sublayer.name for sublayer in layer.layers]


def get_layer_replica_weights(layer, i_replica: int):
"""
Get the weights for the given single replica `i_replica`,
Expand All @@ -444,13 +460,18 @@ def get_layer_replica_weights(layer, i_replica: int):
weights: list
list of weights for the replica
"""
return [tf.Variable(w[i_replica : i_replica + 1], name=w.name) for w in layer.weights]
if stacked_single_replicas(layer):
weights = layer.get_layer(f"{NN_PREFIX}_{i_replica}").weights
else:
weights = [tf.Variable(w[i_replica : i_replica + 1], name=w.name) for w in layer.weights]

return weights


def set_layer_replica_weights(layer, weights, i_replica: int):
"""
Set the weights for the given single replica `i_replica`,
from a `layer` that has weights for all replicas.
for a `layer` that has weights for all replicas.
Parameters
----------
Expand All @@ -461,6 +482,10 @@ def set_layer_replica_weights(layer, weights, i_replica: int):
i_replica: int
the replica number
"""
if stacked_single_replicas(layer):
layer.get_layer(f"{NN_PREFIX}_{i_replica}").set_weights(weights)
return

full_weights = [w.numpy() for w in layer.weights]
for w_old, w_new in zip(full_weights, weights):
w_old[i_replica : i_replica + 1] = w_new
Expand Down
31 changes: 15 additions & 16 deletions n3fit/src/n3fit/model_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,15 +596,11 @@ def pdfNN_layer_generator(

# Apply NN layers for all replicas to a given input grid
def neural_network_replicas(x, postfix=""):
NNs_x = Lambda(lambda nns: op.stack(nns, axis=1), name=f"NNs{postfix}")(
[nn(x) for nn in nn_replicas]
)
NNs_x = nn_replicas(x)

if subtract_one:
x_eq_1_processed = process_input(layer_x_eq_1)
NNs_x_1 = Lambda(lambda nns: op.stack(nns, axis=1), name=f"NNs{postfix}_x_1")(
[nn(x_eq_1_processed) for nn in nn_replicas]
)
NNs_x_1 = nn_replicas(x_eq_1_processed)
NNs_x = subtract_one_layer([NNs_x, NNs_x_1])

return NNs_x
Expand Down Expand Up @@ -660,11 +656,10 @@ def compute_unnormalized_pdf(x, postfix=""):
if photons:
PDFs = layer_photon(PDFs)

if replica_axis:
pdf_model = MetaModel(model_input, PDFs, name=f"PDFs", scaler=scaler)
else:
pdf_model = MetaModel(model_input, PDFs[:, 0], name=f"PDFs", scaler=scaler)
if not replica_axis:
PDFs = Lambda(lambda pdfs: pdfs[:, 0], name="remove_replica_axis")(PDFs)

pdf_model = MetaModel(model_input, PDFs, name=f"PDFs", scaler=scaler)
return pdf_model


Expand Down Expand Up @@ -709,8 +704,8 @@ def generate_nn(
Returns
-------
nn_replicas: List[MetaModel]
List of MetaModel objects, one for each replica.
nn_replicas: MetaModel
Single model containing all replicas.
"""
nodes_list = list(nodes) # so we can modify it
x_input = Input(shape=(None, nodes_in), batch_size=1, name='xgrids_processed')
Expand All @@ -734,7 +729,7 @@ def initializer_generator(seed, i_layer):
]
return initializers

elif layer_type == "dense":
else: # "dense"
reg = regularizer_selector(regularizer, **regularizer_args)
custom_args['regularizer'] = reg

Expand Down Expand Up @@ -772,16 +767,20 @@ def initializer_generator(seed, i_layer):

# Apply all layers to the input to create the models
pdfs = [layer(x_input) for layer in list_of_pdf_layers[0]]

for layers in list_of_pdf_layers[1:]:
# Since some layers (dropout) are shared, we have to treat them separately
if type(layers) is list:
pdfs = [layer(x) for layer, x in zip(layers, pdfs)]
else:
pdfs = [layers(x) for x in pdfs]

models = [
MetaModel({'NN_input': x_input}, pdf, name=f"NN_{i_replica}")
# Wrap the pdfs in a MetaModel to enable getting/setting of weights later
pdfs = [
MetaModel({'NN_input': x_input}, pdf, name=f"NN_{i_replica}")(x_input)
for i_replica, pdf in enumerate(pdfs)
]
pdfs = Lambda(lambda nns: op.stack(nns, axis=1), name=f"stack_replicas")(pdfs)
model = MetaModel({'NN_input': x_input}, pdfs, name=f"NNs")

return models
return model
2 changes: 1 addition & 1 deletion n3fit/src/n3fit/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def _model_generation(self, xinput, pdf_model, partition, partition_idx):
training.summary()
pdf_model = training.get_layer("PDFs")
pdf_model.summary()
nn_model = pdf_model.get_layer("NN_0")
nn_model = pdf_model.get_layer("NNs")
nn_model.summary()
# We may have fits without sumrules imposed
try:
Expand Down

0 comments on commit c650cf3

Please sign in to comment.