Skip to content

Commit

Permalink
cleanup: implement ToBytes/FromBytes for integer and float types (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
zshipko authored Sep 26, 2023
1 parent df330e7 commit 831a4c0
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Example.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ main = do
setLogFile "stdout" LogError
let m = manifest [wasmFile "wasm/code-functions.wasm"]
f <- hostFunction "hello_world" [I64] [I64] hello "Hello, again"
plugin <- unwrap <$> pluginFromManifest m [f] True
plugin <- unwrap <$> newPlugin m [f] True
id <- pluginID plugin
print id
res <- unwrap <$> call plugin "count_vowels" "this is a test"
Expand Down
3 changes: 2 additions & 1 deletion extism.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ library
bytestring >= 0.11.3 && <= 0.12,
json >= 0.10 && <= 0.11,
extism-manifest >= 0.0.0 && <= 1.0.0,
uuid >= 1.3 && < 2
uuid >= 1.3 && < 2,
binary >= 0.8.9 && < 0.9.0

test-suite extism-example
type: exitcode-stdio-1.0
Expand Down
5 changes: 5 additions & 0 deletions manifest/Extism/Manifest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ withTimeout :: Manifest -> Int -> Manifest
withTimeout m t =
m { timeout = nonNull t }

-- | Set memory.max_pages
withMaxPages :: Manifest -> Int -> Manifest
withMaxPages m pages =
m { memory = NotNull $ Memory (NotNull pages) }

toString :: (JSON a) => a -> String
toString v =
encode (showJSON v)
97 changes: 83 additions & 14 deletions src/Extism.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ module Extism (
fromByteString,
extismVersion,
newPlugin,
pluginFromManifest,
isValid,
setConfig,
setLogFile,
Expand Down Expand Up @@ -46,6 +45,10 @@ import qualified Extism.JSON (JSON(..))
import Extism.Manifest (Manifest, toString)
import Extism.Bindings
import qualified Data.UUID (UUID, fromByteString)
import Data.Binary.Get (runGetOrFail, getWord32le, getInt32le, getWord64le, getInt64le, getFloatle, getDoublele)
import Data.Binary.Put (runPut, putWord32le, putInt32le, putWord64le, putInt64le, putFloatle, putDoublele)
import Data.Int
import Data.Word

-- | Host function, see 'Extism.HostFunction.hostFunction'
data Function = Function (ForeignPtr ExtismFunction) (StablePtr ()) deriving Eq
Expand Down Expand Up @@ -79,10 +82,20 @@ extismVersion () = do
v <- extism_version
peekCString v

class PluginInput a where
pluginInput :: a -> B.ByteString

instance PluginInput B.ByteString where
pluginInput = id

instance PluginInput Manifest where
pluginInput m = toByteString $ toString m

-- | Create a 'Plugin' from a WASM module, `useWasi` determines if WASI should
-- | be linked
newPlugin :: B.ByteString -> [Function] -> Bool -> IO (Result Plugin)
newPlugin wasm functions useWasi =
newPlugin :: PluginInput a => a -> [Function] -> Bool -> IO (Result Plugin)
newPlugin input functions useWasi =
let wasm = pluginInput input in
let nfunctions = fromIntegral (length functions) in
let length' = fromIntegral (B.length wasm) in
let wasi = fromInteger (if useWasi then 1 else 0) in
Expand All @@ -95,19 +108,13 @@ newPlugin wasm functions useWasi =
extism_plugin_new (castPtr s) length' funcs nfunctions wasi errmsg ))
if p == nullPtr then do
err <- peek errmsg
e <- peekCString err
e <- peekCString err
extism_plugin_new_error_free err
return $ Left (ExtismError e)
else do
ptr <- Foreign.Concurrent.newForeignPtr p (extism_plugin_free p)
return $ Right (Plugin ptr))

-- | Create a 'Plugin' from a 'Manifest'
pluginFromManifest :: Manifest -> [Function] -> Bool -> IO (Result Plugin)
pluginFromManifest manifest functions useWasi =
let wasm = toByteString $ toString manifest in
newPlugin wasm functions useWasi

-- | Check if a 'Plugin' is valid
isValid :: Plugin -> IO Bool
isValid (Plugin p) = withForeignPtr p (\x -> return (x /= nullPtr))
Expand Down Expand Up @@ -156,6 +163,12 @@ class FromBytes a where
-- Encoding is used to indicate a type implements both `ToBytes` and `FromBytes`
class (ToBytes a, FromBytes a) => Encoding a where

instance ToBytes () where
toBytes () = toByteString ""

instance FromBytes () where
fromBytes _ = Right ()

instance ToBytes B.ByteString where
toBytes x = x

Expand All @@ -166,17 +179,73 @@ instance ToBytes [Char] where
toBytes = toByteString

instance FromBytes [Char] where
fromBytes bs =
fromBytes bs =
Right $ fromByteString bs

instance ToBytes Int32 where
toBytes i = B.toStrict (runPut (putInt32le i))

instance FromBytes Int32 where
fromBytes bs =
case runGetOrFail getInt32le (B.fromStrict bs) of
Left (_, _, e) -> Left (ExtismError e)
Right (_, _, x) -> Right x

instance ToBytes Int64 where
toBytes i = B.toStrict (runPut (putInt64le i))

instance FromBytes Int64 where
fromBytes bs =
case runGetOrFail getInt64le (B.fromStrict bs) of
Left (_, _, e) -> Left (ExtismError e)
Right (_, _, x) -> Right x


instance ToBytes Word32 where
toBytes i = B.toStrict (runPut (putWord32le i))

instance FromBytes Word32 where
fromBytes bs =
case runGetOrFail getWord32le (B.fromStrict bs) of
Left (_, _, e) -> Left (ExtismError e)
Right (_, _, x) -> Right x

instance ToBytes Word64 where
toBytes i = B.toStrict (runPut (putWord64le i))

instance FromBytes Word64 where
fromBytes bs =
case runGetOrFail getWord64le (B.fromStrict bs) of
Left (_, _, e) -> Left (ExtismError e)
Right (_, _, x) -> Right x

instance ToBytes Float where
toBytes i = B.toStrict (runPut (putFloatle i))

instance FromBytes Float where
fromBytes bs =
case runGetOrFail getFloatle (B.fromStrict bs) of
Left (_, _, e) -> Left (ExtismError e)
Right (_, _, x) -> Right x

instance ToBytes Double where
toBytes i = B.toStrict (runPut (putDoublele i))

instance FromBytes Double where
fromBytes bs =
case runGetOrFail getDoublele (B.fromStrict bs) of
Left (_, _, e) -> Left (ExtismError e)
Right (_, _, x) -> Right x


-- Wraps a `JSON` value for input/output
newtype JSONValue x = JSONValue x

instance Extism.JSON.JSON a => ToBytes (JSONValue a) where
toBytes (JSONValue x) =
toByteString $ Text.JSON.encode x


instance Extism.JSON.JSON a => FromBytes (JSONValue a) where
fromBytes bs = do
case Text.JSON.decode (fromByteString bs) of
Expand Down Expand Up @@ -223,10 +292,10 @@ pluginID (Plugin plugin) =
ptr <- extism_plugin_id plugin'
buf <- B.packCStringLen (castPtr ptr, 16)
case Data.UUID.fromByteString (BL.fromStrict buf) of
Nothing -> error "Invalid Plugin ID"
Nothing -> error "Invalid Plugin ID"
Just x -> return x)


unwrap (Right x) = x
unwrap (Left (ExtismError msg)) =
error msg
4 changes: 2 additions & 2 deletions test/Test.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ hostFunctionManifest = manifest [wasmFile "../wasm/code-functions.wasm"]

initPlugin :: IO Plugin
initPlugin =
Extism.pluginFromManifest defaultManifest [] False >>= assertUnwrap
Extism.newPlugin defaultManifest [] False >>= assertUnwrap

pluginFunctionExists = do
p <- initPlugin
Expand All @@ -38,7 +38,7 @@ hello plugin () = do
output plugin 0 "{\"count\": 999}"

pluginCallHostFunction = do
p <- Extism.pluginFromManifest hostFunctionManifest [] False >>= assertUnwrap
p <- Extism.newPlugin hostFunctionManifest [] False >>= assertUnwrap
res <- call p "count_vowels" (toByteString "this is a test") >>= assertUnwrap
assertEqual "count vowels output" "{\"count\": 999}" (fromByteString res)

Expand Down

0 comments on commit 831a4c0

Please sign in to comment.