Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cairo: Support complex data types in program input #2822

Merged
merged 9 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ env:
SKIP: ormolu,format-juvix-files,typecheck-juvix-examples
VAMPIRREPO: anoma/vamp-ir
VAMPIRVERSION: v0.1.3
CAIRO_VM_VERSION: 6bb5330aede3fc8049b498012a6efbf12bc9432a
CAIRO_VM_VERSION: ec4e2547c201983595254efbe72d9b4cfa450ad9
RISC0_VM_VERSION: v1.0.1
JUST_ARGS: runtimeCcArg=$CC runtimeLibtoolArg=$LIBTOOL
STACK_BUILD_ARGS: --pedantic -j4 --ghc-options=-j
Expand Down
7 changes: 6 additions & 1 deletion app/Commands/Dev/Casm/Compile.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ runCommand opts = do
runReader entryPoint
. runError @JuvixError
. casmToCairo
$ Casm.Result labi code []
$ Casm.Result
{ _resultLabelInfo = labi,
_resultCode = code,
_resultBuiltins = [],
_resultOutputSize = 1
}
res <- getRight r
liftIO $ JSON.encodeFile (toFilePath cairoFile) res
where
Expand Down
2 changes: 1 addition & 1 deletion cntlines.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function count_ext () {
}

RUNTIME_C=$(count runtime/c/src/juvix)
RUNTIME_RUST=$(count runtime/rust/src)
RUNTIME_RUST=$(count runtime/rust/juvix/src)
RUNTIME_VAMPIR=$(count_ext '*.pir' runtime/vampir)
RUNTIME_JVT=$(count_ext '*.jvt' runtime/tree)
RUNTIME_CASM=$(count_ext '*.casm' runtime/casm)
Expand Down
7 changes: 7 additions & 0 deletions runtime/casm/stdlib.casm
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
-- sargs + argsnum = total numer of arguments for the function
-- Maximum number of function arguments: 8

-- Constructor layout: [ cid | arguments... ]
-- cid -- constructor id: 2 * tag + 1, where tag is the 0-based index
-- of the constructor within its inductive type
-- Make sure this spec is followed by:
-- * Juvix.Compiler.Casm.Translation.FromReg
-- * get_cid() in juvix_hint_processor/hint_processor.rs in juvix-cairo-vm

-- after calling juvix_get_regs:
-- [ap - 4] = fp
-- [ap - 3] = pc
Expand Down
29 changes: 19 additions & 10 deletions src/Juvix/Compiler/Backend/Cairo/Extra/Serialization.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import Data.Bits
import Juvix.Compiler.Backend.Cairo.Data.Result
import Juvix.Compiler.Backend.Cairo.Language

serialize :: [Text] -> [Element] -> Result
serialize builtins elems =
serialize :: Int -> [Text] -> [Element] -> Result
serialize outputSize builtins elems =
Result
{ _resultData =
initializeBuiltins
Expand Down Expand Up @@ -48,17 +48,26 @@ serialize builtins elems =

finalizeBuiltins :: [Text]
finalizeBuiltins =
-- [[fp]] = [ap - 1] -- [output_ptr] = [ap - 1]
-- [ap] = [fp] + 1; ap++ -- output_ptr
[ "0x4002800080007fff",
"0x4826800180008000",
"0x1"
]
-- [[fp] + i] = [ap - outputSize + i]
-- [output_ptr + i] = [ap - outputSize + i]
map
( \i ->
toHexText (0x4002800080008000 - outputSize' + i + shift i 32)
)
[0 .. outputSize' - 1]
++
-- [ap] = [fp] + outputSize; ap++
-- output_ptr = output_ptr + outputSize
[ "0x4826800180008000",
toHexText outputSize'
]
++
-- [ap] = [ap - builtinsNum - 2]; ap++
-- [ap] = [ap - 1 - builtinsNum - outputSize]; ap++
replicate
builtinsNum
(toHexText (0x48107ffe7fff8000 - shift builtinsNum 32))
(toHexText (0x48107fff7fff8000 - shift (builtinsNum + outputSize') 32))
where
outputSize' = fromIntegral outputSize

finalizeJump :: [Text]
finalizeJump =
Expand Down
3 changes: 2 additions & 1 deletion src/Juvix/Compiler/Casm/Data/Result.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import Juvix.Compiler.Casm.Language
data Result = Result
{ _resultLabelInfo :: LabelInfo,
_resultCode :: [Instruction],
_resultBuiltins :: [Builtin]
_resultBuiltins :: [Builtin],
_resultOutputSize :: Int
}

makeLenses ''Result
3 changes: 2 additions & 1 deletion src/Juvix/Compiler/Casm/Translation/FromCairo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ fromCairo elems0 =
Result
{ _resultLabelInfo = mempty,
_resultCode = go 0 [] elems0,
_resultBuiltins = mempty
_resultBuiltins = mempty,
_resultOutputSize = 0
}
where
errorMsg :: Address -> Text -> a
Expand Down
53 changes: 42 additions & 11 deletions src/Juvix/Compiler/Casm/Translation/FromReg.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,44 +29,71 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
registerLabelName startSym startName
registerLabelAddress startSym startAddr
let mainSym = fromJust $ tab ^. Reg.infoMainFunction
mainInfo = fromJust (HashMap.lookup mainSym (tab ^. Reg.infoFunctions))
mainInfo = Reg.lookupFunInfo tab mainSym
mainName = mainInfo ^. Reg.functionName
mainResultType = Reg.typeTarget (mainInfo ^. Reg.functionType)
mainArgs = getInputArgs (mainInfo ^. Reg.functionArgsNum) (mainInfo ^. Reg.functionArgNames)
bnum = toOffset builtinsNum
callStartInstr = mkCallRel (Lab startLab)
initBuiltinsInstr = mkAssignAp (Binop $ BinopValue FieldAdd (MemRef Fp (-2)) (Imm 1))
callMainInstr = mkCallRel (Lab $ LabelRef mainSym (Just mainName))
jmpEndInstr = mkJumpRel (Val $ Lab endLab)
margs = concat $ reverse $ map mkLoadInputArg mainArgs
loadInputArgsInstrs = concat $ reverse $ map mkLoadInputArg mainArgs
-- [ap] = [[ap - 2 - k] + k]; ap++
bltsRet = map (\k -> mkAssignAp (Load $ LoadValue (MemRef Ap (-2 - k)) k)) [0 .. bnum - 1]
resRetInstr = mkAssignAp (Val $ Ref $ MemRef Ap (-bnum - 1))
resRetInstrs = mkResultInstrs bnum mainResultType
pinstrs =
callStartInstr
: jmpEndInstr
: Label startLab
: initBuiltinsInstr
: margs
: loadInputArgsInstrs
++ callMainInstr
: bltsRet
++ [resRetInstr, Return]
++ resRetInstrs
++ [Return]
(blts, binstrs) <- addStdlibBuiltins (length pinstrs)
let cinstrs = concatMap (mkFunCall . fst) $ sortOn snd $ HashMap.toList (info ^. Reg.extraInfoFUIDs)
(addr, instrs) <- second (concat . reverse) <$> foldM (goFun blts endLab) (length pinstrs + length binstrs + length cinstrs, []) (tab ^. Reg.infoFunctions)
eassert (addr == length instrs + length cinstrs + length binstrs + length pinstrs)
registerLabelName endSym endName
registerLabelAddress endSym addr
return $
( allElements,
return
( length resRetInstrs,
allElements,
pinstrs
++ binstrs
++ cinstrs
++ instrs
++ [Label endLab]
)
where
mkResult :: (LabelInfo, ([Builtin], Code)) -> Result
mkResult (labi, (blts, code)) = Result labi code blts
mkResult :: (LabelInfo, (Int, [Builtin], Code)) -> Result
mkResult (labi, (outSize, blts, code)) =
Result
{ _resultLabelInfo = labi,
_resultCode = code,
_resultBuiltins = blts,
_resultOutputSize = outSize
}

mkResultInstrs :: Offset -> Reg.Type -> [Instruction]
mkResultInstrs off = \case
Reg.TyInductive Reg.TypeInductive {..} -> goRecord _typeInductiveSymbol
Reg.TyConstr Reg.TypeConstr {..} -> goRecord _typeConstrInductive
_ -> [mkAssignAp (Val $ Ref $ MemRef Ap (-off - 1))]
where
goRecord :: Symbol -> [Instruction]
goRecord sym = case indInfo ^. Reg.inductiveConstructors of
[tag] -> case Reg.lookupConstrInfo tab tag of
Reg.ConstructorInfo {..} ->
map mkOutInstr [1 .. toOffset _constructorArgsNum]
where
mkOutInstr :: Offset -> Instruction
mkOutInstr i = mkAssignAp (Load $ LoadValue (MemRef Ap (-off - i)) i)
_ -> impossible
where
indInfo = Reg.lookupInductiveInfo tab sym

mkLoadInputArg :: Text -> [Instruction]
mkLoadInputArg arg = [Hint (HintInput arg), mkAssignAp (Val $ Ref $ MemRef Ap 0)]
Expand All @@ -87,6 +114,10 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
Nop
]

-- To make it convenient with relative jumps, Cairo constructor tag is `2 *
-- tag + 1` where `tag` is the 0-based constructor number within the
-- inductive type. Make sure this corresponds with the relative jump code in
-- `goCase`.
getTagId :: Tag -> Int
getTagId tag =
1 + 2 * fromJust (HashMap.lookup tag (info ^. Reg.extraInfoCIDs))
Expand Down Expand Up @@ -170,7 +201,7 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI

goCallBlock :: Bool -> Maybe Reg.VarRef -> HashSet Reg.VarRef -> Sem r ()
goCallBlock updatedBuiltins outVar liveVars = do
let liveVars' = toList (maybe liveVars (flip HashSet.delete liveVars) outVar)
let liveVars' = toList (maybe liveVars (`HashSet.delete` liveVars) outVar)
n = length liveVars'
bltOff =
if
Expand Down Expand Up @@ -578,7 +609,7 @@ fromReg tab = mkResult $ run $ runLabelInfoBuilderWithNextId (Reg.getNextSymbolI
syms <- replicateM (length tags) freshSymbol
symEnd <- freshSymbol
let symMap = HashMap.fromList $ zip tags syms
labs = map (flip LabelRef Nothing) syms
labs = map (`LabelRef` Nothing) syms
labEnd = LabelRef symEnd Nothing
jmps = map (mkJumpRel . Val . Lab) labs
-- we need the Nop instructions to ensure that the relative jump
Expand Down
72 changes: 67 additions & 5 deletions src/Juvix/Compiler/Core/Transformation/Check/Cairo.hs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ module Juvix.Compiler.Core.Transformation.Check.Cairo where

import Juvix.Compiler.Core.Error
import Juvix.Compiler.Core.Extra
import Juvix.Compiler.Core.Pretty
import Juvix.Compiler.Core.Transformation.Base
import Juvix.Compiler.Core.Transformation.Check.Base
import Juvix.Data.PPOutput

checkCairo :: forall r. (Member (Error CoreError) r) => Module -> Sem r Module
checkCairo md = do
Expand All @@ -19,7 +19,7 @@ checkCairo md = do
unless (checkType (ii ^. identifierType)) $
throw
CoreError
{ _coreErrorMsg = ppOutput "for this target the arguments and the result of the `main` function must be numbers or field elements",
{ _coreErrorMsg = ppOutput "for this target the arguments the `main` function need to be field elements, numbers, booleans, records or lists, and the result needs to be a field element, number, boolean or a record of field elements, numbers and booleans",
_coreErrorLoc = fromMaybe defaultLoc (ii ^. identifierLocation),
_coreErrorNode = Nothing
}
Expand All @@ -29,7 +29,69 @@ checkCairo md = do
checkType :: Node -> Bool
checkType ty =
let (tyargs, tgt) = unfoldPi' ty
in all isPrimIntegerOrField (tgt : tyargs)
in all isArgType tyargs && isTargetType tgt
where
isPrimIntegerOrField ty' =
isTypeInteger ty' || isTypeField ty' || isDynamic ty'
isArgType :: Node -> Bool
isArgType = \case
NPi {} -> False
NUniv {} -> False
NTyp x -> isRecordOrList x
NPrim x -> isAllowedPrim x
NDyn {} -> True
_ -> False

isTargetType :: Node -> Bool
isTargetType = \case
NPi {} -> False
NUniv {} -> False
NTyp x -> isFlatRecord x
NPrim x -> isAllowedPrim x
NDyn {} -> True
_ -> False

isPrimType :: Node -> Bool
isPrimType = \case
NPrim x -> isAllowedPrim x
_ -> False

isAllowedPrim :: TypePrim -> Bool
isAllowedPrim TypePrim {..} = case _typePrimPrimitive of
PrimInteger {} -> True
PrimBool {} -> True
PrimField {} -> True
PrimString {} -> False

isRecordOrList :: TypeConstr -> Bool
isRecordOrList TypeConstr {..} = case ii ^. inductiveBuiltin of
Just (BuiltinTypeInductive BuiltinList) ->
all isArgType _typeConstrArgs
Just {} ->
False
Nothing ->
case ii ^. inductiveConstructors of
[tag] ->
all isArgType tyargs
where
ci = lookupConstructorInfo md tag
cty = ci ^. constructorType
nParams = length _typeConstrArgs
tyargs =
map (substs _typeConstrArgs)
. drop nParams
. typeArgs
$ cty
_ -> False
where
ii = lookupInductiveInfo md _typeConstrSymbol

isFlatRecord :: TypeConstr -> Bool
isFlatRecord TypeConstr {..} =
case ii ^. inductiveConstructors of
[tag]
| null _typeConstrArgs ->
all isPrimType (typeArgs (ci ^. constructorType))
where
ci = lookupConstructorInfo md tag
_ -> False
where
ii = lookupInductiveInfo md _typeConstrSymbol
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Pipeline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ regToCasm = Reg.toCasm >=> return . Casm.fromReg
casmToCairo :: Casm.Result -> Sem r Cairo.Result
casmToCairo Casm.Result {..} =
return
. Cairo.serialize (map Casm.builtinName _resultBuiltins)
. Cairo.serialize _resultOutputSize (map Casm.builtinName _resultBuiltins)
$ Cairo.fromCasm _resultCode

regToCairo :: Reg.InfoTable -> Sem r Cairo.Result
Expand Down
3 changes: 2 additions & 1 deletion test/Casm/Compilation.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module Casm.Compilation where

import Base
import Casm.Compilation.Negative qualified as N
import Casm.Compilation.Positive qualified as P

allTests :: TestTree
allTests = testGroup "Juvix to CASM compilation" [P.allTests, P.allTestsNoOptimize]
allTests = testGroup "Juvix to CASM compilation" [P.allTests, P.allTestsNoOptimize, N.allTests]
17 changes: 15 additions & 2 deletions test/Casm/Compilation/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ compileAssertion ::
Bool ->
Int ->
Path Abs File ->
Maybe (Path Abs File) ->
Path Abs File ->
(String -> IO ()) ->
Assertion
Expand All @@ -26,10 +27,11 @@ compileAssertionEntry ::
Bool ->
Int ->
Path Abs File ->
Maybe (Path Abs File) ->
Path Abs File ->
(String -> IO ()) ->
Assertion
compileAssertionEntry adjustEntry root' bInterp bRunVM optLevel mainFile expectedFile step = do
compileAssertionEntry adjustEntry root' bInterp bRunVM optLevel mainFile inputFile expectedFile step = do
step "Translate to JuvixCore"
entryPoint <- adjustEntry <$> testDefaultEntryPointIO root' mainFile
PipelineResult {..} <- snd <$> testRunIO entryPoint upToStoredCore
Expand All @@ -44,4 +46,15 @@ compileAssertionEntry adjustEntry root' bInterp bRunVM optLevel mainFile expecte
step "Pretty print"
writeFileEnsureLn tmpFile (toPlainText $ ppProgram _resultCode)
)
casmRunAssertion' bInterp bRunVM _resultLabelInfo _resultCode _resultBuiltins Nothing expectedFile step
casmRunAssertion' bInterp bRunVM _resultLabelInfo _resultCode _resultBuiltins _resultOutputSize inputFile expectedFile step

compileErrorAssertion :: Path Abs Dir -> Path Abs File -> (String -> IO ()) -> Assertion
compileErrorAssertion root' mainFile step = do
step "Translate to JuvixCore"
entryPoint <- testDefaultEntryPointIO root' mainFile
let entryPoint' = entryPoint {_entryPointFieldSize = cairoFieldSize}
PipelineResult {..} <- snd <$> testRunIO entryPoint' upToStoredCore
step "Translate to CASM"
case run $ runError @JuvixError $ runReader entryPoint $ storedCoreToCasm (_pipelineResult ^. Core.coreResultModule) of
Left {} -> assertBool "" True
Right {} -> assertFailure "no error"
Loading
Loading