Skip to content

Commit

Permalink
fix: args for remaining optimisers (#551) (#555)
Browse files Browse the repository at this point in the history
  • Loading branch information
BradyPlanden authored Nov 15, 2024
1 parent 3fe0c1d commit 93da35a
Showing 1 changed file with 162 additions and 18 deletions.
180 changes: 162 additions & 18 deletions pybop/optimisers/pints_optimisers.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,24 @@ class Adam(BasePintsOptimiser):
stacklevel=2,
)

def __init__(self, cost, **optimiser_kwargs):
super().__init__(cost, PintsAdam, **optimiser_kwargs)
def __init__(
self,
cost,
max_iterations: int = None,
min_iterations: int = 2,
max_unchanged_iterations: int = 15,
parallel: bool = False,
**optimiser_kwargs,
):
super().__init__(
cost,
PintsAdam,
max_iterations,
min_iterations,
max_unchanged_iterations,
parallel,
**optimiser_kwargs,
)


class AdamW(BasePintsOptimiser):
Expand Down Expand Up @@ -182,8 +198,24 @@ class AdamW(BasePintsOptimiser):
pybop.AdamWImpl : The PyBOP implementation this class is based on.
"""

def __init__(self, cost, **optimiser_kwargs):
super().__init__(cost, AdamWImpl, **optimiser_kwargs)
def __init__(
self,
cost,
max_iterations: int = None,
min_iterations: int = 2,
max_unchanged_iterations: int = 15,
parallel: bool = False,
**optimiser_kwargs,
):
super().__init__(
cost,
AdamWImpl,
max_iterations,
min_iterations,
max_unchanged_iterations,
parallel,
**optimiser_kwargs,
)


class IRPropMin(BasePintsOptimiser):
Expand Down Expand Up @@ -231,8 +263,24 @@ class IRPropMin(BasePintsOptimiser):
pints.IRPropMin : The PINTS implementation this class is based on.
"""

def __init__(self, cost, **optimiser_kwargs):
super().__init__(cost, PintsIRPropMin, **optimiser_kwargs)
def __init__(
self,
cost,
max_iterations: int = None,
min_iterations: int = 2,
max_unchanged_iterations: int = 15,
parallel: bool = False,
**optimiser_kwargs,
):
super().__init__(
cost,
PintsIRPropMin,
max_iterations,
min_iterations,
max_unchanged_iterations,
parallel,
**optimiser_kwargs,
)


class PSO(BasePintsOptimiser):
Expand Down Expand Up @@ -280,8 +328,24 @@ class PSO(BasePintsOptimiser):
pints.PSO : The PINTS implementation this class is based on.
"""

def __init__(self, cost, **optimiser_kwargs):
super().__init__(cost, PintsPSO, **optimiser_kwargs)
def __init__(
self,
cost,
max_iterations: int = None,
min_iterations: int = 2,
max_unchanged_iterations: int = 15,
parallel: bool = False,
**optimiser_kwargs,
):
super().__init__(
cost,
PintsPSO,
max_iterations,
min_iterations,
max_unchanged_iterations,
parallel,
**optimiser_kwargs,
)


class SNES(BasePintsOptimiser):
Expand Down Expand Up @@ -329,8 +393,24 @@ class SNES(BasePintsOptimiser):
pints.SNES : The PINTS implementation this class is based on.
"""

def __init__(self, cost, **optimiser_kwargs):
super().__init__(cost, PintsSNES, **optimiser_kwargs)
def __init__(
self,
cost,
max_iterations: int = None,
min_iterations: int = 2,
max_unchanged_iterations: int = 15,
parallel: bool = False,
**optimiser_kwargs,
):
super().__init__(
cost,
PintsSNES,
max_iterations,
min_iterations,
max_unchanged_iterations,
parallel,
**optimiser_kwargs,
)


class XNES(BasePintsOptimiser):
Expand Down Expand Up @@ -378,8 +458,24 @@ class XNES(BasePintsOptimiser):
pints.XNES : PINTS implementation of XNES algorithm.
"""

def __init__(self, cost, **optimiser_kwargs):
super().__init__(cost, PintsXNES, **optimiser_kwargs)
def __init__(
self,
cost,
max_iterations: int = None,
min_iterations: int = 2,
max_unchanged_iterations: int = 15,
parallel: bool = False,
**optimiser_kwargs,
):
super().__init__(
cost,
PintsXNES,
max_iterations,
min_iterations,
max_unchanged_iterations,
parallel,
**optimiser_kwargs,
)


class NelderMead(BasePintsOptimiser):
Expand Down Expand Up @@ -429,8 +525,24 @@ class NelderMead(BasePintsOptimiser):
pints.NelderMead : PINTS implementation of Nelder-Mead algorithm.
"""

def __init__(self, cost, **optimiser_kwargs):
super().__init__(cost, PintsNelderMead, **optimiser_kwargs)
def __init__(
self,
cost,
max_iterations: int = None,
min_iterations: int = 2,
max_unchanged_iterations: int = 15,
parallel: bool = False,
**optimiser_kwargs,
):
super().__init__(
cost,
PintsNelderMead,
max_iterations,
min_iterations,
max_unchanged_iterations,
parallel,
**optimiser_kwargs,
)


class CMAES(BasePintsOptimiser):
Expand Down Expand Up @@ -478,14 +590,30 @@ class CMAES(BasePintsOptimiser):
pints.CMAES : PINTS implementation of CMA-ES algorithm.
"""

def __init__(self, cost, **optimiser_kwargs):
def __init__(
self,
cost,
max_iterations: int = None,
min_iterations: int = 2,
max_unchanged_iterations: int = 15,
parallel: bool = False,
**optimiser_kwargs,
):
x0 = optimiser_kwargs.get("x0", cost.parameters.initial_value())
if len(x0) == 1 or len(cost.parameters) == 1:
raise ValueError(
"CMAES requires optimisation of >= 2 parameters at once. "
"Please choose another optimiser."
)
super().__init__(cost, PintsCMAES, **optimiser_kwargs)
super().__init__(
cost,
PintsCMAES,
max_iterations,
min_iterations,
max_unchanged_iterations,
parallel,
**optimiser_kwargs,
)


class CuckooSearch(BasePintsOptimiser):
Expand Down Expand Up @@ -532,5 +660,21 @@ class CuckooSearch(BasePintsOptimiser):
pybop.CuckooSearchImpl : PyBOP implementation of Cuckoo Search algorithm.
"""

def __init__(self, cost, **optimiser_kwargs):
super().__init__(cost, CuckooSearchImpl, **optimiser_kwargs)
def __init__(
self,
cost,
max_iterations: int = None,
min_iterations: int = 2,
max_unchanged_iterations: int = 15,
parallel: bool = False,
**optimiser_kwargs,
):
super().__init__(
cost,
CuckooSearchImpl,
max_iterations,
min_iterations,
max_unchanged_iterations,
parallel,
**optimiser_kwargs,
)

0 comments on commit 93da35a

Please sign in to comment.