Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge prefactors into single layer #1881

Merged
merged 1 commit into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file modified n3fit/runcards/examples/developing_weights.h5
Binary file not shown.
25 changes: 14 additions & 11 deletions n3fit/src/n3fit/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from n3fit.backends.keras_backend.internal_state import (
set_initial_state,
clear_backend_state,
set_eager
)
from n3fit.backends.keras_backend import callbacks, constraints, operations
from n3fit.backends.keras_backend.MetaLayer import MetaLayer
from n3fit.backends.keras_backend.MetaModel import MetaModel
from n3fit.backends.keras_backend.MetaModel import (
NN_LAYER_ALL_REPLICAS,
NN_PREFIX,
PREPROCESSING_LAYER_ALL_REPLICAS,
MetaModel,
)
from n3fit.backends.keras_backend.base_layers import (
Concatenate,
Input,
concatenate,
Lambda,
base_layer_selector,
concatenate,
regularizer_selector,
Concatenate,
)
from n3fit.backends.keras_backend import operations
from n3fit.backends.keras_backend import constraints
from n3fit.backends.keras_backend import callbacks
from n3fit.backends.keras_backend.internal_state import (
clear_backend_state,
set_eager,
set_initial_state,
)

print("Using Keras backend")
140 changes: 85 additions & 55 deletions n3fit/src/n3fit/backends/keras_backend/MetaModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
}

NN_PREFIX = "NN"
PREPROCESSING_PREFIX = "preprocessing_factor"
NN_LAYER_ALL_REPLICAS = "all_NNs"
PREPROCESSING_LAYER_ALL_REPLICAS = "preprocessing_factor"

# Some keys need to work for everyone
for k, v in optimizers.items():
Expand Down Expand Up @@ -156,7 +157,7 @@ def perform_fit(self, x=None, y=None, epochs=1, **kwargs):
of the model (the loss functions) to the partial losses.

If the model was compiled with input and output data, they will not be passed through.
In this case by default the number of `epochs` will be set to 1
In this case by default the number of ``epochs`` will be set to 1

ex:
{'loss': [100], 'dataset_a_loss1' : [67], 'dataset_2_loss': [33]}
Expand Down Expand Up @@ -228,7 +229,7 @@ def compile(
):
"""
Compile the model given an optimizer and a list of loss functions.
The optimizer must be one of those implemented in the `optimizer` attribute of this class.
The optimizer must be one of those implemented in the ``optimizer`` attribute of this class.

Options:
- A learning rate and a list of target outpout can be defined.
Expand Down Expand Up @@ -353,14 +354,10 @@ def get_replica_weights(self, i_replica):
dict
dictionary with the weights of the replica
"""
NN_weights = [
tf.Variable(w, name=w.name) for w in self.get_layer(f"{NN_PREFIX}_{i_replica}").weights
]
prepro_weights = [
tf.Variable(w, name=w.name)
for w in self.get_layer(f"{PREPROCESSING_PREFIX}_{i_replica}").weights
]
weights = {NN_PREFIX: NN_weights, PREPROCESSING_PREFIX: prepro_weights}
weights = {}
for layer_type in [NN_LAYER_ALL_REPLICAS, PREPROCESSING_LAYER_ALL_REPLICAS]:
layer = self.get_layer(layer_type)
weights[layer_type] = get_layer_replica_weights(layer, i_replica)

return weights

Expand All @@ -378,10 +375,9 @@ def set_replica_weights(self, weights, i_replica=0):
i_replica: int
the replica number to set, defaulting to 0
"""
self.get_layer(f"{NN_PREFIX}_{i_replica}").set_weights(weights[NN_PREFIX])
self.get_layer(f"{PREPROCESSING_PREFIX}_{i_replica}").set_weights(
weights[PREPROCESSING_PREFIX]
)
for layer_type in [NN_LAYER_ALL_REPLICAS, PREPROCESSING_LAYER_ALL_REPLICAS]:
layer = self.get_layer(layer_type)
set_layer_replica_weights(layer=layer, weights=weights[layer_type], i_replica=i_replica)

def split_replicas(self):
"""
Expand Down Expand Up @@ -411,51 +407,85 @@ def load_identical_replicas(self, model_file):
"""
From a single replica model, load the same weights into all replicas.
"""
weights = self._format_weights_from_file(model_file)
single_replica = self.single_replica_generator()
single_replica.load_weights(model_file)
weights = single_replica.get_replica_weights(0)

for i_replica in range(self.num_replicas):
self.set_replica_weights(weights, i_replica)

def _format_weights_from_file(self, model_file):
"""Read weights from a .h5 file and format into a dictionary of tf.Variables"""
weights = {}

with h5py.File(model_file, 'r') as f:
# look at layers of the form NN_i and take the lowest i
i_replica = 0
while f"{NN_PREFIX}_{i_replica}" not in f:
i_replica += 1
def is_stacked_single_replicas(layer):
"""
Check if the layer consists of stacked single replicas (Only happens for NN layers),
to determine how to extract single replica weights.

weights[NN_PREFIX] = self._extract_weights(
f[f"{NN_PREFIX}_{i_replica}"], NN_PREFIX, i_replica
)
weights[PREPROCESSING_PREFIX] = self._extract_weights(
f[f"{PREPROCESSING_PREFIX}_{i_replica}"], PREPROCESSING_PREFIX, i_replica
)
Parameters
----------
layer: MetaLayer
the layer to check

return weights
Returns
-------
bool
True if the layer consists of stacked single replicas
"""
if not isinstance(layer, MetaModel):
return False
return f"{NN_PREFIX}_0" in [sublayer.name for sublayer in layer.layers]


def get_layer_replica_weights(layer, i_replica: int):
"""
Get the weights for the given single replica ``i_replica``,
from a ``layer`` that contains the weights of all the replicas.

Note that the layer could be a complete NN with many separated sub_layers
each of which containing weights for all replicas together.
This functions separates the per-replica weights and returns the list of weight as if the
input ``layer`` were made of _only_ replica ``i_replica``.

Parameters
----------
layer: MetaLayer
the layer to get the weights from
i_replica: int
the replica number

Returns
-------
weights: list
list of weights for the replica
"""
if is_stacked_single_replicas(layer):
weights = layer.get_layer(f"{NN_PREFIX}_{i_replica}").weights
else:
weights = [tf.Variable(w[i_replica : i_replica + 1], name=w.name) for w in layer.weights]

return weights


def set_layer_replica_weights(layer, weights, i_replica: int):
"""
Set the weights for the given single replica ``i_replica``.
When the input ``layer`` contains weights for many replicas, ensures that
only those corresponding to replica ``i_replica`` are updated.

Parameters
----------
layer: MetaLayer
the layer to set the weights for
weights: list
list of weights for the replica
i_replica: int
the replica number
"""
if is_stacked_single_replicas(layer):
layer.get_layer(f"{NN_PREFIX}_{i_replica}").set_weights(weights)
return

full_weights = [w.numpy() for w in layer.weights]
for w_old, w_new in zip(full_weights, weights):
w_old[i_replica : i_replica + 1] = w_new

def _extract_weights(self, h5_group, weights_key, i_replica):
"""Extract weights from a h5py group, turning them into Tensorflow variables"""
weights = []

def append_weights(name, node):
if isinstance(node, h5py.Dataset):
weight_name = node.name.split("/", 2)[-1]
weight_name = weight_name.replace(f"{NN_PREFIX}_{i_replica}", f"{NN_PREFIX}_0")
weight_name = weight_name.replace(
f"{PREPROCESSING_PREFIX}_{i_replica}", f"{PREPROCESSING_PREFIX}_0"
)
weights.append(tf.Variable(node[()], name=weight_name))

h5_group.visititems(append_weights)

# have to put them in the same order
weights_ordered = []
weights_model_order = [w.name for w in self.get_replica_weights(0)[weights_key]]
for w in weights_model_order:
for w_h5 in weights:
if w_h5.name == w:
weights_ordered.append(w_h5)

return weights_ordered
layer.set_weights(full_weights)
APJansen marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 4 additions & 1 deletion n3fit/src/n3fit/layers/msr_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, mode: str = "ALL", replicas: int = 1, **kwargs):
else:
raise ValueError(f"Mode {mode} not accepted for sum rules")

self.replicas = replicas
indices = []
self.divisor_indices = []
if self._msr_enabled:
Expand Down Expand Up @@ -83,6 +84,7 @@ def call(self, pdf_integrated, photon_integral):
reshape = lambda x: op.transpose(x[0])
y = reshape(pdf_integrated)
photon_integral = reshape(photon_integral)

numerators = []

if self._msr_enabled:
Expand All @@ -96,8 +98,9 @@ def call(self, pdf_integrated, photon_integral):
divisors = op.gather(y, self.divisor_indices, axis=0)

# Fill in the rest of the flavours with 1
num_flavours = y.shape[0]
norm_constants = op.scatter_to_one(
numerators / divisors, indices=self.indices, output_shape=y.shape
numerators / divisors, indices=self.indices, output_shape=(num_flavours, self.replicas)
)

return op.batchit(op.transpose(norm_constants), batch_dimension=1)
16 changes: 12 additions & 4 deletions n3fit/src/n3fit/layers/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ class Preprocessing(MetaLayer):
Whether large x preprocessing factor should be active
seed: int
seed for the initializer of the random alpha and beta values
num_replicas: int (default 1)
The number of replicas
"""

def __init__(
self,
flav_info: Optional[list] = None,
seed: int = 0,
large_x: bool = True,
num_replicas: int = 1,
**kwargs,
):
if flav_info is None:
Expand All @@ -49,6 +52,8 @@ def __init__(
self.flav_info = flav_info
self.seed = seed
self.large_x = large_x
self.num_replicas = num_replicas

self.alphas = []
self.betas = []
super().__init__(**kwargs)
Expand Down Expand Up @@ -87,7 +92,7 @@ def generate_weight(self, name: str, kind: str, dictionary: dict, set_to_zero: b
# Generate the new trainable (or not) parameter
newpar = self.builder_helper(
name=name,
kernel_shape=(1,),
kernel_shape=(self.num_replicas, 1),
initializer=initializer,
trainable=trainable,
constraint=constraint,
Expand Down Expand Up @@ -117,9 +122,12 @@ def call(self, x):

Returns
-------
prefactor: tensor(shape=[1,N,F])
prefactor: tensor(shape=[1,R,N,F])
"""
alphas = op.stack(self.alphas, axis=1)
betas = op.stack(self.betas, axis=1)
# weight tensors of shape (R, 1, F)
alphas = op.stack(self.alphas, axis=-1)
betas = op.stack(self.betas, axis=-1)

x = op.batchit(x, batch_dimension=0)

return x ** (1 - alphas) * (1 - x) ** betas
Loading
Loading