Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start support for muxing function handles #615

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crucible/crucible.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ library
Lang.Crucible.Utils.CoreRewrite
Lang.Crucible.Utils.MonadVerbosity
Lang.Crucible.Utils.MuxTree
Lang.Crucible.Utils.PartitioningMuxTree
Lang.Crucible.Utils.PrettyPrint
Lang.Crucible.Utils.RegRewrite
Lang.Crucible.Utils.StateContT
Expand Down
8 changes: 5 additions & 3 deletions crucible/src/Lang/Crucible/Simulator/Evaluation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ import Lang.Crucible.Simulator.RegMap
import Lang.Crucible.Simulator.SimError
import Lang.Crucible.Types
import Lang.Crucible.Utils.MuxTree
import qualified Lang.Crucible.Utils.PartitioningMuxTree as PMT

------------------------------------------------------------------------
-- Utilities
Expand Down Expand Up @@ -476,12 +477,13 @@ evalApp sym itefns _logFn evalExt (evalSub :: forall tp. f tp -> IO (RegValue sy
----------------------------------------------------------------------
-- Handle

HandleLit h -> return (HandleFnVal h)
HandleLit h -> return (PMT.toPartitioningMuxTree sym (HandleFnVal h))

Closure _ _ h_expr tp v_expr -> do
Closure argReprs retRepr h_expr tp v_expr -> do
h <- evalSub h_expr
v <- evalSub v_expr
return $! ClosureFnVal h tp v
let closure = ClosureFnVal (argReprs Ctx.:> tp) retRepr h v
return $! PMT.toPartitioningMuxTree sym closure

----------------------------------------------------------------------
-- RealVal
Expand Down
33 changes: 33 additions & 0 deletions crucible/src/Lang/Crucible/Simulator/ExecutionTree.hs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ module Lang.Crucible.Simulator.ExecutionTree
, FrameRetType

-- ** ReturnHandler
, SuspendedCallees(..)
, ReturnHandler(..)

-- * ActiveTree
Expand All @@ -94,6 +95,7 @@ module Lang.Crucible.Simulator.ExecutionTree
, activeFrames
, actContext
, actFrame
, actResult

-- * Simulator context
-- ** Function bindings
Expand Down Expand Up @@ -141,6 +143,7 @@ module Lang.Crucible.Simulator.ExecutionTree
import Control.Lens
import Control.Monad.Reader
import Data.Kind
import qualified Data.List.NonEmpty as DLN
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Parameterized.Ctx
Expand Down Expand Up @@ -843,6 +846,22 @@ vfvParents c0 =
------------------------------------------------------------------------
-- ReturnHandler

-- | Internal states for the 'CallNext' 'ReturnHandler'
--
-- We either have a non-empty list of call targets to execute *or* we have some
-- call targets and a defined "previous" state from executing alternative call
-- targets.
--
-- By tracking these states together, we rule out an impossible state where we
-- have an empty call list and no previously collected return states
data SuspendedCallees p sym ext f args ret where
InitialSuspendedCall :: DLN.NonEmpty (ResolvedCall p sym ext ret, Pred sym)
-> SuspendedCallees p sym ext f args ret
SuspendedCallees :: [(ResolvedCall p sym ext ret, Pred sym)]
-> PartialResult sym ext (SimFrame sym ext f args)
-> RegEntry sym ret
-> SuspendedCallees p sym ext f args ret

{- | A 'ReturnHandler' indicates what actions to take to resume
executing in a caller's context once a function call has completed and
the return value is avaliable.
Expand Down Expand Up @@ -891,6 +910,20 @@ data ReturnHandler (ret :: CrucibleType) p sym ext root f args where
(ret ~ r) =>
ReturnHandler ret p sym ext root (CrucibleLang blocks r) ctx

{- | The 'CallNext' constructor indicates that the simulator needs to pause and
call the next alternative callee for the call site. To do so, it should
use the saved 'SimState', which was the same 'SimState' used by the first
potential callee (so that all callees see the same initial state).

-}
CallNext ::
ProgramLoc {- ^ The location of the call site -} ->
ReturnHandler ret p sym ext root f args {- ^ The original return handler -} ->
SimState p sym ext rtp f a {- ^ The sim state to use for each alternative call -} ->
SuspendedCallees p sym ext f args ret {- ^ The remaining callees and their muxed results -} ->
Pred sym {- ^ The predicate under which the current target is valid -} ->
ReturnHandler ret p sym ext root f args


------------------------------------------------------------------------
-- ActiveTree
Expand Down
164 changes: 144 additions & 20 deletions crucible/src/Lang/Crucible/Simulator/Operations.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
Expand Down Expand Up @@ -78,16 +79,19 @@ import Prelude hiding (pred)
import qualified Control.Exception as Ex
import Control.Lens
import Control.Monad.Reader
import Data.Maybe (fromMaybe)
import qualified Data.Foldable as F
import Data.List (isPrefixOf)
import qualified Data.List.NonEmpty as DLN
import Data.Maybe (fromMaybe)
import qualified Data.Parameterized.Context as Ctx
import Data.Parameterized.Some
import Data.Sequence (Seq)
import qualified Data.Sequence as Seq
import qualified Data.Vector as V
import qualified Data.Traversable as T
import Data.Type.Equality hiding (sym)
import System.IO
import qualified Data.Vector as V
import qualified Prettyprinter as PP
import System.IO

import What4.Config
import What4.Interface
Expand All @@ -105,6 +109,7 @@ import Lang.Crucible.Simulator.GlobalState
import Lang.Crucible.Simulator.Intrinsics
import Lang.Crucible.Simulator.RegMap
import Lang.Crucible.Simulator.SimError
import qualified Lang.Crucible.Utils.PartitioningMuxTree as PMT

---------------------------------------------------------------------
-- Intermediate state branching/merging
Expand Down Expand Up @@ -359,19 +364,47 @@ packVarargs = go mempty
-- the underlying function handle is not found in the
-- 'FunctionBindings' map.
resolveCall ::
(IsSymInterface sym) =>
sym ->
FunctionBindings p sym ext {- ^ Map from function handles to semantics -} ->
PMT.PartitioningMuxTree sym (FnVal sym args ret) {- ^ Function handle and any closure variables -} ->
RegMap sym args {- ^ Arguments to the function -} ->
ProgramLoc {- ^ Location of the call -} ->
[SomeFrame (SimFrame sym ext)] {-^ current call stack (for exceptions) -} ->
IO (DLN.NonEmpty (ResolvedCall p sym ext ret, Pred sym))
resolveCall sym = resolveCallPred sym (truePred sym)

resolveCallPred ::
(IsSymInterface sym) =>
sym ->
Pred sym {-^ A predicate to mux into all of the conditions -} ->
FunctionBindings p sym ext {- ^ Map from function handles to semantics -} ->
FnVal sym args ret {- ^ Function handle and any closure variables -} ->
PMT.PartitioningMuxTree sym (FnVal sym args ret) {- ^ Function handle and any closure variables -} ->
RegMap sym args {- ^ Arguments to the function -} ->
ProgramLoc {- ^ Location of the call -} ->
[SomeFrame (SimFrame sym ext)] {-^ current call stack (for exceptions) -} ->
ResolvedCall p sym ext ret
resolveCall bindings c0 args loc callStack =
IO (DLN.NonEmpty (ResolvedCall p sym ext ret, Pred sym))
resolveCallPred sym p0 bindings mt0 args loc callStack = do
elts <- T.traverse (\(v, cond) -> (v,) <$> andPred sym p0 cond) (PMT.viewPartitioningMuxTree mt0)
calls <- T.traverse (resolveCallFnVal sym bindings args loc callStack) elts
return (join calls)

resolveCallFnVal
:: (IsSymInterface sym)
=> sym
-> FunctionBindings p sym ext
-> RegMap sym args
-> ProgramLoc
-> [SomeFrame (SimFrame sym ext)]
-> (FnVal sym args ret, Pred sym)
-> IO (DLN.NonEmpty (ResolvedCall p sym ext ret, Pred sym))
resolveCallFnVal sym bindings args loc callStack (c0, p) =
case c0 of
ClosureFnVal c tp v -> do
resolveCall bindings c (assignReg tp v args) loc callStack
ClosureFnVal (_argReprs Ctx.:> tp) _retRepr c v -> do
resolveCallPred sym p bindings c (assignReg tp v args) loc callStack

VarargsFnVal h addlTypes ->
resolveCall bindings (HandleFnVal h) (packVarargs addlTypes args) loc callStack
resolveCallFnVal sym bindings (packVarargs addlTypes args) loc callStack (HandleFnVal h, p)

HandleFnVal h -> do
case lookupHandleMap h bindings of
Expand All @@ -381,9 +414,9 @@ resolveCall bindings c0 args loc callStack =
, _overrideHandle = SomeHandle h
, _overrideRegMap = args
}
in OverrideCall o f
Just (UseCFG g pdInfo) -> do
CrucibleCall (cfgEntryBlockID g) (mkCallFrame g pdInfo args)
in return ((OverrideCall o f, p) DLN.:| [])
Just (UseCFG g pdInfo) ->
return ((CrucibleCall (cfgEntryBlockID g) (mkCallFrame g pdInfo args), p) DLN.:| [])


resolvedCallName :: ResolvedCall p sym ext ret -> FunctionName
Expand Down Expand Up @@ -536,30 +569,51 @@ returnValue arg =


callFunction ::
IsExprBuilder sym =>
FnVal sym args ret {- ^ Function handle and any closure variables -} ->
IsSymInterface sym =>
RegValue sym (FunctionHandleType args ret) {- ^ Function handle and any closure variables -} ->
RegMap sym args {- ^ Arguments to the function -} ->
ReturnHandler ret p sym ext rtp f a {- ^ How to modify the caller's scope with the return value -} ->
ProgramLoc {-^ location of call -} ->
ExecCont p sym ext rtp f a
callFunction fn args retHandler loc =
do bindings <- view (stateContext.functionBindings)
callStack <- view (stateTree . to activeFrames)
let rcall = resolveCall bindings fn args loc callStack
ReaderT $ return . CallState retHandler rcall
sym <- view (stateContext . ctxSymInterface)
rcalls <- liftIO $ resolveCall sym bindings fn args loc callStack

ReaderT $ \ctx ->
case rcalls of
(firstTarget, muxPred) DLN.:| otherTargets ->
case otherTargets of
[] ->
-- In the degenerate case of a single target, just issue a standard call/return
--
-- We ignore the predicate for a single target. NOTE: We could
-- assert it (it should just be True)
return (CallState retHandler firstTarget ctx)
rest1 : restn -> do
-- Otherwise, wrap the return state up in a special return that will
-- run each possible call (sequentially) with a captured initial
-- state. When the first target returns, the modified handler calls
-- the next possible target. When callers are exhausted, it merges
-- the results together and performs the original type of return.
let initTargets = InitialSuspendedCall (rest1 DLN.:| restn)
let cpsHandler = CallNext loc retHandler ctx initTargets muxPred
return (CallState cpsHandler firstTarget ctx)

tailCallFunction ::
FrameRetType f ~ ret =>
FnVal sym args ret {- ^ Function handle and any closure variables -} ->
(FrameRetType f ~ ret, IsSymInterface sym) =>
RegValue sym (FunctionHandleType args ret) {- ^ Function handle and any closure variables -} ->
RegMap sym args {- ^ Arguments to the function -} ->
ValueFromValue p sym ext rtp ret ->
ProgramLoc {-^ location of call -} ->
ExecCont p sym ext rtp f a
tailCallFunction fn args vfv loc =
do bindings <- view (stateContext.functionBindings)
callStack <- view (stateTree . to activeFrames)
let rcall = resolveCall bindings fn args loc callStack
ReaderT $ return . TailCallState vfv rcall
sym <- view (stateContext . ctxSymInterface)
rcalls <- liftIO $ resolveCall sym bindings fn args loc callStack
ReaderT $ return . TailCallState vfv (error "rcalls")


-- | Immediately transition to the 'BranchMergeState'.
Expand Down Expand Up @@ -811,6 +865,11 @@ handleSimReturn fnName vfv return_value =


-- | Resolve the return value, and begin executing in the caller's context again.
--
-- FIXME: Modify this to handle the CPS-ed call (issuing another call here)
--
-- Challenge: We need the simulator state from *before* we issue the call (so
-- that all possible callees start with the same state)
performReturn ::
IsSymInterface sym =>
FunctionName {- ^ Name of the function we are returning from -} ->
Expand Down Expand Up @@ -838,6 +897,49 @@ performReturn fnName ctx0 v = do
(stateTree .~ ActiveTree ctx (pres & partialValue . gpValue .~ OF f))
(ReaderT (k v))

VFVCall ctx frm (CallNext callLoc origHandler callCtx suspendedCallees returnPred) -> do
case suspendedCallees of
InitialSuspendedCall ((nextCallee, calleeValidWhen) DLN.:| otherCallees) -> do
nextCtx <- ask
let sym = nextCtx ^. stateSymInterface
let intrinsics = nextCtx ^. stateIntrinsicTypes
-- Start the CPS multi-call sequence
--
-- This is a separate case (from the 'SuspendedCallees' case) because
-- we don't have any initial values we need to mux here.
let firstResult = nextCtx ^. stateTree . actResult
frame <- liftIO $ resultDefinedWhen sym callLoc returnPred firstResult
let nextSuspendedCallees = SuspendedCallees otherCallees firstResult v
let call = CallNext callLoc origHandler callCtx nextSuspendedCallees calleeValidWhen
return (CallState call nextCallee callCtx)
SuspendedCallees [] prevPartialRes prevRetVal -> do
nextCtx <- ask
let sym = nextCtx ^. stateSymInterface
let intrinsics = nextCtx ^. stateIntrinsicTypes
-- We have called all of the potential callees, so now we just issue a
-- normal return (that returns the collection of muxed return values)
let ctx1 = VFVCall ctx frm origHandler
let thisRes = nextCtx ^. stateTree . actResult
thisFrame <- liftIO $ mergePartialResult nextCtx ReturnTarget returnPred prevPartialRes thisRes
mergedRetVal <- liftIO $ muxRegEntry sym intrinsics returnPred v prevRetVal
withReaderT (& stateTree . actResult .~ thisFrame) $ performReturn fnName ctx1 mergedRetVal
SuspendedCallees ((nextCallee, calleeValidWhen) : remainingTargets) prevPartialRes prevRetVal -> do
nextCtx <- ask
let sym = nextCtx ^. stateSymInterface
let intrinsics = nextCtx ^. stateIntrinsicTypes
let thisRes = nextCtx ^. stateTree . actResult
-- This is morally appealing, but the 'CrucibleBranchTarget' seems to
-- be very difficult to create here. Do we need to create that when
-- we issue the call?
thisFrame <- liftIO $ mergePartialResult nextCtx ReturnTarget returnPred prevPartialRes thisRes
mergedVal <- liftIO $ muxRegEntry sym intrinsics returnPred v prevRetVal
let nextSuspendedCallees = SuspendedCallees remainingTargets thisFrame mergedVal
let call = CallNext callLoc origHandler callCtx nextSuspendedCallees calleeValidWhen
-- NOTE: We have saved the initial context of the call so that we can
-- issue all of the calls with the same initial state (and so that
-- they cannot observe each other's effects)
return (CallState call nextCallee callCtx)

VFVPartial ctx loc pred r ->
do sym <- view stateSymInterface
ActiveTree oldctx pres <- view stateTree
Expand All @@ -852,6 +954,28 @@ performReturn fnName ctx0 v = do
ActiveTree _oldctx pres <- view stateTree
return $! ResultState $ FinishedResult simctx (pres & partialValue . gpValue .~ v)

-- | Mark a value as defined when the given 'Pred' holds
--
-- If the value was originally a 'TotalRes', it becomes a 'PartialRes'
--
-- Otherwise, the condition is conjoined with the existing 'Pred'
resultDefinedWhen ::
IsSymInterface sym =>
sym ->
ProgramLoc ->
Pred sym ->
PartialResult sym ext (SimFrame sym ext l args) ->
IO (PartialResult sym ext (SimFrame sym ext l args))
resultDefinedWhen sym loc p v =
case v of
TotalRes gp -> do
let aborted = AbortedExec InfeasibleBranch gp
return (PartialRes loc p gp aborted)
PartialRes _loc0 p0 gp aborted -> do
combinedPred <- orPred sym p0 p
return (PartialRes loc combinedPred gp aborted)


cruciblePausedFrame ::
ResolvedJump sym b ->
GlobalPair sym (SimFrame sym ext (CrucibleLang b r) ('Just a)) ->
Expand Down
3 changes: 2 additions & 1 deletion crucible/src/Lang/Crucible/Simulator/RegMap.hs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ import Lang.Crucible.Simulator.RegValue
import Lang.Crucible.Simulator.SimError
import Lang.Crucible.Types
import Lang.Crucible.Utils.MuxTree
import qualified Lang.Crucible.Utils.PartitioningMuxTree as PMT
import Lang.Crucible.Backend
import Lang.Crucible.Panic

Expand Down Expand Up @@ -229,7 +230,7 @@ muxRegForType s itefns p =
case isPosNat w of
Nothing -> \_ x _ -> return x
Just LeqProof -> bvIte s
FunctionHandleRepr _ _ -> muxReg s p
FunctionHandleRepr _ _ -> PMT.mergePartitioningMuxTree s muxFnValSymbolicPart

MaybeRepr r -> mergePartExpr s (muxRegForType s itefns r)
VectorRepr r -> muxVector s (muxRegForType s itefns r)
Expand Down
Loading