From 8810e1168d3797aa15cdad14972007b34c13c975 Mon Sep 17 00:00:00 2001 From: Aron Date: Fri, 8 Dec 2023 12:36:04 +0100 Subject: [PATCH] Merge replicas in Prefactor layer Simplify handling of dropout Factor out layer_generator in generate_dense_network Refactor dense_per_flavor_network Move setting of last nodes to generate_nn Add constant arguments Add constant arguments Move dropout to generate_nn Move concatenation of per_flavor layers into generate_nn Make the two layer generators almost equal remove separate dense and dense_per_flavor functions Add documentation. Simplify per_flavor layer concatenation Reverse order of loops over replicas and layers Fixes for dropout Fixes for per_flavour Fix issue with copying over nodes for per_flavour layer Fix seeds in per_flavour layer Add error for combination of dropout with per_flavour layers Add basis_size argument to per_flavour layer Fix model_gen tests to use new generate_nn in favor of now removed generate_dense and generate_dense_per_flavour Allow for nodes to be a tuple Move dropout, per_flavour check to checks Clarify layer type check Co-authored-by: Juan M. Cruz-Martinez Clarify naming in nn_generator Remove initializer_name argument clarify comment Co-authored-by: Juan M. Cruz-Martinez Add comment on shared layers Rewrite comprehension over replica seeds Add check on layer type Merge prefactors into single layer Add replica dimension to preprocessing factor in test Update preprocessing layer in vpinterface Remove assigning of weight slices Simplify loading weights from file Update regression data Always return a single NNs model for all replicas, adjust weight getting and setting accordingly Revert "Update regression data" This reverts commit 6f793687f2608c5af041ba989202460dff615220. Change structure of regression weights Remove now unused postfix Update regression weights Give explicit shape to scatter_to_one Update developing weights structure fix prefix typo add double ticks rename layer name constants use constants defined in metamodel.py for layer names Explain need for is_stacked_single_replicas shorten line fix constant loading Simplify get_replica_weights NNs -> all_NNs Clarify get_layer_replica_weights Co-authored-by: Juan M. Cruz-Martinez Clarify set_layer_replica_weights Remove comment about python 3.11 Co-authored-by: Juan M. Cruz-Martinez Fix typo in comment Co-authored-by: Tanjona Rabemananjara Fix formatting in docstring Co-authored-by: Tanjona Rabemananjara Rewording docstring Co-authored-by: Tanjona Rabemananjara --- n3fit/runcards/examples/developing_weights.h5 | Bin 41420 -> 41348 bytes n3fit/src/n3fit/backends/__init__.py | 25 ++-- .../n3fit/backends/keras_backend/MetaModel.py | 140 +++++++++++------- n3fit/src/n3fit/layers/msr_normalization.py | 5 +- n3fit/src/n3fit/layers/preprocessing.py | 16 +- n3fit/src/n3fit/model_gen.py | 80 +++++----- n3fit/src/n3fit/model_trainer.py | 3 +- .../src/n3fit/tests/regressions/weights_1.h5 | Bin 29064 -> 29232 bytes .../src/n3fit/tests/regressions/weights_2.h5 | Bin 29064 -> 29232 bytes n3fit/src/n3fit/tests/test_modelgen.py | 5 +- n3fit/src/n3fit/tests/test_preprocessing.py | 80 +++++----- n3fit/src/n3fit/vpinterface.py | 9 +- 12 files changed, 202 insertions(+), 161 deletions(-) diff --git a/n3fit/runcards/examples/developing_weights.h5 b/n3fit/runcards/examples/developing_weights.h5 index 542fca06f960ec56e2db2a210f06c3c058c1bec0..385ea55a758dc81092dad551d29cdb1f94fab8d9 100644 GIT binary patch delta 3754 zcmZ`*Uu;uV7(e$6x7*te%GPaNH`+Tkw<%*Q*uhvpCo#srfEh?MW^oi+V595)l`#lx z1vZ#C=E{eEKu}B;s+EiBU{^AjAhv7BR*c&v(wby|*Pj$;~<6 z_xpW+&cE;6Z@;GFU(=a-WjyiBYji<@mnUn>iJS1&B04g4RSa(3PCZt%`%-)5U3f&j zw!kqV$Kg!APir5Mk^)6}^;Hka3MMPWL0rsfBzPh`*cA=-hKD$-Tp{2)zh+DmHpJ)^ zlvSZbozo)Q*f`%JBh6)z5zDj41#Dbok6J4lEp*T_rJ374D z%rio~wJwqJURK?(py>*QAGJCq4-~skSpijc-v*oYX!C+T# zq+9G9-vNfZAZV{xgBC`Jy*x;(&SDvxIj3dtqvBPTBSI#$qX^k8< zHg0M3Gr17Hx0l#AH60BPM#GU-|4R6i%LlrHL#>RkisPHpxRc|Xn+$w4$D7l55yzX; zcrlj)Y1zg3fi%Aa23A#Wd1x#MY*|28tLSZ3~9~QkXIuhS21_tlc5g}ji5xr6J zJ`=p+w8JY$1!2B7k@^($JXJGOfxjl|`AuIx#`BSt1IyA%w32c$8(Q(zSvZy{D_)p| zqc1DIDhp>AJU8is{{}p^79RsyQ`wq6RzRSS4AXm;<0kcpD3m$>hkPbpsx>>b92k!rU^$`|IFh`x1@&A+&3j#v8HJp`Qly18sYH z=+au4+jB=Va|BPFmo7CiyZzv)FV)PX@Nv`rmB%Mel0~nMR@l@%89`^FC+X5&aP-a6 zjt&{^?_DyQ4(h-TTc}!CgwlyS4rcgZFM`b_4;Oz!Z4FUDh$n-w9}4MgkYzUt^Jm|n z_o~5pZkFx~$yome)3F>AJ8TbYg`tkIl5uYfrL%|OoBTtXc^#m&AS$3cJsdJ`fC(-3 za>zXM3_5SC98Bj`@dBl;0Zxj6(4Lw=5?#CoTAo*1lyg|_H@*H8;|ABIn>#Ut_+Cz( z5|0Udiblqd;)pwT|R|Baz($53h@=uksHwA&zU| z_&}wh%cx(FD_Ap)`c9^QX$JKpQor>A>gSk#_9E(Em-@&!>a(+IFhY+e?yE`HwH^ij zXsn<71_mxA zqZJz|4P1q5fzY}FVqCj9SqzPPWEMxU2m3*3E*mrJqR<*a{$SM)n_+DhL7|@!9SO)H zYVgd$EC>sdw>%*R1~2jvm-mvXn*#VubmVtg9t`60djfD>75y#A>&yzg??VB^ZX{BT c{8!1eG+%=&H~4E{l1*stCWmBVEH8xq2L`+q^Z)<= delta 4346 zcmZ`+T}&HS7QQo1VC)b|6G8?EW*`ZekQ(!o1Za|yO?S(-G<8*Jx9X}LOzdD&$M%Au zPL;)Onnu|y+r^xuOLwy=l~Bb)R6(7GJkScOdFVq_d8Iz^5LJ2T1FE7Zl`0jj^04=J ze%$y<JS}a_ClLIY~oqT|020= zb8e;&yyxBWa66nj-y)~HPnvW2%a96vl4CuQCg*;W@D-a5tXT zSm;JP-(jIiY&$J`^xG7#9a8*Yv%a-j@v zjGl2Jl>y6Ta6EQNMsAuJwTWbAPEldlsisM|WNc+G)aC8*`taPrX=)nJ6TvfjVmy=> zk28RVh^GQSI~h$z647&d7sG`MR;N>XFftVhhQzk~%7ad(K=5fYF)O7D3l@dgTgmw%qF^XC5z;%+Q^ox8cp^TTWYvT*yL=aV z`ic6}`Vm?_7%Vb;h?dc>6m=UZEt2oq^+W@DCoNZui9MtNV0t%^e;z+YiVpKSGa`t@ z5|c?>3AJPa9wIWv??iR*w9+coLm&R#0|BK`YJm5Y#wK!B$PmSK>Xc4u5s@=PMh*N| zsc#@Wjne$&;dx4r#z#{KH9?cJQEG-0&fU^p$U1jz>*{zrG#(2@`+fUhh0-IDz(haR zv7h4I2HryP?hb*sQoPf^Jp^y>H1Ia=gD4c5CIdz20G6+=HkKpQ&KES1Fm~y_jzBaK z4(L<8#2Y#e!pgS21Q`h>8Su=L*Ce}*HD22q0(6vEW7_fA*I>NBoKx3h#?D^k(zZJg6>NUTUB{!!(zx41XreYZ_%|_|aU@a4ShHiaH6hNP91}_U`=j0ydi;BuJ$30avzYmLd0Fj-d&8ZcN^LBW%I41X}1@rv5H@l zk3V24H-u(1n9aYO9af!#9*aZHAFeP;$$c9BMS|saRUT-Chwn5?I^+Y-5*x0K!(671 z1O?xot8v{IRdNg-j@GoJ{Y*k@WpCld50@An7D=UEn8{SyY1Y2sE^QoO$<(}*ZSDZ+ z%g)<;!!o%yL(j?ba3>Wqx04DDa9_F5LJxR94m!3XKYaP|?zCx$`PIu_Xq|6^;bu`h z3tVoEYZr26W|)A&HeoP@Ws z1qB|yufVNqYK6GrNl$?5E>*@aw^9f9uGMbC-zL9%NgoMJL?&S2S~U!Pq)AKJ{3m!r zG-u&C9)np4+A+8hcxz6rN~_#q5s_yXo);M{kR_~uLuJh zL&N}Nykm0WBs}o0y3CN^+jrdZydNtXriaIjA@{Cty570KzZJPz5tMrB4Ob$`Y~I{% z*xX16o3k-$)9)XXAD@I({~hPz`?df`h(nhdD|#iHKb%{C?ap@j=mqS`x@-7(hG-5$ z_k)W<_tgt59*+je;EQtR1$=KU!f1RtX#EvBmIpAPZWMQ*4j{0N$KKFvZO!E^e-#62W; zffesA!b$(J>LVT6Dvz&*v`v8Y539gIOj{L@PZlfgEy78kvf@5WlYdnrAaWgA^>rZ_ z32X4bb9Fn8{+-AnG{|1))<5B^nLn%)=1JF4U|CjNio~#re`WLdh?ngnfBM)^dC_qR zWR}oM-i>P<*Tm0%se)*xUtV|t{L?jXFRj4$!An>T{5b8Bul_@n(ZLxyODpotPw>pO z#7#v1P53jIKl2~L4|D#-tAvl$$%`Kmz7ar{@U4~?ZI*4&QfcUBX=u4Lbjt>13(Gch zw{55mS}hH&*&v*`3I&OFv^Yzmo%yRGU;L0{VUuTPb%F5jVgC8+gul-D*KZI$?f~-D zYlJ_~`H$xb|NIq2zB5N=*8JVE*~c9uxB9m+3C;v7jm`bWCSqHCw=`5J4SiP{S}zUV HvqAp{l7xg; diff --git a/n3fit/src/n3fit/backends/__init__.py b/n3fit/src/n3fit/backends/__init__.py index f49bbd0f53..3c48317a86 100644 --- a/n3fit/src/n3fit/backends/__init__.py +++ b/n3fit/src/n3fit/backends/__init__.py @@ -1,20 +1,23 @@ -from n3fit.backends.keras_backend.internal_state import ( - set_initial_state, - clear_backend_state, - set_eager -) +from n3fit.backends.keras_backend import callbacks, constraints, operations from n3fit.backends.keras_backend.MetaLayer import MetaLayer -from n3fit.backends.keras_backend.MetaModel import MetaModel +from n3fit.backends.keras_backend.MetaModel import ( + NN_LAYER_ALL_REPLICAS, + NN_PREFIX, + PREPROCESSING_LAYER_ALL_REPLICAS, + MetaModel, +) from n3fit.backends.keras_backend.base_layers import ( + Concatenate, Input, - concatenate, Lambda, base_layer_selector, + concatenate, regularizer_selector, - Concatenate, ) -from n3fit.backends.keras_backend import operations -from n3fit.backends.keras_backend import constraints -from n3fit.backends.keras_backend import callbacks +from n3fit.backends.keras_backend.internal_state import ( + clear_backend_state, + set_eager, + set_initial_state, +) print("Using Keras backend") diff --git a/n3fit/src/n3fit/backends/keras_backend/MetaModel.py b/n3fit/src/n3fit/backends/keras_backend/MetaModel.py index 1b0990bb03..fde7c4f987 100644 --- a/n3fit/src/n3fit/backends/keras_backend/MetaModel.py +++ b/n3fit/src/n3fit/backends/keras_backend/MetaModel.py @@ -46,7 +46,8 @@ } NN_PREFIX = "NN" -PREPROCESSING_PREFIX = "preprocessing_factor" +NN_LAYER_ALL_REPLICAS = "all_NNs" +PREPROCESSING_LAYER_ALL_REPLICAS = "preprocessing_factor" # Some keys need to work for everyone for k, v in optimizers.items(): @@ -156,7 +157,7 @@ def perform_fit(self, x=None, y=None, epochs=1, **kwargs): of the model (the loss functions) to the partial losses. If the model was compiled with input and output data, they will not be passed through. - In this case by default the number of `epochs` will be set to 1 + In this case by default the number of ``epochs`` will be set to 1 ex: {'loss': [100], 'dataset_a_loss1' : [67], 'dataset_2_loss': [33]} @@ -228,7 +229,7 @@ def compile( ): """ Compile the model given an optimizer and a list of loss functions. - The optimizer must be one of those implemented in the `optimizer` attribute of this class. + The optimizer must be one of those implemented in the ``optimizer`` attribute of this class. Options: - A learning rate and a list of target outpout can be defined. @@ -353,14 +354,10 @@ def get_replica_weights(self, i_replica): dict dictionary with the weights of the replica """ - NN_weights = [ - tf.Variable(w, name=w.name) for w in self.get_layer(f"{NN_PREFIX}_{i_replica}").weights - ] - prepro_weights = [ - tf.Variable(w, name=w.name) - for w in self.get_layer(f"{PREPROCESSING_PREFIX}_{i_replica}").weights - ] - weights = {NN_PREFIX: NN_weights, PREPROCESSING_PREFIX: prepro_weights} + weights = {} + for layer_type in [NN_LAYER_ALL_REPLICAS, PREPROCESSING_LAYER_ALL_REPLICAS]: + layer = self.get_layer(layer_type) + weights[layer_type] = get_layer_replica_weights(layer, i_replica) return weights @@ -378,10 +375,9 @@ def set_replica_weights(self, weights, i_replica=0): i_replica: int the replica number to set, defaulting to 0 """ - self.get_layer(f"{NN_PREFIX}_{i_replica}").set_weights(weights[NN_PREFIX]) - self.get_layer(f"{PREPROCESSING_PREFIX}_{i_replica}").set_weights( - weights[PREPROCESSING_PREFIX] - ) + for layer_type in [NN_LAYER_ALL_REPLICAS, PREPROCESSING_LAYER_ALL_REPLICAS]: + layer = self.get_layer(layer_type) + set_layer_replica_weights(layer=layer, weights=weights[layer_type], i_replica=i_replica) def split_replicas(self): """ @@ -411,51 +407,85 @@ def load_identical_replicas(self, model_file): """ From a single replica model, load the same weights into all replicas. """ - weights = self._format_weights_from_file(model_file) + single_replica = self.single_replica_generator() + single_replica.load_weights(model_file) + weights = single_replica.get_replica_weights(0) for i_replica in range(self.num_replicas): self.set_replica_weights(weights, i_replica) - def _format_weights_from_file(self, model_file): - """Read weights from a .h5 file and format into a dictionary of tf.Variables""" - weights = {} - with h5py.File(model_file, 'r') as f: - # look at layers of the form NN_i and take the lowest i - i_replica = 0 - while f"{NN_PREFIX}_{i_replica}" not in f: - i_replica += 1 +def is_stacked_single_replicas(layer): + """ + Check if the layer consists of stacked single replicas (Only happens for NN layers), + to determine how to extract single replica weights. - weights[NN_PREFIX] = self._extract_weights( - f[f"{NN_PREFIX}_{i_replica}"], NN_PREFIX, i_replica - ) - weights[PREPROCESSING_PREFIX] = self._extract_weights( - f[f"{PREPROCESSING_PREFIX}_{i_replica}"], PREPROCESSING_PREFIX, i_replica - ) + Parameters + ---------- + layer: MetaLayer + the layer to check - return weights + Returns + ------- + bool + True if the layer consists of stacked single replicas + """ + if not isinstance(layer, MetaModel): + return False + return f"{NN_PREFIX}_0" in [sublayer.name for sublayer in layer.layers] + + +def get_layer_replica_weights(layer, i_replica: int): + """ + Get the weights for the given single replica ``i_replica``, + from a ``layer`` that contains the weights of all the replicas. + + Note that the layer could be a complete NN with many separated sub_layers + each of which containing weights for all replicas together. + This functions separates the per-replica weights and returns the list of weight as if the + input ``layer`` were made of _only_ replica ``i_replica``. + + Parameters + ---------- + layer: MetaLayer + the layer to get the weights from + i_replica: int + the replica number + + Returns + ------- + weights: list + list of weights for the replica + """ + if is_stacked_single_replicas(layer): + weights = layer.get_layer(f"{NN_PREFIX}_{i_replica}").weights + else: + weights = [tf.Variable(w[i_replica : i_replica + 1], name=w.name) for w in layer.weights] + + return weights + + +def set_layer_replica_weights(layer, weights, i_replica: int): + """ + Set the weights for the given single replica ``i_replica``. + When the input ``layer`` contains weights for many replicas, ensures that + only those corresponding to replica ``i_replica`` are updated. + + Parameters + ---------- + layer: MetaLayer + the layer to set the weights for + weights: list + list of weights for the replica + i_replica: int + the replica number + """ + if is_stacked_single_replicas(layer): + layer.get_layer(f"{NN_PREFIX}_{i_replica}").set_weights(weights) + return + + full_weights = [w.numpy() for w in layer.weights] + for w_old, w_new in zip(full_weights, weights): + w_old[i_replica : i_replica + 1] = w_new - def _extract_weights(self, h5_group, weights_key, i_replica): - """Extract weights from a h5py group, turning them into Tensorflow variables""" - weights = [] - - def append_weights(name, node): - if isinstance(node, h5py.Dataset): - weight_name = node.name.split("/", 2)[-1] - weight_name = weight_name.replace(f"{NN_PREFIX}_{i_replica}", f"{NN_PREFIX}_0") - weight_name = weight_name.replace( - f"{PREPROCESSING_PREFIX}_{i_replica}", f"{PREPROCESSING_PREFIX}_0" - ) - weights.append(tf.Variable(node[()], name=weight_name)) - - h5_group.visititems(append_weights) - - # have to put them in the same order - weights_ordered = [] - weights_model_order = [w.name for w in self.get_replica_weights(0)[weights_key]] - for w in weights_model_order: - for w_h5 in weights: - if w_h5.name == w: - weights_ordered.append(w_h5) - - return weights_ordered + layer.set_weights(full_weights) diff --git a/n3fit/src/n3fit/layers/msr_normalization.py b/n3fit/src/n3fit/layers/msr_normalization.py index 0755cad0b4..01a3648cb7 100644 --- a/n3fit/src/n3fit/layers/msr_normalization.py +++ b/n3fit/src/n3fit/layers/msr_normalization.py @@ -40,6 +40,7 @@ def __init__(self, mode: str = "ALL", replicas: int = 1, **kwargs): else: raise ValueError(f"Mode {mode} not accepted for sum rules") + self.replicas = replicas indices = [] self.divisor_indices = [] if self._msr_enabled: @@ -83,6 +84,7 @@ def call(self, pdf_integrated, photon_integral): reshape = lambda x: op.transpose(x[0]) y = reshape(pdf_integrated) photon_integral = reshape(photon_integral) + numerators = [] if self._msr_enabled: @@ -96,8 +98,9 @@ def call(self, pdf_integrated, photon_integral): divisors = op.gather(y, self.divisor_indices, axis=0) # Fill in the rest of the flavours with 1 + num_flavours = y.shape[0] norm_constants = op.scatter_to_one( - numerators / divisors, indices=self.indices, output_shape=y.shape + numerators / divisors, indices=self.indices, output_shape=(num_flavours, self.replicas) ) return op.batchit(op.transpose(norm_constants), batch_dimension=1) diff --git a/n3fit/src/n3fit/layers/preprocessing.py b/n3fit/src/n3fit/layers/preprocessing.py index 77ea760607..f8ab1f8f55 100644 --- a/n3fit/src/n3fit/layers/preprocessing.py +++ b/n3fit/src/n3fit/layers/preprocessing.py @@ -33,6 +33,8 @@ class Preprocessing(MetaLayer): Whether large x preprocessing factor should be active seed: int seed for the initializer of the random alpha and beta values + num_replicas: int (default 1) + The number of replicas """ def __init__( @@ -40,6 +42,7 @@ def __init__( flav_info: Optional[list] = None, seed: int = 0, large_x: bool = True, + num_replicas: int = 1, **kwargs, ): if flav_info is None: @@ -49,6 +52,8 @@ def __init__( self.flav_info = flav_info self.seed = seed self.large_x = large_x + self.num_replicas = num_replicas + self.alphas = [] self.betas = [] super().__init__(**kwargs) @@ -87,7 +92,7 @@ def generate_weight(self, name: str, kind: str, dictionary: dict, set_to_zero: b # Generate the new trainable (or not) parameter newpar = self.builder_helper( name=name, - kernel_shape=(1,), + kernel_shape=(self.num_replicas, 1), initializer=initializer, trainable=trainable, constraint=constraint, @@ -117,9 +122,12 @@ def call(self, x): Returns ------- - prefactor: tensor(shape=[1,N,F]) + prefactor: tensor(shape=[1,R,N,F]) """ - alphas = op.stack(self.alphas, axis=1) - betas = op.stack(self.betas, axis=1) + # weight tensors of shape (R, 1, F) + alphas = op.stack(self.alphas, axis=-1) + betas = op.stack(self.betas, axis=-1) + + x = op.batchit(x, batch_dimension=0) return x ** (1 - alphas) * (1 - x) ** betas diff --git a/n3fit/src/n3fit/model_gen.py b/n3fit/src/n3fit/model_gen.py index 09f337dbca..fc180f392f 100644 --- a/n3fit/src/n3fit/model_gen.py +++ b/n3fit/src/n3fit/model_gen.py @@ -14,7 +14,16 @@ import numpy as np -from n3fit.backends import Input, Lambda, MetaLayer, MetaModel, base_layer_selector +from n3fit.backends import ( + NN_LAYER_ALL_REPLICAS, + NN_PREFIX, + PREPROCESSING_LAYER_ALL_REPLICAS, + Input, + Lambda, + MetaLayer, + MetaModel, + base_layer_selector, +) from n3fit.backends import operations as op from n3fit.backends import regularizer_selector from n3fit.layers import ( @@ -572,18 +581,14 @@ def pdfNN_layer_generator( else: sumrule_layer = lambda x: x - # Only these layers change from replica to replica: - preprocessing_factor_replicas = [] - for i_replica, replica_seed in enumerate(seed): - preprocessing_factor_replicas.append( - Preprocessing( - flav_info=flav_info, - input_shape=(1,), - name=f"preprocessing_factor_{i_replica}", - seed=replica_seed + number_of_layers, - large_x=not subtract_one, - ) - ) + compute_preprocessing_factor = Preprocessing( + flav_info=flav_info, + input_shape=(1,), + name=PREPROCESSING_LAYER_ALL_REPLICAS, + seed=seed[0] + number_of_layers, + large_x=not subtract_one, + num_replicas=num_replicas, + ) nn_replicas = generate_nn( layer_type=layer_type, @@ -598,38 +603,28 @@ def pdfNN_layer_generator( last_layer_nodes=last_layer_nodes, ) - # Apply NN layers for all replicas to a given input grid - def neural_network_replicas(x, postfix=""): - NNs_x = Lambda(lambda nns: op.stack(nns, axis=1), name=f"NNs{postfix}")( - [nn(x) for nn in nn_replicas] - ) + # The NN subtracted by NN(1), if applicable + def nn_subtracted(x): + NNs_x = nn_replicas(x) if subtract_one: x_eq_1_processed = process_input(layer_x_eq_1) - NNs_x_1 = Lambda(lambda nns: op.stack(nns, axis=1), name=f"NNs{postfix}_x_1")( - [nn(x_eq_1_processed) for nn in nn_replicas] - ) + NNs_x_1 = nn_replicas(x_eq_1_processed) NNs_x = subtract_one_layer([NNs_x, NNs_x_1]) return NNs_x - # Apply preprocessing factors for all replicas to a given input grid - def preprocessing_replicas(x, postfix=""): - return Lambda(lambda pfs: op.stack(pfs, axis=1), name=f"prefactors{postfix}")( - [pf(x) for pf in preprocessing_factor_replicas] - ) - - def compute_unnormalized_pdf(x, postfix=""): + def compute_unnormalized_pdf(x): # Preprocess the input grid x_nn_input = extract_nn_input(x) x_processed = process_input(x_nn_input) x_original = extract_original(x) # Compute the neural network output - NNs_x = neural_network_replicas(x_processed, postfix=postfix) + NNs_x = nn_subtracted(x_processed) # Compute the preprocessing factor - preprocessing_factors_x = preprocessing_replicas(x_original, postfix=postfix) + preprocessing_factors_x = compute_preprocessing_factor(x_original) # Apply the preprocessing factor pref_NNs_x = apply_preprocessing_factor([preprocessing_factors_x, NNs_x]) @@ -646,7 +641,7 @@ def compute_unnormalized_pdf(x, postfix=""): PDFs_unnormalized = compute_unnormalized_pdf(pdf_input) if impose_sumrule: - PDFs_integration_grid = compute_unnormalized_pdf(integrator_input, postfix="_x_integ") + PDFs_integration_grid = compute_unnormalized_pdf(integrator_input) if photons: # add batch and flavor dimensions @@ -670,11 +665,10 @@ def compute_unnormalized_pdf(x, postfix=""): if photons: PDFs = layer_photon(PDFs) - if replica_axis: - pdf_model = MetaModel(model_input, PDFs, name=f"PDFs", scaler=scaler) - else: - pdf_model = MetaModel(model_input, PDFs[:, 0], name=f"PDFs", scaler=scaler) + if not replica_axis: + PDFs = Lambda(lambda pdfs: pdfs[:, 0], name="remove_replica_axis")(PDFs) + pdf_model = MetaModel(model_input, PDFs, name=f"PDFs", scaler=scaler) return pdf_model @@ -719,8 +713,8 @@ def generate_nn( Returns ------- - nn_replicas: List[MetaModel] - List of MetaModel objects, one for each replica. + nn_replicas: MetaModel + Single model containing all replicas. """ nodes_list = list(nodes) # so we can modify it x_input = Input(shape=(None, nodes_in), batch_size=1, name='xgrids_processed') @@ -744,7 +738,7 @@ def initializer_generator(seed, i_layer): ] return initializers - elif layer_type == "dense": + else: # "dense" reg = regularizer_selector(regularizer, **regularizer_args) custom_args['regularizer'] = reg @@ -782,6 +776,7 @@ def initializer_generator(seed, i_layer): # Apply all layers to the input to create the models pdfs = [layer(x_input) for layer in list_of_pdf_layers[0]] + for layers in list_of_pdf_layers[1:]: # Since some layers (dropout) are shared, we have to treat them separately if type(layers) is list: @@ -789,9 +784,12 @@ def initializer_generator(seed, i_layer): else: pdfs = [layers(x) for x in pdfs] - models = [ - MetaModel({'NN_input': x_input}, pdf, name=f"NN_{i_replica}") + # Wrap the pdfs in a MetaModel to enable getting/setting of weights later + pdfs = [ + MetaModel({'NN_input': x_input}, pdf, name=f"{NN_PREFIX}_{i_replica}")(x_input) for i_replica, pdf in enumerate(pdfs) ] + pdfs = Lambda(lambda nns: op.stack(nns, axis=1), name=f"stack_replicas")(pdfs) + model = MetaModel({'NN_input': x_input}, pdfs, name=NN_LAYER_ALL_REPLICAS) - return models + return model diff --git a/n3fit/src/n3fit/model_trainer.py b/n3fit/src/n3fit/model_trainer.py index 99e9f013db..acbcc8b3cd 100644 --- a/n3fit/src/n3fit/model_trainer.py +++ b/n3fit/src/n3fit/model_trainer.py @@ -17,6 +17,7 @@ from n3fit import model_gen from n3fit.backends import MetaModel, callbacks, clear_backend_state from n3fit.backends import operations as op +from n3fit.backends import NN_LAYER_ALL_REPLICAS import n3fit.hyper_optimization.penalties import n3fit.hyper_optimization.rewards from n3fit.scaler import generate_scaler @@ -454,7 +455,7 @@ def _model_generation(self, xinput, pdf_model, partition, partition_idx): training.summary() pdf_model = training.get_layer("PDFs") pdf_model.summary() - nn_model = pdf_model.get_layer("NN_0") + nn_model = pdf_model.get_layer(NN_LAYER_ALL_REPLICAS) nn_model.summary() # We may have fits without sumrules imposed try: diff --git a/n3fit/src/n3fit/tests/regressions/weights_1.h5 b/n3fit/src/n3fit/tests/regressions/weights_1.h5 index 2c5b02e71eed2ab94b30aa51f0cfb9264fc4ba55..7f9f9301844822e6cb8ab0e0271f2bc8f538119f 100644 GIT binary patch literal 29232 zcmeHQ3tSY{8lPQ!tPcqCh3^$XG)YZBQFhL5iAHKjMuvqhAd8^9To9?>FGJ13oBqUC zk~a&j{8?%$X1;dLqW6mCqsZ4Kt($2XdM*8Fa=E2(XXboonH_h)OA32+=J#X2bI$jj zZ_fG8_kA?A@o1rc@nA3I z&{b!BH~d8-nReDRmzWn9ZuM-qX2&=2B_57?{RuG=z|ye2_kTc$m4-DwFg zS*+%agmg=`ShX{+NAM)!kEqm%0k>TP70pc!_$ZTdWD$=3I+4J2f+d=hvrCXbZ(f>;qgMp;&q-UFXQoMCvRElenW% zj60>>h#-w$P)#Kd2m&Tf?V(bwE4vke1KlrLyR!efil$=bfIVY8m zZRV&bq7QEGBu)r&Ny%nQZf4rd9Nv%3i3!=MBo4xDp2QoF8#Sg@K!`QMLE~U*dR8V6 zWO_C{Z;1ZR^Tl|?Q>!ICGuL9aTC&norzDsY^58sTeB}9JT;wU=?ksEO6iaq?YQ|JE z=>lKO(91^K6%|F=389a~Sza#BoR&E?k70z^D;xk1lbVrZnQBcCdpXsbnnc?zV#v$! zRm4Mxb;3dUIlO&jT9eZWlXkUdFk8u;cvnI`srSuCBwtgtV-b z1aoc}-ywRvdekMKXvqO_rUS^`f&QFeUVlfBuV23?&UE@8^b75${&@%b1#t#$Kg=Dp ze@Yf#0o_t2wqNXTc7{;H59H~2T(+OIhg1OhE!6QLuM!mj6#*3i6#*3i z6#*3i6#*3icR=9TsBsa(Tp1nzR=|Lqj%VorRSyGaZ)kK;C`nMcQplejO3L*hR}tnc z#|#60IHAC1d=1WWCdkJlQts}*bbdyTS6gO1%y)Y7V_jcnmL+v+N>1%WvEqCu%oEdj z4N67);bIA6adAEX^1|3%$`{g7zA(2S7({(0L|`2c6fVafHSL%CC=-_~qe-%IBg<#`W2wh59x`krFE6J};Q>6lKh z0u0UhDTB1+~GiM%}cs&ux14&Bcd0O5$l-QNlF~1j2;7~j{je9yeaiu8vU@>}* z_qES|wF(Dx`hvUg^xx6gxH))JbUAx-*Jk#IMN84AUT=qoZJvkP{;NMuU3L|((ClI3 zR&QWGYcmHM*F0|jF03^Q?A9MY+b$dhnUm0&^&jIgum6YLF|81<{rwSCvdLiorF;cy z-Q_iOa%?1?R~BHbp0(Bxae}eUdG6SreJ@vqTZZ}BtlN4UVonDdDqsF-_?3XkdkRwh zY+f0C3^xM;Y@5S6+ukYhH{{p@47mpa4C9BbV`mm3>|18T6U;NvtNr$IJNs|vlCNK9 zUy47*weUL3ZTBz46Kl@nZHvON{)vTX@6Suo!gFW1;x-G=JpYa8&1D1d%sESNRo((T zcSS4P%Q`<}>9>nfeAP8l1%A<348_xTY?IXpGFVom7$Q|=b+(}61ZtGBXReSX5%%xN3nk4 z%ds{{i^`Itamv;RZoGXJyGXMGt$Ka1t@?<@$gtjQNlT9N@Ojv-|M7D+G2x8 zSGNV7z7>ULZY}4^QYvjr+HXXf3(f3vgS2euNu535n@YSSeHZ&~ObzFuUxc5$Y{fsX zdIfo}@Wh8_cC}AloyuCfE@Kz`?J^3zY`_Q3@4=HC%@tKeG1A@8Z(t?_bG6}HNIIlA@YDIBzS7yDVq%V>gWAu3*_H(p#`j8CnoG0eHN65k5n ziuzr70WCOs5eMwd!C$Ye-n}EX9Q#x^H-3=#6;6!t;*PiIY5$2`&$b*r0_9ftp;>X4 z(899!(M!MWMHAZPbJgE0#$l7TbDht8jSqQxA?>9;D0Ma!m(0bbUq{ZRTul$U|HeIs4@QR+@n`|=1eYY6jTH;|`IYNWp z>9L*d7ShJPfA&ar$d^B(B|{!%H$QX3R=%wm&!2G>or-8?Z1>oU_uq?Isn;AnDzBjpZ>ddJa^l2+2^Nq17eWT~hQSzT6Sa*;Q;;R2R8@Owyx`5`Hy@ZSyTJZ!6yhx`2d#cKd!7sEu#o1 z`A_NNRD#fZn{3s%n59j0kLL+h47rwoh@`Z2frF`Mrc`0A`HeSjX-sDO7 z!nf^GKKXW?BEd5yhthkn;?YE?bS$CdzN)-0RN*)x7YVv`FOWnm4!aU)eZJNvc z6F@A~z2Z*O*+(-K4=4N3xg{R-Ig5s89Wi1iQWn?geL5QWf0igcmu;CrH0LXvpy}+n zty^Bdro+Qq#e=ue+W6nzE%Bhwoz>t^-cuAD3yQ;lR*$b+&KKAf^#p{SAXh%QZhaNN zz8YV*TpqA1pRZf42-wjt)-4wcY~SbVmTNjc^L0zRqG-vt$dhcbCyGe6^b_+UpKRGD z;pYll{?+npvka*(ZaFC7i(3kcdEkqCFbV}^OG8H%>XSVn9sJrXZ9q;FCQV z9rbIoj(JLIM}o}VIGWDhy1AvjH60#$6%XE^YwvUZDjrUF^i}b2!XrqQOHObRTv3{_XPp(y1b0zj*0B@Jjo~^Y49K?mw^eK3Hjgcsd@S`?}|3W|`A0 zxt26@dS(*;OZH{Uc08vQ(uOuMRsFAtFMyk$UTI$LY^rm^@AkYLnJ;L`|2G$M#2N{u z=aH|&iwC)bT4y==x+LnSS7mbYbsE$^&%!zZdS3Z@#hq~8Vwk6-a=w{_5eD;!>p`vv z-V0N?;;DpD3V3b-IUTH9qH;_cVd&v~F^JVm=sFbnc^Y?pmiqlzy%*{M6;O&VCYha6 pKZw3dI8ncl`<3b!cL`us&7VR*$uDBzenOSsE5DH6zZ%Og{tE_3b5Q^Q delta 3321 zcmZuzO>7fK6yBXU$?lRL1NpHNVy{C>Orj*t55)vp5<*aZ3USnyMoz$Xi1YvB5{i6+ zAfYO%dWd-vLZ}p1irNVE;MA9>>Q5~7kV>dBw}=x`IiRYCAXRNjd+3{;S?|VVr18G@ zy>H%oGxPR&6VJt^-^67%S^6kedy&Km-0wYN59SHlm|zdV)5-+*TIz36S?B1w7k(!W z`Ccv47vYJ;DJSYPT30!^a}LA5J=7~G%Cf(eqimMRb-e$ghY71S&UixBE^5-AU!PLAHp6+8?T(kc@~uJ3T& zc9;Rn9FH_HLFe7$`g#k)wN0bAN_T`q=f!#WjYAAyq>ndndK#!?5C^#)<(W-DY^q$b z)iIX0K9;(n?Z;ci5O$Y7Ro7nT*FaFV+35|@K~IK`=7!b=>M{rROEsbePT+4Yge0e! z2e%}ri|5O31$h@P1hfqJg$ZTwM5?T!n10i|@`Pw$W_~=1=MOH~Dei*~xm+xOm|VKA zsc~d7w9tlY3SkZP=Emvp)w%FYo2N*#H)rg{)Lz$=v0G`%IJ7&)=V*9tA)H|g!hT#| zQj($bVnL#drdb;BG=^p(6QRI$FO9B^gWxJWOf4hfYs^r>Em>=>wPe@`N5^4)9DVB0&+l}ILzovta=bEAd|9pU?E^HjosV#*)a@>0sq<+_o7mF|V@84V5{ zVSdNkJkDKy5fD^!zIf&A#1`hLjSUxnR@2jVWKP>m;f%PF3>Ds0lIo~R;8~ZX2>di& zMKG>Q9@-!+yI zIeLT>=~<9~2Y&26qo{8&u-!}ZhWq(teYQ3Rd3nlwm9U@3_0n#DW03pP^NC($_=Mg{ zw0ypuIYGo)XgfJByO~|BO+K3Ynuzta z_4%2PuT9R+1x7*(lM7lF5ya2+8NlRhWPTwWm|eI^&G6S)wYVBfJ%q?nl}M+x+989d z;8ERW*f?aJ^6}!*_n2KSR<>0k*uw;!x5V|;K8EYfW~qNSFTwTnKrHnuc$!skJN(4g z4!0dE#D^vBfJlo(cHUk~6^Y59Iwg_Kc6jVwfk>lGF`~dL%MsC&ymhXekj*|!e@n4+ zuuy{pK6897M&V_@lR6XCWaxac2*HLaGTaG|8@^YJET|tc_#Qil?E!Rj1ANVXa=R10 zX*W-OQ(4v%N6OUIaf_D#6;L$Mf}AUDe3ZQmdK>xf{oVtT@6wP*8J1S1y zZih>~Es8WqAwh8a=@+MMXPBwo>>)(PnE)}Lb1nUg2pjL($mV1$b@`4h58gSiR6P5T zF+{$?aJZIoV0pMzNL=WZpPyxR3;f&q0G9i$WMZBc!A8F%`k=$7)f9O#L@}#RBJOFF zluhomdzRus#M?(HzQpn17{wEaM?xBdM|)w4_7Fa8MGofh$DZi+9%!Gm2iPm`n9v8Y zbRx66|4arWoz-)I WYv|%opV>Wguzl%u@J?VC_TYc8bYh4A diff --git a/n3fit/src/n3fit/tests/regressions/weights_2.h5 b/n3fit/src/n3fit/tests/regressions/weights_2.h5 index e5a8adea8692c6ef3fab1d8947de89455ce607ca..51061a63f24a1d41d1d392f048a5ac31d231212f 100644 GIT binary patch literal 29232 zcmeHQ4O|t)7T*ghH$@XeMKiscy(m8r1O?^J+*I_Iiahg!OuZmph2$$hLM^?fC8DYI zH1V4(3pJJOS%%zMUTSGpWNA`rnWd@uS!x(&Z+GX+W$$_ye57!n?*4wI zb9QHTc9?DM)3;&$mi0OD(CfK6Tw^K8e*S`^OByX>~H zjX}S`lNM}+ibDoRL~xu3$MiD_)f@^!p&;}rnkaAxXo=ey0dvG7QDXmMz+S$9&YH&2 zgif$rS9W$puqyGeNxz&R<65rFOa*M9AH+EyuvevBdvhlQO`<+;&R4MR#MnvJG)r=9 zk~Lkb+CtDHdXo4@RqCa3osQkljvl!W2`kf-D>H|Q5p<11bm<4*nLY@rP8EMs44m_vpW`t z{9Dp$X0S^&xzvZK+=333o`&n!d(Z$W_DBJT#VUv&Le$upsjFdD>MuHIAERm5^ zAKcz)oDk*4jn zcG+0FA|q)#A@tEWE68P85>v)!ahw=?#RK5s5|T5lhMJ8G?PJq>W2Vd4`tL*u`m9zFxen@~2 zN0Y6IJ-XcOvVMr8KC8zl)6eSrxUApRQD3^R2r|y~yE@L#u2UP(T1X#JT4GgMP97%s+2HzbMYZ?T5R8_Q$0P z6~NuOs`X3a+`Tu_U%+2%4}=l!KB(vNdQtpEafjRuv_DhCpS7FOO7x534qPggr|lmv z${V2FRkvRf=h}jNrR|sco1QGzXb1A_JTBW$+e0gWd;xQO$ZJGHKtn)7Ktn)7Ktn)7 zKtn)7z#R~HJaSMUajuMwe@kIN&c?HBfNFq&b7N?9ZU{{az33qkR$|-^HJ42sUHlem>oIS)^fKy0&4T5 z@oa(t0ATY>3=cN1#o`Ey2aI18rSZ$t5x;UFPBDGT`BG-x#R@s)d?}mPNe!SaVES?) zzcsL`mi>S6bwc2h`UoA>asE;VbRM*T&G)kTMP=Soh;5>zkD-ke@5Gsz=6bIAG=QP5 zaD_y^hX+-thnX`^U5tT>WPv0j$~-Ht8A|QS!!f@XL*Y<8Ig5LGHgTmYd1o<7Ca&Zc zgpA}@m;8xx{N|D?{OO_{zwbo)o`HDZ-dF9mMIQKQ(|h?f7hmVMgmmQRtzU!IT~5L~ zBS(`BKdq@~Ey9dd_&m9?qz8XQ_$fH z-}ju$wZZ!OpOV7X`;AKm&mcYghv2-c5m?{UYBVhVxag&;E%0{-9_AxSJRUk>I;sD} zB9yha9uAs%4Hw0|U_ABlMC0K;Q_1Jg`I53z1x3q-e^2&A912gpI2X-WQi3A;1@Z0o zzJd%37UMtH%qP9)%pq$VTtsIMC*z!l-a@PUwZ=#NR-y7|o0&WpHRH`QPxpM~t!u`Y zBLYy@IccPHQYJZ`KFFRuZY`Sk($+>VhY>pUH1u*v!7&a~bOBThHh< z`DxVq?A>_FZ_lGCLG4UAUfaSqp6kRv+5IlOZ}QtDEUGns@{iv3gby~6!h*eI{g20q zU*9Y_KGqzDlm|5)$OzJ8Q(V{7v5Wf{XXA-?tX10 zS$^$`y<3OVw8OIqe_cKh z=bV_2GxB{*?^tbk<2T*yb4EAiXYCn>mKL|+Ble8J#}*GGd$$FnV?p`$H(Pv;r*CXh zG;{Q&!oPO@X!M@4%J|rxKK4nM9}VA^(2HDKyCZz)i#p?%qu(Rnv`obR8aWB^^Is(i z(?21{lA_Qq{ZV3W*#YmBT1(w#i(pe2yWJ>1f4C6 zBz|qP$;^|>N&Ka=#zPBB&|Z&l^y=o9P`4FJ_{~eNq6UWo`INVpkQ7@8%DMj`+$#X_ zp9X%4&iTJf+TQns{gAE%#YdeZ^=H^{qt%X{EC0;558K%YHJDdyeA(WLzx2RhJZicTeYomN z6n|+MD#;y)%mKmti`nHUc&FL;(vEq?1|K%U&zyXh931{2o|^X{E;HnkBVMPCX!}`W zp0Ez}nY?ru5ZC<288svzreJZ zbEQgh?0yl`mkQb7e1c!-#7`h}!k75fS+CD#BN#P^>cP0%2}x<0JmlL7d1<*QmzzxW*c33*_2D|^+FTacEsIA=J62TJvVg$#?|u9w4qJD^@KmGK ze-CsK-UErzUVdcBGc}&3cBL0~7lUIAOSq;KEsi_!PuF}_s)vLvl9jlo`A}Ru$~B*r z?t;V-3oCI=9#C3L`jspJQC2G%Q1y57WHkA|9o;`cJ z<@KvQJQ{0w2sT;~|9#yO5BA(y2L9AIPQ|gHI235v;L7FvfL&TeK-da$UyQC?Un#K9 z$5bwt1?&^UE0@a!cGQ!V%S8j*Z)oLmwdZGkZfRExE#($@vMu%`5!sf0QeNWIE&F8r zOmWM<3VubFqxGdN2W5O|OF=0Qd}$9xv4C!A=*U8Sx(B3#Uy)^eM+K>Vh9e7nx(B19 zenr+XPs!{^l(`v4?b%x^x3ss~!^5EAAsBSUeXhNRhZ7$D8Xit~bky*0!lR2@;-Rf; z5H6B-REYrFM_svpyK=vDrbyf`UcL{!+WzsvdtaCP&#S!;R@)z*jR)Ai?inemmPBi& zHPMokGFJGd`?94wp3@evhSoCG{I8ZTfSaFQZC>t3f^)-f_PiXOFYprnn+rXnF{SK0 z%5``#Ah*lISx&hwiTUYyg`9Go2J_D&uug!TSGitsE1WkE<|&z+UkYV}!aQO=$mPO& zVJ4S1o-&F7&jOIs!@4CV$0bsR0p1sbSe1mWLs6cmX4hwF-;Y&!q3%!twfJIE*g5lq nsGEcn^9!Y4X?}5&0M^v}Ed7fK6yBXU$?lRL1NpHNVy{C>Orj*t55)vp5<*aZ3USnyMoz$Xi1YvB5{i6+ zAfYO%dWd-vLZ}p1irNVE;MA9>>Q5~7kV>dBw}=x`IiRYCAXRNjd+3{;S?|VVrSZP^ zy>H%oGxPR&63@k@-^67%S^6kedy&Km-0wYN59SHlGr?Yhrpv8i&! zR>xT0`dI3QwjOU4L)cmRR9$Mq3pe}P@zf>by-~|5WLP&Co zd2mZ|x_G|qRFHS!LO@#@KQN&To=BBd6w^BRasR<3JH>s_A(x8<5R*&y zH8qY*h8Ef|Y$2?n-rP7HzB(74Y4a3m_U4SenA+=_GIo6%wKK+SM#FOp;S5_4_T&7L zk_?>}3ld#4&C)eU2~{GD^i@7 zJ;#VszIzTDYaBDZq!hrBYYigo=O-{O!?znUP{@yTsQKs(!H=Xqrssg z%2JbjEc%4`ImJ_DDk=RI?yv#qyMQ~Nr$ca-djxCvLEG9Kx zZsfsP*`T7857sVwH^Wce9d+4XxRo_5A@Kp5y{S79x?`#Hv7k*}tYMHAn#BFd?;6X9 z96iE`bT3Gt34ZK8qo{8&u-!}ZhWq(teYQ3Rd3nlwmaw14_0n#D&mi}w=M%li@Cm(@ zX!(3QbApJq&~|cMb~C$Ln|w6)H4*EzjWZesCh&>8%};JI+0l8Lkb54DWb_)16g|c~ zane=`?~y(R3!A*BWD?0lW&|(iS@}wIUiQl3oVUZ{={#f87(r`2T_|txCZ-dO8e&AA zhmg){x<>WTu1TxUXxG|`(|oVlip3!Ob36cPzZHUGwgSDc(M|}%ijAJ_7h{#jM#lo- z>+>@oUz?nt3yg#oCKt3UB8Z>sGl0q2$oxV$FuQP-n&GdpYH>A|dI*uDDv?fWwL=C^ z!K1p%uyM#b<>SSruQ9t^tZb=5u!jjcZ;9)xeGJ!|%~JnvUV`iCfmrHS@HDI7cKC^} z9d0{Th!0EL0g)Do?7Y2}DiV`HbxI2&=8VSXX^?C?bYn!0dc$aQV~r+D;NB0K+$|Ppc|)c z9=8{xZst?2*|_yy>J84&6ep02I^p zN)*jAX+KwS=yuxs%yxPqQgP^ZdLc5~>HnFEgWJytr6`*T^v_Ghp*tu?(QK!+qvFu* zcDU5rqDX@j5(KxO-Z*VL!%XdFFCj9{1c>>ZYw2G^*m&1QHYa1L%Xe&f@XmRq;@O9c zA@Ui9!?lzH%fqcg;zFMUUz+QRBgg%I+ z6PeZhXEGS^1lr#4p=&l!iwUJWA%WE|k!XajOK-Blptc&nB(_7$@625Fy))W+5g&ws zX{Qop0R&$d@jF8jS)V{b=4Xo6;V 1: - # We really don't want to fail at this point, but print a warning at least... - log.warning("More than one preprocessing layer found within the model!") - elif len(preprocessing_layers) < 1: - log.warning("No preprocessing layer found within the model!") - preprocessing_layer = preprocessing_layers[0] + preprocessing_layer = self._models[replica - 1].get_layer(PREPROCESSING_LAYER_ALL_REPLICAS) alphas_and_betas = None if self.fit_basis is not None: