Skip to content
This repository has been archived by the owner on Jan 30, 2023. It is now read-only.

Commit

Permalink
Fix various generator bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Kerl13 committed Jun 21, 2019
1 parent 5303f2a commit b648771
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/sage/combinat/boltzmann_sampling/generator.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,10 @@ cdef c_gen(first_rule, rules, int size_min, int size_max, int max_retry, builder
cdef inline identity(x):
return x

cdef inline first(x):
a, __ = x
return a

def UnionBuilder(*builders):
"""Helper for computing the builder of a union out of the builders of its
components in Boltzmann samplers. For instance, in order to compute the
Expand Down Expand Up @@ -317,11 +321,11 @@ cdef make_default_builder(rule):
if isinstance(rule, Ref):
return identity
elif isinstance(rule, Atom):
return identity
elif isinstance(rule, Product):
return first
elif isinstance(rule, Union):
subbuilders = [make_default_builder(component) for component in rule.args]
return UnionBuilder(*subbuilders)
elif isinstance(rule, Union):
elif isinstance(rule, Product):
subbuilders = [make_default_builder(component) for component in rule.args]
return ProductBuilder(subbuilders)

Expand Down Expand Up @@ -386,7 +390,7 @@ class Generator:
def _precompute_oracle_values(self, z):
return [
self.oracle.eval_rule(self.id_to_name[rule_id], z)
for rule_id in range(len(self.flat_rules))
for rule_id in range(len(self.id_to_name))
]

def gen(self, name, window, max_retry=2000, singular=True):
Expand Down Expand Up @@ -422,6 +426,7 @@ class Generator:
self.singularity = values[atom_name]
z = self.singularity
self.oracle_cache[z] = values
z = self.singularity
else:
raise NotImplementedError("Non-singular generation")
if z not in self.oracle_cache:
Expand All @@ -434,7 +439,7 @@ class Generator:
self.grammar.rules
)
# Generate
first_rule = flat_rules[self.name_to_id[name]]
first_rule = (REF, self.oracle_cache[z][name], self.name_to_id[name])
statistics, obj = c_gen(
first_rule,
flat_rules,
Expand Down

0 comments on commit b648771

Please sign in to comment.