-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from imaginationtech/splitfiles
Split constrainedrandom/__init__.py
- Loading branch information
Showing
8 changed files
with
762 additions
and
666 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
# SPDX-License-Identifier: MIT | ||
# Copyright (c) 2023 Imagination Technologies Ltd. All Rights Reserved | ||
|
||
import constraint | ||
from collections import defaultdict | ||
from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING, Union | ||
|
||
from constrainedrandom import utils | ||
|
||
if TYPE_CHECKING: | ||
from constrainedrandom.randobj import RandObj | ||
from constrainedrandom.internal.randvar import RandVar | ||
|
||
|
||
class MultiVarProblem: | ||
''' | ||
Multi-variable problem. Used internally by RandObj. | ||
Represents one problem concerning multiple random variables, | ||
where those variables all share dependencies on one another. | ||
:param parent: The :class:`RandObj` instance that owns this instance. | ||
:param vars: The dictionary of names and :class:`RandVar` instances this problem consists of. | ||
:param constraints: The list or tuple of constraints associated with | ||
the random variables. | ||
:param max_iterations: The maximum number of failed attempts to solve the randomization | ||
problem before giving up. | ||
:param max_domain_size: The maximum size of domain that a constraint satisfaction problem | ||
may take. This is used to avoid poor performance. When a problem exceeds this domain | ||
size, we don't use the ``constraint`` package, but just use ``random`` instead. | ||
For :class:`MultiVarProblem`, we also use this to determine the maximum size of a | ||
solution group. | ||
''' | ||
|
||
def __init__( | ||
self, | ||
parent: 'RandObj', | ||
vars: Dict[str, 'RandVar'], | ||
constraints: Iterable[utils.Constraint], | ||
max_iterations: int, | ||
max_domain_size: int, | ||
) -> None: | ||
self.parent = parent | ||
self.random = self.parent._random | ||
self.vars = vars | ||
self.constraints = constraints | ||
self.max_iterations = max_iterations | ||
self.max_domain_size = max_domain_size | ||
|
||
def determine_order(self) -> List[List['RandVar']]: | ||
''' | ||
Chooses an order in which to resolve the values of the variables. | ||
Used internally. | ||
:return: A list of lists denoting the order in which to solve the problem. | ||
Each inner list is a group of variables that can be solved at the same | ||
time. Each inner list will be considered separately. | ||
''' | ||
# Aim to build a list of lists, each inner list denoting a group of variables | ||
# to solve at the same time. | ||
# The best case is to simply solve them all at once, if possible, however it is | ||
# likely that the domain will be too large. | ||
|
||
# Use order hints first, remaining variables can be placed anywhere the domain | ||
# isn't too large. | ||
sorted_vars = sorted(self.vars.values(), key=lambda x: x.order) | ||
|
||
# Currently this is just a flat list. Group into as large groups as possible. | ||
result = [[sorted_vars[0]]] | ||
index = 0 | ||
domain_size = len(sorted_vars[0].domain) if sorted_vars[0].domain is not None else 1 | ||
for var in sorted_vars[1:]: | ||
if var.domain is not None: | ||
domain_size = domain_size * len(var.domain) | ||
if var.order == result[index][0].order and domain_size < self.max_domain_size: | ||
# Put it in the same group as the previous one, carry on | ||
result[index].append(var) | ||
else: | ||
# Make a new group | ||
index += 1 | ||
domain_size = len(var.domain) if var.domain is not None else 1 | ||
result.append([var]) | ||
|
||
return result | ||
|
||
def solve_groups( | ||
self, | ||
groups: List[List['RandVar']], | ||
max_iterations:int, | ||
solutions_per_group: Optional[int]=None, | ||
) -> Union[Dict[str, Any], None]: | ||
''' | ||
Constraint solving algorithm (internally used by :class:`MultiVarProblem`). | ||
:param groups: The list of lists denoting the order in which to resolve the random variables. | ||
See :func:`determine_order`. | ||
:param max_iterations: The maximum number of failed attempts to solve the randomization | ||
problem before giving up. | ||
:param solutions_per_group: If ``solutions_per_group`` is not ``None``, | ||
solve each constraint group problem 'sparsely', | ||
i.e. maintain only a subset of potential solutions between groups. | ||
Fast but prone to failure. | ||
``solutions_per_group = 1`` is effectively a depth-first search through the state space | ||
and comes with greater benefits of considering each multi-variable constraint at | ||
most once. | ||
If ``solutions_per_group`` is ``None``, Solve constraint problem 'thoroughly', | ||
i.e. keep all possible results between iterations. | ||
Slow, but will usually converge. | ||
:returns: A valid solution to the problem, in the form of a dictionary with the | ||
names of the random variables as keys and the valid solution as the values. | ||
Returns ``None`` if no solution is found within the allotted ``max_iterations``. | ||
''' | ||
constraints = self.constraints | ||
sparse_solver = solutions_per_group is not None | ||
|
||
if sparse_solver: | ||
solved_vars = defaultdict(set) | ||
else: | ||
solved_vars = [] | ||
problem = constraint.Problem() | ||
|
||
for idx, group in enumerate(groups): | ||
# Construct a constraint problem where possible. A variable must have a domain | ||
# in order to be part of the problem. If it doesn't have one, it must just be | ||
# randomized. | ||
if sparse_solver: | ||
# Construct one problem per iteration, add solved variables from previous groups | ||
problem = constraint.Problem() | ||
for name, values in solved_vars.items(): | ||
problem.addVariable(name, list(values)) | ||
group_vars = [] | ||
rand_vars = [] | ||
for var in group: | ||
group_vars.append(var.name) | ||
if var.domain is not None and not isinstance(var.domain, dict): | ||
problem.addVariable(var.name, var.domain) | ||
# If variable has its own constraints, these must be added to the problem, | ||
# regardless of whether var.check_constraints is true, as the var's value will | ||
# depend on the value of the other constrained variables in the problem. | ||
for con in var.constraints: | ||
problem.addConstraint(con, (var.name,)) | ||
else: | ||
rand_vars.append(var) | ||
# Add all pertinent constraints | ||
skipped_constraints = [] | ||
for (con, vars) in constraints: | ||
skip = False | ||
for var in vars: | ||
if var not in group_vars and var not in solved_vars: | ||
# Skip this constraint | ||
skip = True | ||
break | ||
if skip: | ||
skipped_constraints.append((con, vars)) | ||
continue | ||
problem.addConstraint(con, vars) | ||
# Problem is ready to solve, apart from any new random variables | ||
solutions = [] | ||
attempts = 0 | ||
while True: | ||
if attempts >= max_iterations: | ||
# We have failed, give up | ||
return None | ||
for var in rand_vars: | ||
# Add random variables in with a concrete value | ||
if solutions_per_group > 1: | ||
var_domain = set() | ||
for _ in range(solutions_per_group): | ||
var_domain.add(var.randomize()) | ||
problem.addVariable(var.name, list(var_domain)) | ||
else: | ||
problem.addVariable(var.name, (var.randomize(),)) | ||
solutions = problem.getSolutions() | ||
if len(solutions) > 0: | ||
break | ||
else: | ||
attempts += 1 | ||
for var in rand_vars: | ||
# Remove from problem, they will be re-added with different concrete values | ||
del problem._variables[var.name] | ||
# This group is solved, move on to the next group. | ||
if sparse_solver: | ||
if idx != len(groups) - 1: | ||
# Store a small number of concrete solutions to avoid bloating the state space. | ||
if solutions_per_group < len(solutions): | ||
solution_subset = self.random.choices(solutions, k=solutions_per_group) | ||
else: | ||
solution_subset = solutions | ||
solved_vars = defaultdict(set) | ||
for soln in solution_subset: | ||
for name, value in soln.items(): | ||
solved_vars[name].add(value) | ||
if solutions_per_group == 1: | ||
# This means we have exactly one solution for the variables considered so far, | ||
# meaning we don't need to re-apply solved constraints for future groups. | ||
constraints = skipped_constraints | ||
else: | ||
solved_vars += group_vars | ||
|
||
return self.random.choice(solutions) | ||
|
||
def solve(self) -> Union[Dict[str, Any], None]: | ||
''' | ||
Attempt to solve the variables with respect to the constraints. | ||
:return: One valid solution for the randomization problem, represented as | ||
a dictionary with keys referring to the named variables. | ||
:raises RuntimeError: When the problem cannot be solved in fewer than | ||
the allowed number of iterations. | ||
''' | ||
groups = self.determine_order() | ||
|
||
solution = None | ||
|
||
# Try to solve sparsely first | ||
sparsities = [1, 10, 100, 1000] | ||
# The worst-case value of the number of iterations for one sparsity level is: | ||
# iterations_per_sparsity * iterations_per_attempt | ||
# because of the call to solve_groups hitting iterations_per_attempt. | ||
# Failing individual solution attempts speeds up some problems greatly, | ||
# this can be thought of as pruning explorations of the state tree. | ||
# So, reduce iterations_per_attempt by an order of magnitude. | ||
iterations_per_sparsity = self.max_iterations | ||
iterations_per_attempt = self.max_iterations // 10 | ||
for sparsity in sparsities: | ||
for _ in range(iterations_per_sparsity): | ||
solution = self.solve_groups(groups, iterations_per_attempt, sparsity) | ||
if solution is not None and len(solution) > 0: | ||
return solution | ||
|
||
# Try 'thorough' method - no backup plan if this fails | ||
solution = self.solve_groups(groups, self.max_iterations) | ||
if solution is None: | ||
raise RuntimeError("Could not solve problem.") | ||
return solution | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# SPDX-License-Identifier: MIT | ||
# Copyright (c) 2023 Imagination Technologies Ltd. All Rights Reserved | ||
|
||
import constraint | ||
from functools import partial | ||
from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING | ||
|
||
from constrainedrandom import utils | ||
|
||
if TYPE_CHECKING: | ||
from constrainedrandom.randobj import RandObj | ||
|
||
|
||
class RandVar: | ||
''' | ||
Randomizable variable. For internal use with RandObj. | ||
:param parent: The :class:`RandObj` instance that owns this instance. | ||
:param name: The name of this random variable. | ||
:param order: The solution order for this variable with respect to other variables. | ||
:param domain: The possible values for this random variable, expressed either | ||
as a ``range``, or as an iterable (e.g. ``list``, ``tuple``) of possible values. | ||
Mutually exclusive with ``bits`` and ``fn``. | ||
:param bits: Specifies the possible values of this variable in terms of a width | ||
in bits. E.g. ``bits=32`` signifies this variable can be ``0 <= x < 1 << 32``. | ||
Mutually exclusive with ``domain`` and ``fn``. | ||
:param fn: Specifies a function to call that will provide the value of this random | ||
variable. | ||
Mutually exclusive with ``domain`` and ``bits``. | ||
:param args: Arguments to pass to the function specified in ``fn``. | ||
If ``fn`` is not used, ``args`` must not be used. | ||
:param constraints: List or tuple of constraints that apply to this random variable. | ||
:param max_iterations: The maximum number of failed attempts to solve the randomization | ||
problem before giving up. | ||
:param max_domain_size: The maximum size of domain that a constraint satisfaction problem | ||
may take. This is used to avoid poor performance. When a problem exceeds this domain | ||
size, we don't use the ``constraint`` package, but just use ``random`` instead. | ||
''' | ||
|
||
def __init__(self, | ||
parent: 'RandObj', | ||
name: str, | ||
order: int, | ||
*, | ||
domain: Optional[utils.Domain]=None, | ||
bits: Optional[int]=None, | ||
fn: Optional[Callable]=None, | ||
args: Optional[tuple]=None, | ||
constraints: Optional[Iterable[utils.Constraint]]=None, | ||
max_iterations: int, | ||
max_domain_size:int, | ||
) -> None: | ||
self.parent = parent | ||
self.random = self.parent._random | ||
self.name = name | ||
self.order = order | ||
self.max_iterations = max_iterations | ||
self.max_domain_size = max_domain_size | ||
assert ((domain is not None) != (fn is not None)) != (bits is not None), "Must specify exactly one of fn, domain or bits" | ||
if fn is None: | ||
assert args is None, "args has no effect without fn" | ||
self.domain = domain | ||
self.bits = bits | ||
self.fn = fn | ||
self.args = args | ||
self.constraints = constraints if constraints is not None else [] | ||
if not (isinstance(self.constraints, tuple) or isinstance(self.constraints, list)): | ||
self.constraints = (self.constraints,) | ||
# Default strategy is to randomize and check the constraints. | ||
self.check_constraints = len(self.constraints) > 0 | ||
# Create a function, self.randomizer, that returns the appropriate random value | ||
if self.fn is not None: | ||
if self.args is not None: | ||
self.randomizer = partial(self.fn, *self.args) | ||
else: | ||
self.randomizer = self.fn | ||
elif self.bits is not None: | ||
self.randomizer = partial(self.random.getrandbits, self.bits) | ||
self.domain = range(0, 1 << self.bits) | ||
else: | ||
# If we are provided a sufficiently small domain and we have constraints, simply construct a | ||
# constraint solution problem instead. | ||
is_range = isinstance(self.domain, range) | ||
is_list = isinstance(self.domain, list) or isinstance(self.domain, tuple) | ||
is_dict = isinstance(self.domain, dict) | ||
if self.check_constraints and len(self.domain) < self.max_domain_size and (is_range or is_list): | ||
problem = constraint.Problem() | ||
problem.addVariable(self.name, self.domain) | ||
for con in self.constraints: | ||
problem.addConstraint(con, (self.name,)) | ||
# Produces a list of dictionaries | ||
solutions = problem.getSolutions() | ||
def solution_picker(solns): | ||
return self.random.choice(solns)[self.name] | ||
self.randomizer = partial(solution_picker, solutions) | ||
self.check_constraints = False | ||
elif is_range: | ||
self.randomizer = partial(self.random.randrange, self.domain.start, self.domain.stop) | ||
elif is_list: | ||
self.randomizer = partial(self.random.choice, self.domain) | ||
elif is_dict: | ||
self.randomizer = partial(self.random.dist, self.domain) | ||
else: | ||
raise TypeError(f'RandVar was passed a domain of a bad type - {self.domain}. Domain should be a range, list, tuple or dictionary.') | ||
|
||
def randomize(self) -> Any: | ||
''' | ||
Returns a random value based on the definition of this random variable. | ||
Does not modify the state of the :class:`RandVar` instance. | ||
:return: A randomly generated value, conforming to the definition of | ||
this random variable, its constraints, etc. | ||
:raises RuntimeError: When the problem cannot be solved in fewer than | ||
the allowed number of iterations. | ||
''' | ||
value = self.randomizer() | ||
value_valid = not self.check_constraints | ||
iterations = 0 | ||
while not value_valid: | ||
if iterations == self.max_iterations: | ||
raise RuntimeError("Too many iterations, can't solve problem") | ||
problem = constraint.Problem() | ||
problem.addVariable(self.name, (value,)) | ||
for con in self.constraints: | ||
problem.addConstraint(con, (self.name,)) | ||
value_valid = problem.getSolution() is not None | ||
if not value_valid: | ||
value = self.randomizer() | ||
iterations += 1 | ||
return value |
Oops, something went wrong.