Skip to content

Commit

Permalink
Try a new approach to the Gen monad that makes it a primitive.
Browse files Browse the repository at this point in the history
This would allow us to eventually slot in someting like Hedgehog
instead to drive the random testing infrastructure, and allow
us to use it's existing shrinking infrastructure, etc.
  • Loading branch information
robdockins committed Jan 4, 2021
1 parent 87efe84 commit bf25423
Show file tree
Hide file tree
Showing 20 changed files with 615 additions and 312 deletions.
1 change: 1 addition & 0 deletions cryptol.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ library

Cryptol.Backend,
Cryptol.Backend.Arch,
Cryptol.Backend.SeqMap,
Cryptol.Backend.Concrete,
Cryptol.Backend.FloatHelpers,
Cryptol.Backend.Monad,
Expand Down
100 changes: 59 additions & 41 deletions lib/Testing.cry
Original file line number Diff line number Diff line change
@@ -1,44 +1,73 @@
module Testing where

primitive type RandGen : *

infixl 1 >>=
infixl 3 <|>
infixl 4 <$>
infixl 4 <*>

primitive seedGen : [8] -> [256] -> RandGen
primitive genSize : RandGen -> [8]
primitive genResize : [8] -> RandGen -> RandGen
primitive splitGen : {n} (fin n, 2 <= n, width n <= 32) => RandGen -> [n]RandGen
primitive type Gen : * -> *

type Gen a = RandGen -> (a, RandGen)
primitive runGen : {a} [8] -> [256] -> Gen a -> a

return : {a} a -> Gen a
return a g = (a, g)
primitive return : {a} a -> Gen a
primitive (>>=) : {a,b} Gen a -> (a -> Gen b) -> Gen b
primitive (<$>) : {a,b} (a -> b) -> Gen a -> Gen b
primitive (<*>) : {a,b} Gen (a -> b) -> Gen a -> Gen b
primitive genStream : {a} Gen a -> Gen ([inf]a)

(<$>) : {a,b} (a -> b) -> Gen a -> Gen b
(<$>) f m g = (f a, g')
where
(a, g') = m g
primitive withSize : {a} [8] -> Gen a -> Gen a

(<*>) : {a,b} Gen (a -> b) -> Gen a -> Gen b
(<*>) mf m g = (f x, g2)
where
(f,g1) = mf g
(x,g2) = m g1
// primitive type RandGen : *
// primitive seedGen : [8] -> [256] -> RandGen
// primitive genSize : RandGen -> [8]
// primitive genResize : [8] -> RandGen -> RandGen
// primitive splitGen : {n} (fin n, 2 <= n, width n <= 32) => RandGen -> [n]RandGen

//type Gen a = RandGen -> (a, RandGen)

///runGen : {a} [8] -> [256] -> Gen a -> a
//runGen sz seed m = (m (seedGen sz seed)).0

//return : {a} a -> Gen a
//return a = \g -> (a,g)

//(>>=) : {a,b} Gen a -> (a -> Gen b) -> Gen b
//(>>=) m f = \g -> uncurry f (m g)

//(<$>) : {a,b} (a -> b) -> Gen a -> Gen b
//(<$>) f m = \g -> ((f x,g') where (x,g') = m g)

//(<*>) : {a,b} Gen (a -> b) -> Gen a -> Gen b
//(<*>) mf m = \g ->
// ((f x,g2) where
// (f,g1) = mf g
// (x,g2) = m g1)

// private
// mkStream : {a} Gen a -> RandGen -> [inf]a
// mkStream m g = xs.0
// where
// xs = [ m g ] # [ m g' | (_,g') <- xs ]

// genStreams : {n,a} (fin n, n>=1, width (n+1) <= 32) => Gen a -> Gen ([n][inf]a)
// genStreams m = \g0 -> ( (map (mkStream m) gs, g') where gs#[g'] = splitGen`{n+1} g0 )

// genStream : {a} Gen a -> Gen ([inf]a)
// genStream m = \g0 -> ((mkStream m g1, g2) where [g1,g2] = splitGen g0)

(>>=) : {a,b} Gen a -> (a -> Gen b) -> Gen b
(>>=) m f g = f x g'
where
(x,g') = m g

primitive generate : {a} Generate a => Gen a

unboundedInteger : Gen Integer
unboundedInteger = generate`{Integer}

primitive boundedInteger : (Integer, Integer) -> Gen Integer
primitive boundedBelowInteger : Integer -> Gen Integer
primitive boundedAboveInteger : Integer -> Gen Integer

unboundedWord : {n} (fin n) => Gen [n]
unboundedWord = generate`{[n]}

primitive boundedWord : {n} (fin n) => ([n],[n]) -> Gen [n]
primitive boundedSignedWord : {n} (fin n) => ([n],[n]) -> Gen [n]

Expand All @@ -53,24 +82,6 @@ choose gs = boundedInteger (0,`(n-1)) >>= \i -> gs@i
oneOf : {n,a} (fin n, n >= 1) => [n]a -> Gen a
oneOf xs = choose (map return xs)

private
mkStream : {a} Gen a -> RandGen -> [inf]a
mkStream m g = xs.0
where
xs = [ m g ] # [ m g' | (_,g') <- xs ]

genStreams : {n,a} (fin n, n>=1, width (n+1) <= 32) => Gen a -> Gen ([n][inf]a)
genStreams m g0 = (xss, g')
where
gs#[g'] = splitGen`{n+1} g0
xss = map (mkStream m) gs

genStream : {a} Gen a -> Gen ([inf]a)
genStream m g0 = ( xs, g2 )
where
[g1,g2] = splitGen g0
xs = mkStream m g1

genSequence : {n,a} (fin n) => Gen a -> Gen ([n]a)
genSequence m = take`{n} <$> genStream m

Expand Down Expand Up @@ -98,6 +109,13 @@ property addOddProp = evenSum <$> oddInteger <*> oddInteger
where
evenSum x y = isEven (x + y)

runTests : Integer -> Gen Bit -> [256] -> Bit
runTests num m seed = runGen 100 seed (loop 0)
where
loop n = if n >= num then return True else go n
go n = withSize (sz n) m >>= \b -> if b then loop (n+1) else return False
sz n = fromInteger (1 + ((n * 100) / num))

withCounterexample : {a,n} (fin n) => String n -> a -> Bit -> Bit
withCounterexample msg vals test = test \/ traceError ("counterexample " # msg) vals

Expand All @@ -114,7 +132,7 @@ property addNoOverflow = noOverflow <$> boundedBelowInteger 0 <*> boundedBelowIn
noOverflow x y = withCounterexample "noOverflow" (x,y) (x <= (x+y) /\ y <= (x+y))

addNoWordOverflow : Gen Bit
property addNoWordOverflow = noOverflow <$> boundedWord (0,0x7FFFFFFF) <*> boundedWord (0,0x7FFFFFFF)
property addNoWordOverflow = noOverflow <$> boundedWord (0,0x8FFFFFFF) <*> boundedWord (0,0x7FFFFFFF)
where
noOverflow : [32] -> [32] -> Bit
noOverflow x y = withCounterexample "noOverflow" (x,y) (x <= (x+y) /\ y <= (x+y))
38 changes: 37 additions & 1 deletion src/Cryptol/Backend.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module Cryptol.Backend
, cryUserError
, cryNoPrimError
, FPArith2
, SeqMap(..)

-- * Rationals
, SRational(..)
Expand All @@ -31,6 +32,7 @@ module Cryptol.Backend

import Control.Monad.IO.Class
import Data.Kind (Type)
import Data.Map.Strict (Map)

import Cryptol.Backend.FloatHelpers (BF)
import Cryptol.Backend.Monad
Expand Down Expand Up @@ -191,15 +193,23 @@ iteRational :: Backend sym => sym -> SBit sym -> SRational sym -> SRational sym
iteRational sym p (SRational a b) (SRational c d) =
SRational <$> iteInteger sym p a c <*> iteInteger sym p b d

-- | A sequence map represents a mapping from nonnegative integer indices
-- to values. These are used to represent both finite and infinite sequences.
data SeqMap sym a
= IndexSeqMap !(Integer -> SEval sym a)
| UpdateSeqMap !(Map Integer (SEval sym a))
!(Integer -> SEval sym a)

-- | This type class defines a collection of operations on bits, words and integers that
-- are necessary to define generic evaluator primitives that operate on both concrete
-- and symbolic values uniformly.
class MonadIO (SEval sym) => Backend sym where
class (MonadIO (SEval sym), Monad (SGen sym)) => Backend sym where
type SBit sym :: Type
type SWord sym :: Type
type SInteger sym :: Type
type SFloat sym :: Type
type SEval sym :: Type -> Type
type SGen sym :: Type -> Type

-- ==== Evaluation monad operations ====

Expand Down Expand Up @@ -254,6 +264,32 @@ class MonadIO (SEval sym) => Backend sym where
-- | Indiciate that an error condition exists
raiseError :: sym -> EvalError -> SEval sym a

-- ==== Value generation operations =====

-- | Lifts evaluation into the generator monad
sGenLift :: sym -> SEval sym a -> SGen sym a

-- | Given a 8-bit size value and a 256-bit random seed, run the generator action
sRunGen :: sym -> SWord sym -> SWord sym -> SGen sym a -> SEval sym a

sGenGetSize :: sym -> SGen sym (SWord sym)
sGenWithSize :: sym -> SWord sym -> SGen sym a -> SGen sym a

sGenerateBit :: sym -> SGen sym (SBit sym)

sUnboundedWord :: sym -> Integer -> SGen sym (SWord sym)
sBoundedWord :: sym -> (SWord sym, SWord sym) -> SGen sym (SWord sym)
sBoundedSignedWord :: sym -> (SWord sym, SWord sym) -> SGen sym (SWord sym)

sUnboundedInteger :: sym -> SGen sym (SInteger sym)
sBoundedInteger :: sym -> (SInteger sym, SInteger sym) -> SGen sym (SInteger sym)
sBoundedBelowInteger :: sym -> SInteger sym -> SGen sym (SInteger sym)
sBoundedAboveInteger :: sym -> SInteger sym -> SGen sym (SInteger sym)

sSuchThat :: sym -> SGen sym a -> (a -> SEval sym (SBit sym)) -> SGen sym a

sGenStream :: sym -> SGen sym a -> SGen sym (SeqMap sym a)

-- ==== Identifying literal values ====

-- | Determine if this symbolic bit is a boolean literal
Expand Down
87 changes: 85 additions & 2 deletions src/Cryptol/Backend/Concrete.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
module Cryptol.Backend.Concrete
Expand All @@ -36,15 +37,21 @@ module Cryptol.Backend.Concrete
) where

import qualified Control.Exception as X
import Data.Bits
import Numeric (showIntAtBase)
import Control.Monad.Reader
import Control.Monad.State.Strict
import Data.Bits
import Data.Word
import Numeric (showIntAtBase)
import qualified LibBF as FP
import qualified GHC.Integer.GMP.Internals as Integer
import System.Random (random, randomR, split)
import System.Random.TF.Gen (TFGen, seedTFGen)

import qualified Cryptol.Backend.Arch as Arch
import qualified Cryptol.Backend.FloatHelpers as FP
import Cryptol.Backend
import Cryptol.Backend.Monad
import Cryptol.Backend.SeqMap
import Cryptol.TypeCheck.Solver.InfNat (genLog)
import Cryptol.Utils.Panic (panic)
import Cryptol.Utils.PP
Expand Down Expand Up @@ -129,12 +136,19 @@ mask ::
mask w i | w >= Arch.maxBigIntWidth = wordTooWide w
| otherwise = i .&. (bit (fromInteger w) - 1)

randomSize :: Word8 -> Word8 -> TFGen -> (Word8, TFGen)
randomSize k n g
| p == 1 = (n, g')
| otherwise = randomSize k (n + 1) g'
where (p, g') = randomR (1, k) g

instance Backend Concrete where
type SBit Concrete = Bool
type SWord Concrete = BV
type SInteger Concrete = Integer
type SFloat Concrete = FP.BF
type SEval Concrete = Eval
type SGen Concrete = ReaderT Word8 (StateT TFGen Eval)

raiseError _ err =
do stk <- getCallStack
Expand Down Expand Up @@ -165,6 +179,75 @@ instance Backend Concrete where
sModifyCallStack _ f m = modifyCallStack f m
sGetCallStack _ = getCallStack

sGenLift _sym m = lift (lift m)
sRunGen _sym sz seed m =
do let mask64 = 0xFFFFFFFFFFFFFFFF
unpack s = fromInteger (s .&. mask64) : unpack (s `shiftR` 64)
[a, b, c, d] = take 4 (unpack (bvVal seed))
g0 = seedTFGen (a,b,c,d)
fst <$> runStateT (runReaderT m (fromInteger (bvVal sz))) g0

sGenGetSize _sym = reader (mkBv 8 . toInteger)
sGenWithSize _sym sz m = local (\_ -> fromInteger (bvVal sz)) m

sGenerateBit _sym = state random
sUnboundedWord _sym w = mkBv w <$> state (randomR (0, 2^w - 1))
sBoundedWord _sym (BV w b1,BV _ b2) = mkBv w <$> state (randomR (lo, hi))
where
lo = min b1 b2
hi = max b1 b2

sBoundedSignedWord _sym (BV w x1,BV _ x2) = mkBv w <$> state (randomR (lo, hi))
where
b1 = signedValue w x1
b2 = signedValue w x2

lo = min b1 b2
hi = max b1 b2

sBoundedInteger _sym (x1,x2) = state (randomR (lo, hi))
where
lo = min x1 x2
hi = max x1 x2

sUnboundedInteger _sym =
do sz <- ask
n <- if sz < 100 then pure sz else state (randomSize 8 100)
state (randomR (- 256^n, 256^n ))

sBoundedBelowInteger _sym lo =
do sz <- ask
n <- if sz < 100 then pure sz else state (randomSize 8 100)
x <- state (randomR (0, 256^n))
pure (lo + x)

sBoundedAboveInteger _sym hi =
do sz <- ask
n <- if sz < 100 then pure sz else state (randomSize 8 100)
x <- state (randomR (0, 256^n))
pure (hi - x)

sSuchThat sym m p =
do x <- m
b <- lift (lift (p x))
if b then pure x else sSuchThat sym m p

sGenStream sym m =
do sz <- ask
(x, fill) <- sGenLift sym (blackhole "sGenStream")
g0 <- state split
let mkElem = runStateT (runReaderT m sz)
let mkMap = IndexSeqMap @Concrete \i ->
if i <= 0 then
mkElem g0
else
do x' <- x
(_,g) <- lookupSeqMap @Concrete x' (i-1)
mkElem g
sGenLift sym (fill (memoMap sym mkMap))
pure (IndexSeqMap \i -> x >>= \x' -> fst <$> lookupSeqMap @Concrete x' i)


bitLit _ b = b
bitAsLit _ b = Just b

Expand Down
3 changes: 2 additions & 1 deletion src/Cryptol/Backend/SBV.hs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ module Cryptol.Backend.SBV

import qualified Control.Exception as X
import Control.Concurrent.MVar
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Reader
import Data.Bits (bit, complement)
import Data.List (foldl')

Expand Down Expand Up @@ -153,6 +153,7 @@ instance Backend SBV where
type SInteger SBV = SVal
type SFloat SBV = () -- XXX: not implemented
type SEval SBV = SBVEval
type SGen SBV = ReaderT SVal SBVEval

raiseError _ err = SBVEval $
do stk <- getCallStack
Expand Down
Loading

0 comments on commit bf25423

Please sign in to comment.