Skip to content

Commit

Permalink
probula: Sanitize where the rng state is created
Browse files Browse the repository at this point in the history
It should only be created in top-level executable files (which in this
project means Spec files)
  • Loading branch information
wasowski committed Jan 13, 2024
1 parent 5fac470 commit 06379f2
Show file tree
Hide file tree
Showing 19 changed files with 50 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ case class BdlConcreteExpectedSarsa [
val gamma: Double,
val epsilon0: Probability,
val episodes: Int,
) extends BdlLearn[State, ObservableState, Action, Double, Randomized2],
) (using val rng: probula.RNG)
extends BdlLearn[State, ObservableState, Action, Double, Randomized2],
ConcreteExactRL[State, ObservableState, Action],
NoDecay:

Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/symsim/concrete/BdlConcreteSarsa.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ case class BdlConcreteSarsa [
val gamma: Double,
val epsilon0: Probability,
val episodes: Int,
) extends BdlLearn[State, ObservableState, Action, Double, Randomized2],
) (using val rng: probula.RNG)
extends BdlLearn[State, ObservableState, Action, Double, Randomized2],
ConcreteExactRL[State, ObservableState, Action],
NoDecay:

Expand Down
9 changes: 4 additions & 5 deletions src/main/scala/symsim/concrete/ConcreteExactRL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ package concrete

import cats.kernel.BoundedEnumerable

given spire.random.rng.SecureJava =
spire.random.rng.SecureJava.apply

trait ConcreteExactRL[State, ObservableState, Action]
extends ExactRL[State, ObservableState, Action, Double, Randomized2]:

Expand All @@ -24,10 +21,12 @@ trait ConcreteExactRL[State, ObservableState, Action]

// TODO: unclear if this is general (if it turns out to be the same im
// symbolic or approximate algos we should promote this to the trait

given rng: probula.RNG

def runQ: (Q, List[Q]) =
val initials = agent.initialize.sample(episodes)
val outcome = learn (vf.initialize, List[VF](), initials).sample()
val initials = agent.initialize.sample (episodes)
val outcome = learn (vf.initialize, List[VF] (), initials).sample ()
(outcome._1, outcome._2)

override def run: Policy =
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/symsim/concrete/ConcreteExpectedSarsa.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ case class ConcreteExpectedSarsa[State, ObservableState, Action] (
val gamma: Double,
val epsilon0: Probability,
val episodes: Int,
) extends ExpectedSarsa[State, ObservableState, Action, Double, Randomized2],
) (using val rng: probula.RNG)
extends ExpectedSarsa[State, ObservableState, Action, Double, Randomized2],
ConcreteExactRL[State, ObservableState, Action],
NoDecay:

Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/symsim/concrete/ConcreteQLearning.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ case class ConcreteQLearning [
val gamma: Double,
val epsilon0: Probability,
val episodes: Int,
) extends QLearning[State, ObservableState, Action, Double, Randomized2],
) (using val rng: probula.RNG)
extends QLearning[State, ObservableState, Action, Double, Randomized2],
ConcreteExactRL[State, ObservableState, Action],
NoDecay:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ case class ConcreteQLearningWithDecay [
val epsilon0: Probability,
val episodes: Int,

) extends QLearning[State, ObservableState, Action, Double, Randomized2],
) (using val rng: probula.RNG)
extends QLearning[State, ObservableState, Action, Double, Randomized2],
ConcreteExactRL[State, ObservableState, Action],
BoundedEpsilonDecay:

Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/symsim/concrete/ConcreteSarsa.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ case class ConcreteSarsa [
val gamma: Double,
val epsilon0: Probability,
val episodes: Int,
) extends Sarsa[State, ObservableState, Action, Double, Randomized2],
) (using val rng: probula.RNG)
extends Sarsa[State, ObservableState, Action, Double, Randomized2],
ConcreteExactRL[State, ObservableState, Action],
NoDecay:

Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/symsim/concrete/Randomized2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ object Randomized2:
/** Perform an imperative operation that depends on one sample from this
* Randomized. This is mostly meant for IO at this point.
*/
def run (f: A => Unit): Unit = f(self.sample ())
def run (f: A => Unit) (using RNG): Unit = f(self.sample ())

def filter (p: A => Boolean): Randomized2[A] =
self.filter (p)
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class BanditInstances (banditReward: List [Randomized2[BanditReward]]) (using pr
extends AgentConstraints[BanditState, BanditState, BanditAction, BanditReward, Randomized2]:

given enumAction: BoundedEnumerable[BanditAction] =
BoundedEnumerableFromList (List.range(0, banditReward.size)*)
BoundedEnumerableFromList (List.range (0, banditReward.size)*)

given enumState: BoundedEnumerable[BanditState] =
BoundedEnumerableFromList (false, true)
Expand Down
7 changes: 3 additions & 4 deletions src/test/scala/symsim/ExperimentSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ import symsim.concrete.Randomized2
import cats.syntax.all.*
import symsim.concrete.Randomized2.*

given spire.random.rng.SecureJava =
spire.random.rng.SecureJava.apply

trait ExperimentSpec[State, ObservableState, Action]
extends org.scalatest.freespec.AnyFreeSpec,
Expand Down Expand Up @@ -70,12 +68,13 @@ trait ExperimentSpec[State, ObservableState, Action]
policies: List[setup.Policy],
initials: Option[Randomized2[State]] = None,
noOfEpisodes: Int = 5
): EvaluationResults =
) (using probula.RNG): EvaluationResults =
val ss: Randomized2[State] = initials.getOrElse (setup.agent.initialize)
for p <- policies
episodeRewards: Randomized2[Randomized2[Double]] =
setup.evaluate (p, ss)
rewards: Randomized2[Double] = episodeRewards.map { e => e.sample () }
rewards: Randomized2[Double] =
episodeRewards.map { e => e.sample () }
yield rewards.sample (noOfEpisodes).toList


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import symsim.examples.concrete.mountaincar.MountainCar

private val mountainCar =
new MountainCar (using spire.random.rng.SecureJava.apply)
private given spire.random.rng.SecureJava =
spire.random.rng.SecureJava.apply

import mountainCar.instances.given

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@ package examples.concrete.braking
import laws.AgentLaws
import laws.EpisodicLaws

private given spire.random.rng.SecureJava =
spire.random.rng.SecureJava.apply
private val car = new Car

class CarIsAgentSpec
extends SymSimSpec:

checkAll ("concrete.braking.Car is an Agent", AgentLaws (new Car).laws)
checkAll ("concrete.braking.Car is Episodic", EpisodicLaws (new Car).laws)
checkAll ("concrete.braking.Car is an Agent", AgentLaws (car).laws)
checkAll ("concrete.braking.Car is Episodic", EpisodicLaws (car).laws)
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package symsim
package examples.concrete.cartpole

// Import evidence that states and actions can be enumerated
private val cartPole =
new CartPole (using spire.random.rng.SecureJava.apply)
private given spire.random.rng.SecureJava =
spire.random.rng.SecureJava.apply
private val cartPole: CartPole = new CartPole
import cartPole.instances.{enumAction, enumState}

class Experiments extends
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package symsim
package examples.concrete.cliffWalking

private val cliffWalking =
new CliffWalking (using spire.random.rng.SecureJava.apply)

private given spire.random.rng.SecureJava =
spire.random.rng.SecureJava.apply
private val cliffWalking: CliffWalking = new CliffWalking
import cliffWalking.instances.{enumAction, enumState}

class Experiments
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package symsim
package examples.concrete.mountaincar

private val mountainCar =
new MountainCar (using spire.random.rng.SecureJava.apply)

private given spire.random.rng.SecureJava =
spire.random.rng.SecureJava.apply
private val mountainCar: MountainCar = new MountainCar
import mountainCar.instances.{enumAction, enumState}

class Experiments
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package symsim
package examples.concrete.pumping

private val pump = new Pump (using spire.random.rng.SecureJava.apply)

private given spire.random.rng.SecureJava =
spire.random.rng.SecureJava.apply
private val pump: Pump = new Pump

class Experiments
extends ExperimentSpec[PumpState, ObservablePumpState, PumpAction]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package symsim
package examples.concrete.simplebandit

private given spire.random.rng.SecureJava =
spire.random.rng.SecureJava.apply

class Experiments
extends ExperimentSpec[BanditState,BanditState,BanditAction]:

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package symsim
package examples.concrete.simplemaze

given spire.random.rng.SecureJava =
private given spire.random.rng.SecureJava =
spire.random.rng.SecureJava.apply

private val maze = new Maze
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import symsim.concrete.ConcreteSarsa
import symsim.concrete.Randomized2
import symsim.concrete.Randomized2.given

given spire.random.rng.SecureJava = spire.random.rng.SecureJava.apply
private given spire.random.rng.SecureJava =
spire.random.rng.SecureJava.apply

private val windyGrid = new WindyGrid

Expand Down

0 comments on commit 06379f2

Please sign in to comment.