Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor xintegrator #1779

Merged
merged 4 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions n3fit/src/n3fit/backends/keras_backend/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,18 +196,6 @@ def tensor_ones_like(*args, **kwargs):
return K.ones_like(*args, **kwargs)


@tf.function
def many_replication(grid, replications, axis=0, **kwargs):
"""
Generates a tensor with one extra dimension:
a repetition of "grid" n times along the given axis
from keras documentation:
If x has shape (s1, s2, s3) and axis is 1, the output will have shape (s1, s2 * rep, s3)
see full `docs <https://www.tensorflow.org/api_docs/python/tf/keras/backend/repeat_elements>`_
"""
return K.repeat_elements(grid, rep=replications, axis=axis, **kwargs)


# Property operations
# modify properties of the tensor like the shape or elements it has
@tf.function
Expand Down
21 changes: 10 additions & 11 deletions n3fit/src/n3fit/layers/x_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,23 @@ def get_config(self):

class xIntegrator(MetaLayer):
"""
This layer performs a sum of the input layer/tensor on the first axis
This layer performs a sum of the input layer/tensor on the axis corresponding to the x-grid
weighted by the weights of the grid.

Receives as input a rank-n (n > 1) tensor `x` (batch_dims ..., xpoints, flavours)
and returns a summation on the `xpoints` index (i.e., index -2)
weighted by the weights of the grid
The output shape is the input shape with the x-axis removed.

Parameters
----------
grid_weights: np.array
weights of the grid
x_axis: int (default=1)
axis of the input tensor that corresponds to the x-grid
"""

def __init__(self, grid_weights, output_dim=BASIS_SIZE, **kwargs):
grid_weights_tensor = op.numpy_to_tensor(grid_weights)
# Open up the grid weights
self.grid_weights = op.many_replication(grid_weights_tensor, output_dim, axis=1)
def __init__(self, grid_weights, x_axis=1, **kwargs):
self.x_axis = x_axis
self.grid_weights = op.flatten(op.numpy_to_tensor(grid_weights))
super().__init__(**kwargs)

def call(self, x):
xx = x * self.grid_weights
return op.sum(xx, axis=-2)
def call(self, pdf):
return op.tensor_product(pdf, self.grid_weights, axes=[self.x_axis, 0])
17 changes: 15 additions & 2 deletions n3fit/src/n3fit/tests/test_xops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"""
import numpy as np

from n3fit.layers import xDivide
from n3fit.backends import operations as op
from n3fit.layers import xDivide, xIntegrator


def test_xdivide_default():
Expand All @@ -21,7 +22,7 @@ def test_xdivide_default():


def test_xdivide_indices():
"""Check that the default xDivide works as expected"""
"""Check that xDivide with custom indices works as expected"""
custom_indices = [0, 1, 7]
x_div = xDivide(div_list=custom_indices)
test_input = np.array([1, 2, 3], dtype=np.float32).reshape((1, 3, 1))
Expand All @@ -32,3 +33,15 @@ def test_xdivide_indices():
expected_output[:, :, i] = 1 / test_input[:, :, 0]

np.testing.assert_allclose(test_output, expected_output, rtol=1e-05)


def test_xintegrator():
np.random.seed(42)
weights = np.random.rand(5, 1)
pdf = op.numpy_to_tensor(np.random.rand(1, 5, 8))
xint = xIntegrator(weights)
xint_out = xint(pdf)
xint_out_reference = np.array(
[[0.405455, 0.878931, 0.937715, 0.906214, 1.984154, 1.147975, 1.642387, 1.549858]]
)
np.testing.assert_allclose(xint_out.numpy(), xint_out_reference, rtol=1e-05)