Skip to content

Commit

Permalink
Clarify the roles of two other classes with the Tester suffix, which …
Browse files Browse the repository at this point in the history
…were really incomplete base classes.
  • Loading branch information
rileyjmurray committed Nov 13, 2023
1 parent 6aebf02 commit 75df4d3
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions test/unit/objects/test_objectivefns.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,15 @@ def test_hessian(self):
self.skipTest("Derivatives for RawTVDFunction aren't implemented yet.")


class TimeIndependentMDSObjectiveFunctionTester(ObjectiveFunctionData):
class TimeIndependentMDSObjectiveFunctionTesterBase(ObjectiveFunctionData):
"""
Tests for methods in the TimeIndependentMDSObjectiveFunction class.
"""

@staticmethod
def build_objfns(cls):
raise NotImplementedError()

@classmethod
def setUpClass(cls):
cls.penalty_dicts = [
Expand Down Expand Up @@ -332,7 +336,7 @@ def test_hessian(self):
self.assertArraysAlmostEqual(hessian / norm, fd_hessian / norm, places=3)


class Chi2FunctionTester(TimeIndependentMDSObjectiveFunctionTester, BaseCase):
class Chi2FunctionTester(TimeIndependentMDSObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = True
enable_hessian_tests = False

Expand All @@ -341,23 +345,23 @@ def build_objfns(self):
for penalties in self.penalty_dicts]


class ChiAlphaFunctionTester(TimeIndependentMDSObjectiveFunctionTester, BaseCase):
class ChiAlphaFunctionTester(TimeIndependentMDSObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = True
enable_hessian_tests = False

def build_objfns(self):
return [_objfns.ChiAlphaFunction.create_from(self.model, self.dataset, self.circuits, {'fmin': 1e-4}, None, method_names=('terms', 'dterms'))]


class FreqWeightedChi2FunctionTester(TimeIndependentMDSObjectiveFunctionTester, BaseCase):
class FreqWeightedChi2FunctionTester(TimeIndependentMDSObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = True
enable_hessian_tests = False

def build_objfns(self):
return [_objfns.FreqWeightedChi2Function.create_from(self.model, self.dataset, self.circuits, None, None, method_names=('terms', 'dterms'))]


class PoissonPicDeltaLogLFunctionTester(TimeIndependentMDSObjectiveFunctionTester, BaseCase):
class PoissonPicDeltaLogLFunctionTester(TimeIndependentMDSObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = True
enable_hessian_tests = True

Expand All @@ -367,23 +371,23 @@ def build_objfns(self):
for penalties in self.penalty_dicts]


class DeltaLogLFunctionTester(TimeIndependentMDSObjectiveFunctionTester, BaseCase):
class DeltaLogLFunctionTester(TimeIndependentMDSObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = False
enable_hessian_tests = False

def build_objfns(self):
return [_objfns.DeltaLogLFunction.create_from(self.model, self.dataset, self.circuits, None, None, method_names=('terms', 'dterms'))]


class MaxLogLFunctionTester(TimeIndependentMDSObjectiveFunctionTester, BaseCase):
class MaxLogLFunctionTester(TimeIndependentMDSObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = False
enable_hessian_tests = False

def build_objfns(self):
return [_objfns.MaxLogLFunction.create_from(self.model, self.dataset, self.circuits, None, None, method_names=('terms', 'dterms'))]


class TVDFunctionTester(TimeIndependentMDSObjectiveFunctionTester, BaseCase):
class TVDFunctionTester(TimeIndependentMDSObjectiveFunctionTesterBase, BaseCase):
computes_lsvec = True
enable_hessian_tests = False

Expand All @@ -394,11 +398,15 @@ def test_derivative(self):
self.skipTest("Derivatives for TVDFunction aren't implemented yet.")


class TimeDependentMDSObjectiveFunctionTester(ObjectiveFunctionData):
class TimeDependentMDSObjectiveFunctionTesterBase(ObjectiveFunctionData):
"""
Tests for methods in the TimeDependentMDSObjectiveFunction class.
"""

@staticmethod
def build_objfns(cls):
raise NotImplementedError()

def setUp(self):
super().setUp()
self.model.sim = pygsti.forwardsims.MapForwardSimulator(model=self.model, max_cache_size=0)
Expand All @@ -415,7 +423,7 @@ def test_dlsvec(self):
#TODO: add validation


class TimeDependentChi2FunctionTester(TimeDependentMDSObjectiveFunctionTester, BaseCase):
class TimeDependentChi2FunctionTester(TimeDependentMDSObjectiveFunctionTesterBase, BaseCase):
"""
Tests for methods in the TimeDependentChi2Function class.
"""
Expand All @@ -424,7 +432,7 @@ def build_objfns(self):
return [_objfns.TimeDependentChi2Function.create_from(self.model, self.dataset, self.circuits, method_names=('lsvec', 'dlsvec'))]


class TimeDependentPoissonPicLogLFunctionTester(TimeDependentMDSObjectiveFunctionTester, BaseCase):
class TimeDependentPoissonPicLogLFunctionTester(TimeDependentMDSObjectiveFunctionTesterBase, BaseCase):
"""
Tests for methods in the TimeDependentPoissonPicLogLFunction class.
"""
Expand Down

0 comments on commit 75df4d3

Please sign in to comment.