Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scatter table refactoring #1319

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 73 additions & 35 deletions docs/tutorials/curve_analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,33 +23,12 @@ different sets of experiment results. A single experiment can define sub-experim
consisting of multiple circuits which are tagged with common metadata,
and curve analysis sorts the experiment results based on the circuit metadata.

This is an example of showing the abstract data structure of a typical curve analysis experiment:
This is an example of showing the abstract data flow of a typical curve analysis experiment:
nkanazawa1989 marked this conversation as resolved.
Show resolved Hide resolved

.. jupyter-input::
:emphasize-lines: 1,10,19

"experiment"
- circuits[0] (x=x1_A, "series_A")
- circuits[1] (x=x1_B, "series_B")
- circuits[2] (x=x2_A, "series_A")
- circuits[3] (x=x2_B, "series_B")
- circuits[4] (x=x3_A, "series_A")
- circuits[5] (x=x3_B, "series_B")
- ...

"experiment data"
- data[0] (y1_A, "series_A")
- data[1] (y1_B, "series_B")
- data[2] (y2_A, "series_A")
- data[3] (y2_B, "series_B")
- data[4] (y3_A, "series_A")
- data[5] (y3_B, "series_B")
- ...

"analysis"
- "series_A": y_A = f_A(x_A; p0, p1, p2)
- "series_B": y_B = f_B(x_B; p0, p1, p2)
- fixed parameters {p1: v}
.. figure:: images/curve_analysis_structure.png
:width: 600
:align: center
:class: no-scaled-link

Here the experiment runs two subsets of experiments, namely, series A and series B.
The analysis defines corresponding fit models :math:`f_A(x_A)` and :math:`f_B(x_B)`.
Expand Down Expand Up @@ -289,21 +268,75 @@ A developer can override this method to perform initialization of analysis-speci

Curve analysis calls the :meth:`_run_data_processing` method, where
the data processor in the analysis option is internally called.
This consumes input experiment results and creates the :class:`.CurveData` dataclass.
Then the :meth:`_format_data` method is called with the processed dataset to format it.
This consumes input experiment results and creates the :class:`.ScatterTable` dataframe.
This table may look like:

.. code-block::

xval yval yerr name class_id category shots
0 0.1 0.153659 0.011258 A 0 raw 1024
1 0.1 0.590732 0.015351 B 1 raw 1024
2 0.1 0.315610 0.014510 A 0 raw 1024
3 0.1 0.376098 0.015123 B 1 raw 1024
4 0.2 0.937073 0.007581 A 0 raw 1024
5 0.2 0.323415 0.014604 B 1 raw 1024
6 0.2 0.538049 0.015565 A 0 raw 1024
7 0.2 0.530244 0.015581 B 1 raw 1024
8 0.3 0.143902 0.010958 A 0 raw 1024
9 0.3 0.261951 0.013727 B 1 raw 1024
10 0.3 0.830732 0.011707 A 0 raw 1024
11 0.3 0.874634 0.010338 B 1 raw 1024

where the experiment consists of two subset series A and B, and the experiment parameter (xval)
is scanned from 0.1 to 0.3 in each subset. For each condition, the experiment is run twice
for some reason. Each column represents following quantity.
nkanazawa1989 marked this conversation as resolved.
Show resolved Hide resolved

- ``xval``: Parameter scanned in the experiment. This value must be defined in the circuit metadata.
- ``yval``: Nominal part of the outcome. The outcome is something like expectation value, which is computed from the experiment result with the data processor.
- ``yerr``: Standard error of the outcome, which is mainly due to sampling error.
- ``name``: Unique identifier of the result class. This is defined by the ``data_subfit_map`` option.
- ``class_id``: Numerical index corresponding to the result class. This number is automatically assigned.
- ``category``: The attribute of data set. The "raw" category indicates an output from the data processing.
- ``shots``: Number of measurement shot used to acquire this result.
nkanazawa1989 marked this conversation as resolved.
Show resolved Hide resolved

3. Formatting
^^^^^^^^^^^^^

Next, the processed dataset is converted into another format suited for the fitting and
every valid result class is assigned to a fit model.
nkanazawa1989 marked this conversation as resolved.
Show resolved Hide resolved
By default, the formatter takes average of the outcomes in the processed dataset
over the same x values, followed by the sorting in the ascending order of x values.
This allows the analysis to easily estimate the slope of the curves to
create algorithmic initial guess of fit parameters.
A developer can inject extra data processing, for example, filtering, smoothing,
or elimination of outliers for better fitting.
The new class_id is given here so that its value corresponds to the fit model object index
in this analysis class. This index mapping is done based upon the correspondence of
the data name and the fit model name.

This is done by calling :meth:`_format_data` method.
This may return new scatter table object with addition of following rows like below.
nkanazawa1989 marked this conversation as resolved.
Show resolved Hide resolved

.. code-block::

12 0.1 0.234634 0.009183 A 0 formatted 2048
13 0.2 0.737561 0.008656 A 0 formatted 2048
14 0.3 0.487317 0.008018 A 0 formatted 2048
15 0.1 0.483415 0.010774 B 1 formatted 2048
16 0.2 0.426829 0.010678 B 1 formatted 2048
17 0.3 0.568293 0.008592 B 1 formatted 2048

The new data is added under the category "formatted". This category name must be also specified in
the analysis option ``fit_category``. The following fit routine filters the scatter table
nkanazawa1989 marked this conversation as resolved.
Show resolved Hide resolved
by the category name. The (x, y) value in each row is passed to the corresponding fit model object
to compute residual values for the least square optimization.

3. Fitting
^^^^^^^^^^

Curve analysis calls the :meth:`_run_curve_fit` method, which is the core functionality of the fitting.
Another method :meth:`_generate_fit_guesses` is internally called to
prepare the initial guess and parameter boundary with respect to the formatted data.
Curve analysis calls the :meth:`_run_curve_fit` method with the formatted subset of the scatter table.
This internally calls :meth:`_generate_fit_guesses` to prepare
the initial guess and parameter boundary with respect to the formatted dataset.
Developers usually override this method to provide better initial guesses
tailored to the defined fit model or type of the associated experiment.
See :ref:`curve_analysis_init_guess` for more details.
Expand All @@ -314,13 +347,18 @@ custom fitting algorithms. This method must return a :class:`.CurveFitResult` da
^^^^^^^^^^^^^^^^^^

Curve analysis runs several postprocessing against the fit outcome.
It calls :meth:`._create_analysis_results` to create the :class:`.AnalysisResultData` class
When the fit is successful, it calls :meth:`._create_analysis_results` to create the :class:`.AnalysisResultData` objects
for the fitting parameters of interest. A developer can inject custom code to
compute custom quantities based on the raw fit parameters.
See :ref:`curve_analysis_results` for details.
Afterwards, figure plotting is handed over to the :doc:`Visualization </tutorials/visualization>` module via
the :attr:`~.CurveAnalysis.plotter` attribute, and a list of created analysis results and the figure are returned.

Afterwards, fit curves are computed with the fit models and optimal parameters, and the scatter table is
updated with the computed (x, y) values. This dataset is stored under the "fitted" category.

Finally, :meth:`._create_figures` method is called with the entire scatter table data
nkanazawa1989 marked this conversation as resolved.
Show resolved Hide resolved
to initialize the curve plotter instance accessible via the :attr:`~.CurveAnalysis.plotter` attribute.
The visualization is handed over to the :doc:`Visualization </tutorials/visualization>` module,
which provides a standardized image format for curve fit results.
A developer can overwrite this method to draw custom images.

.. _curve_analysis_init_guess:

Expand Down
nkanazawa1989 marked this conversation as resolved.
Show resolved Hide resolved
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 8 additions & 2 deletions qiskit_experiments/curve_analysis/base_curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def _default_options(cls) -> Options:
lmfit_options (Dict[str, Any]): Options that are passed to the
LMFIT minimizer. Acceptable options depend on fit_method.
x_key (str): Circuit metadata key representing a scanned value.
fit_category (str): Name of dataset in the scatter table to fit.
result_parameters (List[Union[str, ParameterRepr]): Parameters reported in the
database as a dedicated entry. This is a list of parameter representation
which is either string or ParameterRepr object. If you provide more
Expand Down Expand Up @@ -219,6 +220,7 @@ def _default_options(cls) -> Options:
options.normalization = False
options.average_method = "shots_weighted"
options.x_key = "xval"
options.fit_category = "formatted"
options.result_parameters = []
options.extra = {}
options.fit_method = "least_squares"
Expand Down Expand Up @@ -282,11 +284,13 @@ def set_options(self, **fields):
def _run_data_processing(
self,
raw_data: List[Dict],
category: str = "raw",
) -> ScatterTable:
"""Perform data processing from the experiment result payload.

Args:
raw_data: Payload in the experiment data.
category: Category string of the output dataset.

Returns:
Processed data that will be sent to the formatter method.
Expand All @@ -296,14 +300,16 @@ def _run_data_processing(
def _format_data(
self,
curve_data: ScatterTable,
category: str = "formatted",
) -> ScatterTable:
"""Postprocessing for the processed dataset.
"""Postprocessing for preparing the fitting data.

Args:
curve_data: Processed dataset created from experiment results.
category: Category string of the output dataset.

Returns:
Formatted data.
New scatter table instance including fit data.
"""

@abstractmethod
Expand Down
60 changes: 27 additions & 33 deletions qiskit_experiments/curve_analysis/composite_curve_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,32 +230,32 @@ def _create_figures(
A list of figures.
"""
for analysis in self.analyses():
sub_data = curve_data[curve_data.model_name.str.endswith(f"_{analysis.name}")]
for model_id, data in list(sub_data.groupby("model_id")):
model_name = analysis._models[model_id]._name
sub_data = curve_data[curve_data.group == analysis.name]
for name, data in list(sub_data.groupby("name")):
full_name = f"{name}_{analysis.name}"
# Plot raw data scatters
if analysis.options.plot_raw_data:
raw_data = data.filter(like="processed", axis="index")
raw_data = data[data.category == "raw"]
self.plotter.set_series_data(
series_name=model_name,
series_name=full_name,
x=raw_data.xval.to_numpy(),
y=raw_data.yval.to_numpy(),
)
# Plot formatted data scatters
formatted_data = data.filter(like="formatted", axis="index")
formatted_data = data[data.category == analysis.options.fit_category]
self.plotter.set_series_data(
series_name=model_name,
series_name=full_name,
x_formatted=formatted_data.xval.to_numpy(),
y_formatted=formatted_data.yval.to_numpy(),
y_formatted_err=formatted_data.yerr.to_numpy(),
)
# Plot fit lines
line_data = data.filter(like="fitted", axis="index")
line_data = data[data.category == "fitted"]
if len(line_data) == 0:
continue
fit_stdev = line_data.yerr.to_numpy()
self.plotter.set_series_data(
series_name=model_name,
series_name=full_name,
x_interp=line_data.xval.to_numpy(),
y_interp=line_data.yval.to_numpy(),
y_interp_err=fit_stdev if np.isfinite(fit_stdev).all() else None,
Expand Down Expand Up @@ -353,21 +353,16 @@ def _run_analysis(
metadata = analysis.options.extra.copy()
metadata["group"] = analysis.name

curve_data = analysis._format_data(
analysis._run_data_processing(experiment_data.data())
)
fit_data = analysis._run_curve_fit(curve_data.filter(like="formatted", axis="index"))
table = analysis._format_data(analysis._run_data_processing(experiment_data.data()))
formatted_subset = table[table.category == analysis.options.fit_category]
fit_data = analysis._run_curve_fit(formatted_subset)
fit_dataset[analysis.name] = fit_data

if fit_data.success:
quality = analysis._evaluate_quality(fit_data)
else:
quality = "bad"

# After the quality is determined, plot can become a boolean flag for whether
# to generate the figure
plot_bool = plot == "always" or (plot == "selective" and quality == "bad")

if self.options.return_fit_parameters:
# Store fit status overview entry regardless of success.
# This is sometime useful when debugging the fitting code.
Expand All @@ -382,10 +377,9 @@ def _run_analysis(
if fit_data.success:
# Add fit data to curve data table
fit_curves = []
formatted = curve_data.filter(like="formatted", axis="index")
columns = list(curve_data.columns)
for i, sub_data in list(formatted.groupby("model_id")):
name = analysis._models[i]._name
columns = list(table.columns)
model_names = analysis.model_names()
for i, sub_data in list(formatted_subset.groupby("class_id")):
xval = sub_data.xval.to_numpy()
if len(xval) == 0:
# If data is empty, skip drawing this model.
Expand All @@ -404,12 +398,10 @@ def _run_analysis(
model_fit[:, columns.index("yval")] = unp.nominal_values(yval_fit)
if fit_data.covar is not None:
model_fit[:, columns.index("yerr")] = unp.std_devs(yval_fit)
model_fit[:, columns.index("model_name")] = name
model_fit[:, columns.index("model_id")] = i
curve_data = curve_data.append_list_values(
other=np.vstack(fit_curves),
prefix="fitted",
)
model_fit[:, columns.index("name")] = model_names[i]
model_fit[:, columns.index("class_id")] = i
model_fit[:, columns.index("category")] = "fitted"
table = table.append_list_values(other=np.vstack(fit_curves))
analysis_results.extend(
analysis._create_analysis_results(
fit_data=fit_data,
Expand All @@ -421,18 +413,20 @@ def _run_analysis(
if self.options.return_data_points:
# Add raw data points
analysis_results.extend(
analysis._create_curve_data(
curve_data=curve_data.filter(like="formatted", axis="index"),
**metadata,
)
analysis._create_curve_data(curve_data=formatted_subset, **metadata)
)

curve_data.model_name += f"_{analysis.name}"
curve_data_set.append(curve_data)
# Add extra column to identify the fit model
table["group"] = analysis.name
curve_data_set.append(table)

combined_curve_data = pd.concat(curve_data_set)
total_quality = self._evaluate_quality(fit_dataset)

# After the quality is determined, plot can become a boolean flag for whether
# to generate the figure
plot_bool = plot == "always" or (plot == "selective" and total_quality == "bad")

# Create analysis results by combining all fit data
if all(fit_data.success for fit_data in fit_dataset.values()):
composite_results = self._create_analysis_results(
Expand Down
Loading