Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return all members when using MC PDFs #1522

Merged
merged 13 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions n3fit/src/n3fit/io/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from reportengine.compat import yaml
import validphys
import n3fit
from n3fit import vpinterface


class WriterWrapper:
Expand Down Expand Up @@ -57,9 +58,9 @@ def write_data(self, replica_path_set, fitname, tr_chi2, vl_chi2, true_chi2):
chi2 of the replica to the central experimental data
"""
# Compute the arclengths
arc_lengths = self.pdf_object.compute_arclength()
arc_lengths = vpinterface.compute_arclength(self.pdf_object)
# Compute the integrability numbers
integrability_numbers = self.pdf_object.integrability_numbers()
integrability_numbers = vpinterface.integrability_numbers(self.pdf_object)
# Construct the chi2exp file
allchi2_lines = self.stopping_object.chi2exps_str()
# Construct the preproc file (the information is only in the json file)
Expand Down Expand Up @@ -140,8 +141,8 @@ def jsonfit(replica_status, pdf_object, tr_chi2, vl_chi2, true_chi2, stop_epoch,
all_info["erf_vl"] = vl_chi2
all_info["chi2"] = true_chi2
all_info["pos_state"] = replica_status.positivity_status
all_info["arc_lengths"] = pdf_object.compute_arclength().tolist()
all_info["integrability"] = pdf_object.integrability_numbers().tolist()
all_info["arc_lengths"] = vpinterface.compute_arclength(pdf_object).tolist()
all_info["integrability"] = vpinterface.integrability_numbers(pdf_object).tolist()
all_info["timing"] = timing
# Versioning info
all_info["version"] = version()
Expand Down
5 changes: 3 additions & 2 deletions n3fit/src/n3fit/performfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,14 +264,15 @@ def performfit(
replica_path_set = replica_path / f"replica_{replica_number}"

# Create a pdf instance
pdf_instance = N3PDF(pdf_model, fit_basis=basis)
q0 = theoryid.get_description().get("Q0")
pdf_instance = N3PDF(pdf_model, fit_basis=basis, Q=q0)

# Generate the writer wrapper
writer_wrapper = WriterWrapper(
replica_number,
pdf_instance,
stopping_object,
theoryid.get_description().get("Q0") ** 2,
q0**2,
final_time,
)

Expand Down
11 changes: 4 additions & 7 deletions n3fit/src/n3fit/tests/test_vpinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from hypothesis import given, settings, example
from hypothesis.strategies import integers
from validphys.pdfgrids import xplotting_grid, distance_grids
from validphys.core import MCStats
from n3fit.vpinterface import N3PDF
from n3fit.vpinterface import N3PDF, integrability_numbers, compute_arclength
from n3fit.model_gen import pdfNN_layer_generator


Expand Down Expand Up @@ -40,20 +39,18 @@ def test_N3PDF(members, layers):
xsize = np.random.randint(2, 20)
xx = np.random.rand(xsize)
n3pdf = generate_n3pdf(layers=layers, members=members)
assert len(n3pdf) == members + 1
assert n3pdf.stats_class == MCStats
assert n3pdf.load() is n3pdf
assert len(n3pdf) == members
w = n3pdf.get_nn_weights()
assert len(w) == members
assert len(w[0]) == 16 + (layers + 1) * 2 # 16=8*2 preprocessing
ret = n3pdf(xx)
assert ret.shape == (members, xsize, 14)
int_numbers = n3pdf.integrability_numbers()
int_numbers = integrability_numbers(n3pdf)
if members == 1:
assert int_numbers.shape == (5,)
else:
assert int_numbers.shape == (members, 5)
assert n3pdf.compute_arclength().shape == (5,)
assert compute_arclength(n3pdf).shape == (5,)
# Try to get a plotting grid
res = xplotting_grid(n3pdf, 1.6, xx)
assert res.grid_values.data.shape == (members, 8, xsize)
Expand Down
Loading