Skip to content

Commit

Permalink
Merge pull request #362 from pyGSTio/help-pytest-find-tests
Browse files Browse the repository at this point in the history
Help pytest find tests
  • Loading branch information
sserita authored Nov 17, 2023
2 parents 144b6d5 + 75df4d3 commit 5c0bc65
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 95 deletions.
4 changes: 3 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ filterwarnings =
ignore:Would have scaled dProd:UserWarning
ignore:Scaled dProd small in order to keep prod managable:UserWarning
ignore:hProd is small:UserWarning
ignore:Scaled hProd small in order to keep prod managable.:UserWarning
ignore:Scaled hProd small in order to keep prod managable.:UserWarning

python_classes = *Tester
18 changes: 10 additions & 8 deletions test/unit/extras/interpygate/test_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pygsti
import pygsti.extras.interpygate as interp
from pygsti.extras.interpygate.core import use_csaps as USE_CSAPS
from pygsti.tools.basistools import change_basis
from pygsti.modelpacks import smq1Q_XY
from pathlib import Path
Expand Down Expand Up @@ -88,6 +89,7 @@ def create_aux_infos(self, v, grouped_v, comm=None):


class InterpygateConstructionTester(BaseCase):

@classmethod
def setUpClass(cls):
super(InterpygateConstructionTester, cls).setUpClass()
Expand All @@ -102,7 +104,6 @@ def setUpClass(cls):

cls.gate_process = SingleQubitGate(num_params = 3,num_params_evaluated_as_group = 1)


def test_target(self):
test = self.target_op.create_target_gate([0,np.pi/4])
self.assertArraysAlmostEqual(test, self.static_target)
Expand All @@ -120,13 +121,14 @@ def test_create_opfactory(self):
op.from_vector([1])
self.assertArraysAlmostEqual(op, self.static_target)

opfactory_spline = interp.InterpolatedOpFactory.create_by_interpolating_physical_process(
self.target_op, self.gate_process, argument_ranges=self.arg_ranges,
parameter_ranges=self.param_ranges, argument_indices=self.arg_indices,
interpolator_and_args='spline')
op = opfactory_spline.create_op([0,np.pi/4])
op.from_vector([1])
self.assertArraysAlmostEqual(op, self.static_target)
if USE_CSAPS:
opfactory_spline = interp.InterpolatedOpFactory.create_by_interpolating_physical_process(
self.target_op, self.gate_process, argument_ranges=self.arg_ranges,
parameter_ranges=self.param_ranges, argument_indices=self.arg_indices,
interpolator_and_args='spline')
op = opfactory_spline.create_op([0,np.pi/4])
op.from_vector([1])
self.assertArraysAlmostEqual(op, self.static_target)

interpolator_and_args = (_linND, {'rescale': True})
opfactory_custom = opfactory_spline = interp.InterpolatedOpFactory.create_by_interpolating_physical_process(
Expand Down
54 changes: 11 additions & 43 deletions test/unit/modelmembers/test_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,19 +95,26 @@ def test_hessian_wrt_params(self):
pass # ok if some classes don't implement this


class LinearOpTester(OpBase):
class LinearOpTester(BaseCase):
n_params = 0

@staticmethod
def build_gate():
dim = 4
evotype = Evotype.cast('default')
state_space = statespace.default_space_for_dim(dim)
rep = evotype.create_dense_superop_rep(np.identity(dim, 'd'), state_space)
# rep = evotype.create_dense_superop_rep(np.identity(dim, 'd'), state_space)
# ^ Original, failing line. My fix below.
rep = evotype.create_dense_superop_rep(None, np.identity(dim, 'd'), state_space)
return op.LinearOperator(rep, evotype)

def setUp(self):
ExplicitOpModel._strict = False
self.gate = self.build_gate()

def test_raise_on_invalid_method(self):
T = FullGaugeGroupElement(np.array([[0, 1], [1, 0]], 'd'))
mat = np.kron(np.array([[0, 1], [1, 0]], 'd'), np.eye(2))
T = FullGaugeGroupElement(mat)
with self.assertRaises(NotImplementedError):
self.gate.transform_inplace(T)
with self.assertRaises(NotImplementedError):
Expand Down Expand Up @@ -600,6 +607,7 @@ def test_include_off_diags_in_degen_blocks(self):
[(1j, (1, 0)), (-1j, (3, 2))]] # Im part of 1,0 and 3,2 els (lower triangle); (1,0) and (3,2) must be conjugates
)


class LindbladErrorgenTester(BaseCase):

def test_errgen_construction(self):
Expand Down Expand Up @@ -642,46 +650,6 @@ def test_errgen_construction_from_op(self):
errgen_copy.transform_inplace(T)
self.assertTrue(np.allclose(errgen_copy.to_dense(), eg.to_dense()))

#TODO - maybe update this to a test of ExpErrorgenOp, which can have dense/sparse versions?
#class LindbladOpBase(object):
# def test_has_nonzero_hessian(self):
# self.assertTrue(self.gate.has_nonzero_hessian())
#
#class LindbladErrorgenBase(LindbladOpBase, MutableDenseOpBase):
# def test_transform(self):
# gate_copy = self.gate.copy()
# T = UnitaryGaugeGroupElement(np.identity(4, 'd'))
# gate_copy.transform_inplace(T)
# self.assertArraysAlmostEqual(gate_copy, self.gate)
# # TODO test a non-trivial case
#
# def test_element_accessors(self):
# e1 = self.gate[1, 1]
# e2 = self.gate[1][1]
# self.assertAlmostEqual(e1, e2)
#
# s1 = self.gate[1, :]
# s2 = self.gate[1]
# s3 = self.gate[1][:]
# a1 = self.gate[:]
# self.assertArraysAlmostEqual(s1, s2)
# self.assertArraysAlmostEqual(s1, s3)
#
# s4 = self.gate[2:4, 1]
#
# result = len(self.gate)
# # TODO assert correctness
#
# def test_convert(self):
# g = op.convert(self.gate, "CPTP", Basis.cast("pp", 4))
# # TODO assert correctness
#
#
#class LindbladSparseOpBase(LindbladOpBase, OpBase):
# def assertArraysEqual(self, a, b):
# # Sparse LindbladOp does not support equality natively, so compare errorgen matrices
# self.assertEqual((a.errorgen.to_sparse() != b.errorgen.to_sparse()).nnz, 0)


class LindbladErrorgenBase(OpBase):
def test_has_nonzero_hessian(self):
Expand Down
75 changes: 32 additions & 43 deletions test/unit/objects/test_objectivefns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pygsti.objectivefns.wildcardbudget import PrimitiveOpsWildcardBudget as _PrimitiveOpsWildcardBudget
from . import smqfixtures
from ..util import BaseCase
import unittest


class ObjectiveFunctionData(object):
Expand Down Expand Up @@ -82,35 +83,15 @@ def test_simple_builds(self):
self.assertTrue(isinstance(fn, builder.cls_to_build))


#BASE CLASS - no testing
#class ObjectiveFunctionTester(BaseCase):
# """
# Tests for methods in the ObjectiveFunction class.
# """
#
# @classmethod
# def setUpClass(cls):
# pass #TODO
#
# @classmethod
# def tearDownClass(cls):
# pass #TODO
#
# def setUp(self):
# pass #TODO
#
# def tearDown(self):
# pass #TODO
#
# def test_get_chi2k_distributed_qty(self):
# raise NotImplementedError() #TODO: test chi2k_distributed_qty


class RawObjectiveFunctionTester(object):
class RawObjectiveFunctionTesterBase(object):
"""
Tests for methods in the RawObjectiveFunction class.
"""

@staticmethod
def build_objfns(cls):
raise NotImplementedError()

@classmethod
def setUpClass(cls):
cls.objfns = cls.build_objfns(cls)
Expand Down Expand Up @@ -187,7 +168,7 @@ def test_hessian(self):
# h(terms) = 2 * (dsvec**2 + lsvec * hlsvec)


class RawChi2FunctionTester(RawObjectiveFunctionTester, BaseCase):
class RawChi2FunctionTester(RawObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = True

@staticmethod
Expand All @@ -196,7 +177,7 @@ def build_objfns(cls):
return [_objfns.RawChi2Function({'min_prob_clip_for_weighting': 1e-6}, resource_alloc)]


class RawChiAlphaFunctionTester(RawObjectiveFunctionTester, BaseCase):
class RawChiAlphaFunctionTester(RawObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = True

@staticmethod
Expand All @@ -211,7 +192,7 @@ def test_hessian(self):
self.skipTest("Hessian for RawChiAlphaFunction isn't implemented yet.")


class RawFreqWeightedChi2FunctionTester(RawObjectiveFunctionTester, BaseCase):
class RawFreqWeightedChi2FunctionTester(RawObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = True

@staticmethod
Expand All @@ -220,7 +201,7 @@ def build_objfns(cls):
return [_objfns.RawFreqWeightedChi2Function({'min_freq_clip_for_weighting': 1e-4}, resource_alloc)]


class RawPoissonPicDeltaLogLFunctionTester(RawObjectiveFunctionTester, BaseCase):
class RawPoissonPicDeltaLogLFunctionTester(RawObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = True

@staticmethod
Expand All @@ -231,7 +212,7 @@ def build_objfns(cls):
'pfratio_derivpt': 0.1, 'fmin': 1e-4}, resource_alloc)]


class RawDeltaLogLFunctionTester(RawObjectiveFunctionTester, BaseCase):
class RawDeltaLogLFunctionTester(RawObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = False

@staticmethod
Expand All @@ -242,7 +223,7 @@ def build_objfns(cls):
resource_alloc)]


class RawMaxLogLFunctionTester(RawObjectiveFunctionTester, BaseCase):
class RawMaxLogLFunctionTester(RawObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = False

@staticmethod
Expand All @@ -251,7 +232,7 @@ def build_objfns(cls):
return [_objfns.RawMaxLogLFunction({}, resource_alloc)]


class RawTVDFunctionTester(RawObjectiveFunctionTester, BaseCase):
class RawTVDFunctionTester(RawObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = True

@staticmethod
Expand All @@ -266,11 +247,15 @@ def test_hessian(self):
self.skipTest("Derivatives for RawTVDFunction aren't implemented yet.")


class TimeIndependentMDSObjectiveFunctionTester(ObjectiveFunctionData):
class TimeIndependentMDSObjectiveFunctionTesterBase(ObjectiveFunctionData):
"""
Tests for methods in the TimeIndependentMDSObjectiveFunction class.
"""

@staticmethod
def build_objfns(cls):
raise NotImplementedError()

@classmethod
def setUpClass(cls):
cls.penalty_dicts = [
Expand Down Expand Up @@ -351,7 +336,7 @@ def test_hessian(self):
self.assertArraysAlmostEqual(hessian / norm, fd_hessian / norm, places=3)


class Chi2FunctionTester(TimeIndependentMDSObjectiveFunctionTester, BaseCase):
class Chi2FunctionTester(TimeIndependentMDSObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = True
enable_hessian_tests = False

Expand All @@ -360,23 +345,23 @@ def build_objfns(self):
for penalties in self.penalty_dicts]


class ChiAlphaFunctionTester(TimeIndependentMDSObjectiveFunctionTester, BaseCase):
class ChiAlphaFunctionTester(TimeIndependentMDSObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = True
enable_hessian_tests = False

def build_objfns(self):
return [_objfns.ChiAlphaFunction.create_from(self.model, self.dataset, self.circuits, {'fmin': 1e-4}, None, method_names=('terms', 'dterms'))]


class FreqWeightedChi2FunctionTester(TimeIndependentMDSObjectiveFunctionTester, BaseCase):
class FreqWeightedChi2FunctionTester(TimeIndependentMDSObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = True
enable_hessian_tests = False

def build_objfns(self):
return [_objfns.FreqWeightedChi2Function.create_from(self.model, self.dataset, self.circuits, None, None, method_names=('terms', 'dterms'))]


class PoissonPicDeltaLogLFunctionTester(TimeIndependentMDSObjectiveFunctionTester, BaseCase):
class PoissonPicDeltaLogLFunctionTester(TimeIndependentMDSObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = True
enable_hessian_tests = True

Expand All @@ -386,23 +371,23 @@ def build_objfns(self):
for penalties in self.penalty_dicts]


class DeltaLogLFunctionTester(TimeIndependentMDSObjectiveFunctionTester, BaseCase):
class DeltaLogLFunctionTester(TimeIndependentMDSObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = False
enable_hessian_tests = False

def build_objfns(self):
return [_objfns.DeltaLogLFunction.create_from(self.model, self.dataset, self.circuits, None, None, method_names=('terms', 'dterms'))]


class MaxLogLFunctionTester(TimeIndependentMDSObjectiveFunctionTester, BaseCase):
class MaxLogLFunctionTester(TimeIndependentMDSObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = False
enable_hessian_tests = False

def build_objfns(self):
return [_objfns.MaxLogLFunction.create_from(self.model, self.dataset, self.circuits, None, None, method_names=('terms', 'dterms'))]


class TVDFunctionTester(TimeIndependentMDSObjectiveFunctionTester, BaseCase):
class TVDFunctionTester(TimeIndependentMDSObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = True
enable_hessian_tests = False

Expand All @@ -413,11 +398,15 @@ def test_derivative(self):
self.skipTest("Derivatives for TVDFunction aren't implemented yet.")


class TimeDependentMDSObjectiveFunctionTester(ObjectiveFunctionData):
class TimeDependentMDSObjectiveFunctionTesterBase(ObjectiveFunctionData):
"""
Tests for methods in the TimeDependentMDSObjectiveFunction class.
"""

@staticmethod
def build_objfns(cls):
raise NotImplementedError()

def setUp(self):
super().setUp()
self.model.sim = pygsti.forwardsims.MapForwardSimulator(model=self.model, max_cache_size=0)
Expand All @@ -434,7 +423,7 @@ def test_dlsvec(self):
#TODO: add validation


class TimeDependentChi2FunctionTester(TimeDependentMDSObjectiveFunctionTester, BaseCase):
class TimeDependentChi2FunctionTester(TimeDependentMDSObjectiveFunctionTesterBase, BaseCase):
"""
Tests for methods in the TimeDependentChi2Function class.
"""
Expand All @@ -443,7 +432,7 @@ def build_objfns(self):
return [_objfns.TimeDependentChi2Function.create_from(self.model, self.dataset, self.circuits, method_names=('lsvec', 'dlsvec'))]


class TimeDependentPoissonPicLogLFunctionTester(TimeDependentMDSObjectiveFunctionTester, BaseCase):
class TimeDependentPoissonPicLogLFunctionTester(TimeDependentMDSObjectiveFunctionTesterBase, BaseCase):
"""
Tests for methods in the TimeDependentPoissonPicLogLFunction class.
"""
Expand Down

0 comments on commit 5c0bc65

Please sign in to comment.