Skip to content
This repository has been archived by the owner on Jan 30, 2023. It is now read-only.

Commit

Permalink
Use sage's randstate API for random number generation
Browse files Browse the repository at this point in the history
  • Loading branch information
Kerl13 committed Jun 19, 2019
1 parent 693df81 commit 75ec37c
Showing 1 changed file with 17 additions and 62 deletions.
79 changes: 17 additions & 62 deletions src/sage/combinat/boltzmann_sampling/generator.pyx
Original file line number Diff line number Diff line change
@@ -1,44 +1,10 @@
# coding: utf-8
# cython: profile=True

from sage.libs.gmp.random cimport gmp_randinit_set, gmp_randinit_default
from sage.libs.gmp.types cimport gmp_randstate_t
from sage.misc.randstate cimport randstate, current_randstate
from .grammar import Atom, Product, Ref, Union
from random import randint # for initialization purpose only
from random import Random


# ---
# Simple random number generator
# ---

# https://en.wikipedia.org/wiki/Linear_congruential_generator#Parameters_in_common_use
# Line 3: the gcc version
#
# FIXME: this is not a great PRNG


cdef long int _lcg_mask = ((1 << 31) - 1)
cdef long int _lcg_a = 1103515245
cdef long int _lcg_c = 12345

cdef class LCG:
cdef long int _state

cpdef long int _rand(self):
self._state = (_lcg_a * self._state + _lcg_c) & _lcg_mask
return self._state & (_lcg_mask >> 1)

cpdef float random(self):
return float(self._rand()) / float(_lcg_mask >> 1)

cpdef long int getstate(self):
return self._state

cpdef setstate(self, long int state):
self._state = state

def seed(self, seed):
self.setstate(seed & _lcg_mask)


# ---
# Preprocessing
Expand Down Expand Up @@ -84,10 +50,11 @@ cdef map_all_names_to_ids(rules):
# Simulation phase
# ---

cdef int c_simulate(int id, float weight, int size_max, flat_rules, rand_float):
cdef int c_simulate(int id, float weight, int size_max, flat_rules, randstate rstate):
cdef int size = 0
cdef list todo = [(REF, weight, id)]
cdef float r = 0.

while todo:
type, weight, args = todo.pop()
if type == REF:
Expand All @@ -99,7 +66,7 @@ cdef int c_simulate(int id, float weight, int size_max, flat_rules, rand_float):
if size > size_max:
return size
elif type == UNION:
r = rand_float() * weight
r = rstate.c_rand_double() * weight
for arg in args:
__, arg_weight, __ = arg
r -= arg_weight
Expand All @@ -115,8 +82,7 @@ cdef int c_simulate(int id, float weight, int size_max, flat_rules, rand_float):
# Actual generation
# ---


cdef c_generate(int id, float weight, flat_rules, builders, rand_float):
cdef c_generate(int id, float weight, flat_rules, builders, randstate rstate):
generated = []
cdef list todo = [(REF, weight, id)]
cdef float r = 0.
Expand All @@ -130,7 +96,7 @@ cdef c_generate(int id, float weight, flat_rules, builders, rand_float):
atom_name, __ = args
generated.append(atom_name)
elif type == UNION:
r = rand_float() * weight
r = rstate.c_rand_double() * weight
for i in range(len(args)):
arg = args[i]
__, arg_weight, __ = arg
Expand Down Expand Up @@ -160,16 +126,17 @@ cdef c_generate(int id, float weight, flat_rules, builders, rand_float):
return obj


cdef c_gen(int id, float weight, flat_rules, int size_min, int size_max, int max_retry, builders, rnd):
cdef c_gen(int id, float weight, flat_rules, int size_min, int size_max, int max_retry, builders):
cdef int nb_rejections = 0
cdef int cumulative_rejected_size = 0
cdef int size = -1
cdef last_state = rnd.getstate()
rand_float = rnd.random
cdef randstate rstate = current_randstate()
cdef gmp_randstate_t gmp_state
gmp_randinit_default(gmp_state)

while nb_rejections < max_retry:
last_state = rnd.getstate()
size = c_simulate(id, weight, size_max, flat_rules, rand_float)
gmp_randinit_set(gmp_state, rstate.gmp_state)
size = c_simulate(id, weight, size_max, flat_rules, rstate)
if size <= size_max and size >= size_min:
break
else:
Expand All @@ -179,13 +146,12 @@ cdef c_gen(int id, float weight, flat_rules, int size_min, int size_max, int max
if not(size <= size_max and size >= size_min):
return None

rnd.setstate(last_state)
obj = c_generate(id, weight, flat_rules, builders, rand_float)
gmp_randinit_set(rstate.gmp_state, gmp_state)
obj = c_generate(id, weight, flat_rules, builders, rstate)
statistics = {
"size": size,
"nb_rejections": nb_rejections,
"cumulative_rejected_size": cumulative_rejected_size,
"rnd_state": last_state
}
return statistics, obj

Expand Down Expand Up @@ -220,17 +186,12 @@ cdef make_default_builder(rule):


class Generator:
def __init__(self, grammar, oracle=None, rnd=None):
def __init__(self, grammar, oracle=None):
# Load the default oracle if none is supplied
if oracle is None:
from .oracle import OracleSimple
oracle = OracleSimple(grammar, e1=1e-6, e2=1e-6)
self.oracle = oracle
# Load the defaut RNG is none is supplied
if rnd is None:
rnd = LCG()
rnd.seed(randint(0, _lcg_mask))
self.rnd = rnd
# flatten the grammar for faster access to rules
self.grammar = grammar
self.grammar.annotate(oracle)
Expand All @@ -251,11 +212,6 @@ class Generator:
def get_builder(self, non_terminal):
symbol_id = self.name_to_id[non_terminal]
return self.builders[symbol_id]
#
# def simulate(self, name, size_max):
# id = self.name_to_id[name]
# weight = self.grammar[name].weight
# return c_simulate(id, weight, size_max, self.flat_rules, self.rnd)

def gen(self, name, window, max_retry=2000):
id = self.name_to_id[name]
Expand All @@ -269,6 +225,5 @@ class Generator:
size_max,
max_retry,
self.builders,
self.rnd
)
return obj, statistics

0 comments on commit 75ec37c

Please sign in to comment.