Skip to content

Commit

Permalink
Give explicit shape to scatter_to_one
Browse files Browse the repository at this point in the history
  • Loading branch information
APJansen committed Jan 10, 2024
1 parent 997451b commit 5867fec
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion n3fit/src/n3fit/layers/msr_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, mode: str = "ALL", replicas: int = 1, **kwargs):
else:
raise ValueError(f"Mode {mode} not accepted for sum rules")

self.replicas = replicas
indices = []
self.divisor_indices = []
if self._msr_enabled:
Expand Down Expand Up @@ -83,6 +84,7 @@ def call(self, pdf_integrated, photon_integral):
reshape = lambda x: op.transpose(x[0])
y = reshape(pdf_integrated)
photon_integral = reshape(photon_integral)

numerators = []

if self._msr_enabled:
Expand All @@ -96,8 +98,10 @@ def call(self, pdf_integrated, photon_integral):
divisors = op.gather(y, self.divisor_indices, axis=0)

# Fill in the rest of the flavours with 1
# (Note: using y.shape in the output_shape below gives an error in Python 3.11)
num_flavours = y.shape[0]
norm_constants = op.scatter_to_one(
numerators / divisors, indices=self.indices, output_shape=y.shape
numerators / divisors, indices=self.indices, output_shape=(num_flavours, self.replicas)
)

return op.batchit(op.transpose(norm_constants), batch_dimension=1)

0 comments on commit 5867fec

Please sign in to comment.