Skip to content

Commit

Permalink
Merge all alphas and all betas into two vector weights
Browse files Browse the repository at this point in the history
  • Loading branch information
APJansen committed Jul 7, 2023
1 parent 43d909e commit a44a6b4
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 25 deletions.
11 changes: 11 additions & 0 deletions n3fit/src/n3fit/backends/keras_backend/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,17 @@ def op_subtract(inputs, **kwargs):
"""
return keras_subtract(inputs, **kwargs)

def clamp(inputs, minvals, maxvals, **kwargs):
"""
Clip a tensor of inputs into a range defined by a tensor of minima and a tensor of maxima
see full `docs <https://www.tensorflow.org/api_docs/python/tf/clip_by_value>`_
"""
return tf.clip_by_value(
inputs,
clip_value_min=minvals,
clip_value_max=maxvals
)


@tf.function
def backend_function(fun_name, *args, **kwargs):
Expand Down
71 changes: 46 additions & 25 deletions n3fit/src/n3fit/layers/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,17 @@ def __init__(
raise ValueError(
"Trying to instantiate a preprocessing factor with no basis information"
)
self.flav_info = flav_info
self.flav_info = self._format_info(flav_info)
self.num_flavors = len(flav_info)
self.seed = seed
self.initializer = "random_uniform"
self.large_x = large_x

self.alphas = []
self.betas = []
super().__init__(**kwargs)

def generate_weight(self,
name: str,
kind: str,
dictionary: dict,
set_to_zero: bool = False
):
def generate_weight(self, name: str, kind: str, set_to_zero: bool = False):
"""
Generates weights according to the flavour dictionary
Expand All @@ -68,8 +65,6 @@ def generate_weight(self,
name to be given to the generated weight
kind: str
where to find the limits of the weight in the dictionary
dictionary: dict
dictionary defining the weight, usually one entry from `flav_info`
set_to_zero: bool
set the weight to constant 0
"""
Expand All @@ -78,35 +73,36 @@ def generate_weight(self,
initializer = MetaLayer.init_constant(0.0)
trainable = False
else:
minval, maxval = dictionary[kind]
trainable = dictionary.get("trainable", True)
minvals = self.flav_info[kind]['min']
maxvals = self.flav_info[kind]['max']
trainable = self.flav_info['trainable']
# Set the initializer and move the seed one up
initializer = MetaLayer.select_initializer(
self.initializer, minval=minval, maxval=maxval, seed=self.seed
self.initializer, minval=minvals, maxval=maxvals, seed=self.seed
)
self.seed += 1
# If we are training, constrain the weights to be within the limits
if trainable:
constraint = constraints.MinMaxWeight(minval, maxval)
constraint = lambda w: op.clamp(w, minvals, maxvals)

# Generate the new trainable (or not) parameter
kernel_shape=(self.num_flavors,)
newpar = self.builder_helper(
name=name,
kernel_shape=(1,),
kernel_shape=kernel_shape,
initializer=initializer,
trainable=trainable,
constraint=constraint,
)
return newpar

@staticmethod
def create_constraint(self, minvals, maxvals):
return lambda w: tf.reduce_min(tf.reduce_max(w, minvals), maxvals)

def build(self, input_shape):
# Run through the whole basis
for flav_dict in self.flav_info:
flav_name = flav_dict["fl"]
alpha_name = f"alpha_{flav_name}"
self.alphas.append(self.generate_weight(alpha_name, "smallx", flav_dict))
beta_name = f"beta_{flav_name}"
self.betas.append(self.generate_weight(beta_name, "largex", flav_dict, set_to_zero=not self.large_x))
self.alphas = self.generate_weight(name='alphas', kind='smallx')
self.betas = self.generate_weight(name='betas', kind='largex', set_to_zero=not self.large_x)

super(Preprocessing, self).build(input_shape)

Expand All @@ -121,8 +117,33 @@ def call(self, x):
Returns
-------
prefactor: tensor(shape=[1,N,F])
prefactors for the single batch dimension, all gridpoints, all flavors,
"""
alphas = op.stack(self.alphas, axis=1)
betas = op.stack(self.betas, axis=1)

return x ** (1 - alphas) * (1 - x) ** betas
return x ** (1 - self.alphas) * (1 - x) ** self.betas

@staticmethod
def _format_info(flav_info):
"Helper function that ideally becomes obsolete if flav_info format is changed"
smallx_min = []
smallx_max = []
largex_min = []
largex_max = []
# have to restrict to either all trainable or none trainable
trainable = True
for flav_dict in flav_info:
smallx_min.append(flav_dict["smallx"][0])
smallx_max.append(flav_dict["smallx"][1])
largex_min.append(flav_dict["largex"][0])
largex_max.append(flav_dict["largex"][1])
trainable = trainable and flav_dict.get("trainable", True)
return {
"smallx": {
"min": op.numpy_to_tensor(smallx_min),
"max": op.numpy_to_tensor(smallx_max),
},
"largex": {
"min": op.numpy_to_tensor(largex_min),
"max": op.numpy_to_tensor(largex_max),
},
"trainable": trainable,
}

0 comments on commit a44a6b4

Please sign in to comment.