Skip to content

Commit

Permalink
feat!: add ToBytes and FromPointer class for converting to and from e…
Browse files Browse the repository at this point in the history
…xtism memory
  • Loading branch information
zshipko committed Sep 15, 2023
1 parent 4a379e3 commit e4d16de
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 41 deletions.
111 changes: 75 additions & 36 deletions src/Extism.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
{-# LANGUAGE FlexibleInstances #-}

module Extism (
module Extism.Manifest,
Function(..),
Expand All @@ -9,7 +11,7 @@ module Extism (
toByteString,
fromByteString,
extismVersion,
plugin,
newPlugin,
pluginFromManifest,
isValid,
setConfig,
Expand All @@ -19,7 +21,10 @@ module Extism (
cancelHandle,
cancel,
pluginID,
unwrap
unwrap,
ToBytes(..),
FromPointer(..),
JSONValue(..)
) where

import Foreign.ForeignPtr
Expand All @@ -34,7 +39,8 @@ import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Data.ByteString.Internal (c2w, w2c)
import Data.ByteString.Unsafe (unsafeUseAsCString)
import qualified Text.JSON (encode, toJSObject, showJSON)
import qualified Text.JSON (encode, decode, toJSObject, showJSON, Result(..))
import qualified Extism.JSON (JSON(..))
import Extism.Manifest (Manifest, toString)
import Extism.Bindings
import qualified Data.UUID (UUID, fromByteString)
Expand Down Expand Up @@ -73,18 +79,18 @@ extismVersion () = do

-- | Create a 'Plugin' from a WASM module, `useWasi` determines if WASI should
-- | be linked
plugin :: B.ByteString -> [Function] -> Bool -> IO (Result Plugin)
plugin wasm functions useWasi =
newPlugin :: B.ByteString -> [Function] -> Bool -> IO (Result Plugin)
newPlugin wasm functions useWasi =
let nfunctions = fromIntegral (length functions) in
let length = fromIntegral (B.length wasm) in
let length' = fromIntegral (B.length wasm) in
let wasi = fromInteger (if useWasi then 1 else 0) in
do
funcs <- mapM (\(Function ptr _) -> withForeignPtr ptr (\x -> do return x)) functions
alloca (\e-> do
let errmsg = (e :: Ptr CString)
p <- unsafeUseAsCString wasm (\s ->
withArray funcs (\funcs ->
extism_plugin_new (castPtr s) length funcs nfunctions wasi errmsg ))
extism_plugin_new (castPtr s) length' funcs nfunctions wasi errmsg ))
if p == nullPtr then do
err <- peek errmsg
e <- peekCString err
Expand All @@ -98,7 +104,7 @@ plugin wasm functions useWasi =
pluginFromManifest :: Manifest -> [Function] -> Bool -> IO (Result Plugin)
pluginFromManifest manifest functions useWasi =
let wasm = toByteString $ toString manifest in
plugin wasm functions useWasi
newPlugin wasm functions useWasi

-- | Check if a 'Plugin' is valid
isValid :: Plugin -> IO Bool
Expand All @@ -109,10 +115,10 @@ setConfig :: Plugin -> [(String, Maybe String)] -> IO Bool
setConfig (Plugin plugin) x =
let obj = Text.JSON.toJSObject [(k, Text.JSON.showJSON v) | (k, v) <- x] in
let bs = toByteString (Text.JSON.encode obj) in
let length = fromIntegral (B.length bs) in
let length' = fromIntegral (B.length bs) in
unsafeUseAsCString bs (\s -> do
withForeignPtr plugin (\plugin-> do
b <- extism_plugin_config plugin (castPtr s) length
withForeignPtr plugin (\plugin'-> do
b <- extism_plugin_config plugin' (castPtr s) length'
return $ b /= 0))

levelStr LogError = "error"
Expand All @@ -133,40 +139,73 @@ setLogFile filename level =
-- | Check if a function exists in the given plugin
functionExists :: Plugin -> String -> IO Bool
functionExists (Plugin plugin) name = do
withForeignPtr plugin (\plugin -> do
b <- withCString name (extism_plugin_function_exists plugin)
withForeignPtr plugin (\plugin' -> do
b <- withCString name (extism_plugin_function_exists plugin')
if b == 1 then return True else return False)

class ToBytes a where
toBytes :: a -> B.ByteString

class FromPointer a where
fromPointer :: CString -> Int -> IO (Result a)

instance ToBytes B.ByteString where
toBytes x = x

instance FromPointer B.ByteString where
fromPointer ptr len = do
x <- B.packCStringLen (castPtr ptr, fromIntegral len)
return $ Right x

instance ToBytes [Char] where
toBytes x = toByteString x

Check warning on line 161 in src/Extism.hs

View workflow job for this annotation

GitHub Actions / Haskell (ubuntu-latest)

Warning in module Extism: Eta reduce ▫︎ Found: "toBytes x = toByteString x" ▫︎ Perhaps: "toBytes = toByteString"

instance FromPointer [Char] where
fromPointer ptr len = do
bs <- fromPointer ptr len
case bs of
Left e -> return $ Left e
Right bs -> return $ Right $ fromByteString bs

newtype JSONValue x = JSONValue x

instance Extism.JSON.JSON a => ToBytes (JSONValue a) where
toBytes (JSONValue x) =
toByteString $ Text.JSON.encode x


instance Extism.JSON.JSON a => FromPointer (JSONValue a) where
fromPointer ptr len = do
s <- fromPointer ptr len
case s of
Left e -> return $ Left e
Right s ->
case Text.JSON.decode s of
Text.JSON.Error x -> return $ Left (ExtismError x)
Text.JSON.Ok x -> return $ Right (JSONValue x)


--- | Call a function provided by the given plugin
call :: Plugin -> String -> B.ByteString -> IO (Result B.ByteString)
call (Plugin plugin) name input =
let length = fromIntegral (B.length input) in
call :: (ToBytes a, FromPointer b) => Plugin -> String -> a -> IO (Result b)
call (Plugin plugin) name inp =
let input = toBytes inp in
let length' = fromIntegral (B.length input) in
do
withForeignPtr plugin (\plugin -> do
rc <- withCString name (\name ->
unsafeUseAsCString input (\input ->
extism_plugin_call plugin name (castPtr input) length))
err <- extism_error plugin
withForeignPtr plugin (\plugin' -> do
rc <- withCString name (\name' ->
unsafeUseAsCString input (\input' ->
extism_plugin_call plugin' name' (castPtr input') length'))
err <- extism_error plugin'
if err /= nullPtr
then do e <- peekCString err
return $ Left (ExtismError e)
else if rc == 0
then do
length <- extism_plugin_output_length plugin
ptr <- extism_plugin_output_data plugin
buf <- B.packCStringLen (castPtr ptr, fromIntegral length)
return $ Right buf
len <- extism_plugin_output_length plugin'
ptr <- extism_plugin_output_data plugin'
fromPointer (castPtr ptr) (fromIntegral len)
else return $ Left (ExtismError "Call failed"))

-- | Call a function with a string argument and return a string
callString :: Plugin -> String -> String -> IO (Result String)
callString p name input = do
res <- call p name (toByteString input)
case res of
Left x -> return $ Left x
Right x -> return $ Right (fromByteString x)


-- | Create a new 'CancelHandle' that can be used to cancel a running plugin
-- | from another thread.
cancelHandle :: Plugin -> IO CancelHandle
Expand All @@ -181,8 +220,8 @@ cancel (CancelHandle handle) =

pluginID :: Plugin -> IO Data.UUID.UUID
pluginID (Plugin plugin) =
withForeignPtr plugin (\plugin -> do
ptr <- extism_plugin_id plugin
withForeignPtr plugin (\plugin' -> do
ptr <- extism_plugin_id plugin'
buf <- B.packCStringLen (castPtr ptr, 16)
case Data.UUID.fromByteString (BL.fromStrict buf) of
Nothing -> error "Invalid Plugin ID"
Expand Down
9 changes: 5 additions & 4 deletions src/Extism/Bindings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ instance Storable Val where
I64 -> ValI64 <$> peekByteOff ptr offs
F32 -> ValF32 <$> peekByteOff ptr offs
F64 -> ValF64 <$> peekByteOff ptr offs
poke ptr x = do
_ -> error "Unsupported val type"
poke ptr a = do
let offs = if _32Bit then 4 else 8
pokeByteOff ptr 0 (typeOfVal x)
case x of
pokeByteOff ptr 0 (typeOfVal a)
case a of
ValI32 x -> pokeByteOff ptr offs x
ValI64 x -> pokeByteOff ptr offs x
ValF32 x -> pokeByteOff ptr offs x
Expand Down Expand Up @@ -103,7 +104,7 @@ foreign import ccall safe "extism.h extism_current_plugin_memory_free" extism_cu

freePtr ptr = do
let s = castPtrToStablePtr ptr
(a, b, c) <- deRefStablePtr s
(_, b, c) <- deRefStablePtr s
freeHaskellFunPtr b
freeHaskellFunPtr c
freeStablePtr s
Expand Down
18 changes: 17 additions & 1 deletion src/Extism/HostFunction.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ module Extism.HostFunction(
memoryOffset,
memoryBytes,
memoryString,
memoryGet,
allocBytes,
allocString,
alloc,
toI32,
toI64,
toF32,
Expand Down Expand Up @@ -80,7 +82,14 @@ memoryString plugin offs = do
ptr <- memoryOffset plugin offs
len <- memoryLength plugin offs
arr <- peekArray (fromIntegral len) ptr
return $ fromByteString $ B.pack arr
return $ fromByteString $ B.pack arr

-- | Access the data associated with a handle and convert it into a Haskell type
memoryGet :: FromPointer a => CurrentPlugin -> MemoryHandle -> IO (Result a)
memoryGet plugin offs = do
ptr <- memoryOffset plugin offs
len <- memoryLength plugin offs
fromPointer (castPtr ptr) (fromIntegral len)

-- | Allocate memory and copy an existing 'ByteString' into it
allocBytes :: CurrentPlugin -> B.ByteString -> IO MemoryHandle
Expand All @@ -101,6 +110,13 @@ allocString plugin s = do
pokeArray ptr (Prelude.map BS.c2w s)
return offs


alloc :: ToBytes a => CurrentPlugin -> a -> IO MemoryHandle
alloc plugin x =
let a = toBytes x in
allocBytes plugin a


-- | Create a new I32 'Val'
toI32 :: Integral a => a -> Val
toI32 x = ValI32 (fromIntegral x)
Expand Down

0 comments on commit e4d16de

Please sign in to comment.