Skip to content

Commit

Permalink
LLVM Strings (UTF-8-interpreted byte buffers) (#142)
Browse files Browse the repository at this point in the history
* LLVM: utf8 strings

Co-authored-by: Jonathan <jonathan@b-studios.de>
Co-authored-by: Philipp Schuster <philipp.schuster@uni-tuebingen.de>
  • Loading branch information
3 people authored Oct 4, 2022
1 parent 46952f1 commit 964d5b3
Show file tree
Hide file tree
Showing 23 changed files with 421 additions and 169 deletions.
3 changes: 0 additions & 3 deletions effekt/jvm/src/test/scala/effekt/LLVMTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ class LLVMTests extends EffektTests {

// mutable state not support, yet
examplesDir / "llvm" / "gids.effekt",

// strings not supported, yet
examplesDir / "llvm" / "hello-world.effekt"
)

def runTestFor(f: File, expected: String) =
Expand Down
10 changes: 7 additions & 3 deletions effekt/shared/src/main/scala/effekt/generator/llvm/LLVM.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@ import effekt.context.assertions.*
import effekt.util.paths.*

object LLVM extends Backend {
def compileWhole(main: CoreTransformed, dependencies: List[CoreTransformed])(using C: Context): Option[Compiled] = {
def compileWhole(main: CoreTransformed, dependencies: List[CoreTransformed])(using Context): Option[Compiled] = {
val mainFile = path(main.mod)
val machineMod = machine.Transformer.transform(main, dependencies)
val llvmDefinitions = Transformer.transform(machineMod)

val llvmString = effekt.llvm.PrettyPrinter.show(llvmDefinitions)

val result = Document(llvmString, emptyLinks)

Some(Compiled(mainFile, Map(mainFile -> result)))
}

def path(m: Module)(implicit C: Context): String =
def path(m: Module)(using C: Context): String =
(C.config.outputPath() / m.path.replace('/', '_')).unixPath + ".ll"

/**
Expand All @@ -34,12 +36,14 @@ object LLVM extends Backend {
* This could be optimized in the future to (for instance) not show the standard library
* and the prelude.
*/
def compileSeparate(main: CoreTransformed)(implicit C: Context): Option[Document] = {
def compileSeparate(main: CoreTransformed)(using Context): Option[Document] = {
val machine.Program(decls, prog) = machine.Transformer.transform(main, Nil)

// we don't print declarations here.
val llvmDefinitions = Transformer.transform(machine.Program(Nil, prog))

val llvmString = effekt.llvm.PrettyPrinter.show(llvmDefinitions)

Some(Document(llvmString, emptyLinks))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ object PrettyPrinter {

type LLVMString = String

def show(definitions: List[Definition])(using C: Context): LLVMString =
def show(definitions: List[Definition])(using Context): LLVMString =
definitions.map(show).mkString("\n\n")

def show(definition: Definition)(using C: Context): LLVMString = definition match {
Expand All @@ -24,9 +24,20 @@ define ${show(returnType)} ${globalName(name)}(${commaSeparated(parameters.map(s
}
"""
case Verbatim(content) => content

case GlobalVariableArray(name, IntegerType8(), ConstantArray(IntegerType8(), members)) =>
val bytes = members.map { ini => ini match {
case ConstantInteger8(b) => b
case _ => ???
}}
val escaped = bytes.map(b => "\\" + f"$b%02x").mkString;
s"@$name = private constant [${bytes.length} x i8] c\"$escaped\""

case GlobalVariableArray(name, typ, initializer) =>
C.abort(s"cannot compile non-i8 constant array: $name = [ x ${typ}] ${initializer}")
}

def show(basicBlock: BasicBlock)(using C: Context): LLVMString = basicBlock match {
def show(basicBlock: BasicBlock)(using Context): LLVMString = basicBlock match {
case BasicBlock(name, instructions, terminator) =>
s"""
${name}:
Expand Down Expand Up @@ -83,9 +94,9 @@ ${indentedLines(instructions.map(show).mkString("\n"))}
case ExtractValue(result, aggregate, index) =>
s"${localName(result)} = extractvalue ${show(aggregate)}, $index"

// let us hope that `msg` does not contain e.g. a newline
case Comment(msg) =>
s"; $msg"
val sanitized = msg.map((c: Char) => if (' ' <= c && c != '\\' && c <= '~') c else '?').mkString
s"; $sanitized"
}

def show(terminator: Terminator): LLVMString = terminator match {
Expand All @@ -99,23 +110,26 @@ ${indentedLines(instructions.map(show).mkString("\n"))}
}

def show(operand: Operand): LLVMString = operand match {
case LocalReference(tpe, name) => s"${show(tpe)} ${localName(name)}"
case ConstantGlobal(tpe, name) => s"${show(tpe)} ${globalName(name)}"
case ConstantInt(n) => s"i64 $n"
case ConstantDouble(n) => s"double $n"
case ConstantAggregateZero(tpe) => s"${show(tpe)} zeroinitializer"
case ConstantNull(tpe) => s"${show(tpe)} null"
case LocalReference(tpe, name) => s"${show(tpe)} ${localName(name)}"
case ConstantGlobal(tpe, name) => s"${show(tpe)} ${globalName(name)}"
case ConstantInt(n) => s"i64 $n"
case ConstantDouble(n) => s"double $n"
case ConstantAggregateZero(tpe) => s"${show(tpe)} zeroinitializer"
case ConstantNull(tpe) => s"${show(tpe)} null"
case ConstantArray(memberType, members) => s"[${members.length} x ${show(memberType)}]"
case ConstantInteger8(b) => s"i8 $b"
}

def show(tpe: Type): LLVMString = tpe match {
case VoidType() => "void"
case IntegerType64() => "i64"
case IntegerType8() => "i8" // required for `void*` (which only exists as `i8*` in LLVM)
case IntegerType1() => "i1"
case NamedType(name) => localName(name)
case IntegerType8() => "i8"
case IntegerType64() => "i64"
case PointerType(referentType) => s"${show(referentType)}*"
case ArrayType(size, of) => s"[$size x ${show(of)}]"
case StructureType(elementTypes) => s"{${commaSeparated(elementTypes.map(show))}}"
case FunctionType(returnType, argumentTypes) => s"${show(returnType)} (${commaSeparated(argumentTypes.map(show))})"
case NamedType(name) => localName(name)
}

def show(parameter: Parameter): LLVMString = parameter match {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
package effekt
package llvm
package effekt.llvm

import effekt.machine
import effekt.machine.analysis.*
Expand Down Expand Up @@ -29,6 +28,11 @@ object Transformer {
declarations.map(transform) ++ definitions :+ entryFunction
}

// context getters
private def MC(using MC: ModuleContext): ModuleContext = MC
private def FC(using FC: FunctionContext): FunctionContext = FC
private def BC(using BC: BlockContext): BlockContext = BC

def transform(declaration: machine.Declaration): Definition =
declaration match {
case machine.Extern(functionName, parameters, returnType, body) =>
Expand Down Expand Up @@ -273,9 +277,23 @@ object Transformer {
emit(FAdd(name, ConstantDouble(x), ConstantDouble(0)));
transform(rest)

case machine.LiteralUTF8String(v@machine.Variable(bind, _), utf8, rest) =>
emit(GlobalVariableArray(s"$bind.lit", IntegerType8(), ConstantArray(IntegerType8(), utf8.map { b => ConstantInteger8(b) }.toList)))

emit(BitCast(s"$bind.lit.decayed", ConstantGlobal(PointerType(ArrayType(utf8.length, IntegerType8())), s"$bind.lit"), PointerType(IntegerType8())))

val res = positiveType
val args = List(ConstantInt(utf8.size), LocalReference(PointerType(IntegerType8()), s"$bind.lit.decayed"))
val argsT = List(IntegerType64(), PointerType(IntegerType8()))
emit(Call(bind, res, ConstantGlobal(FunctionType(res, argsT), "c_buffer_construct"), args))

eraseValues(List(v), freeVariables(rest));
transform(rest)

case machine.ForeignCall(machine.Variable(resultName, resultType), foreign, values, rest) =>
// TODO careful with calling convention?!?
val functionType = PointerType(FunctionType(transform(resultType), values.map { case machine.Variable(_, tpe) => transform(tpe) }));
shareValues(values, freeVariables(rest));
emit(Call(resultName, transform(resultType), ConstantGlobal(functionType, foreign), values.map(transform)));
transform(rest)
}
Expand Down Expand Up @@ -308,6 +326,7 @@ object Transformer {
case machine.Negative(_) => negativeType
case machine.Type.Int() => NamedType("Int")
case machine.Type.Double() => NamedType("Double")
case machine.Type.String() => positiveType
case machine.Type.Stack() => stkType
}

Expand All @@ -316,11 +335,12 @@ object Transformer {

def typeSize(tpe: machine.Type): Int =
tpe match {
case machine.Positive(_) => 16
case machine.Negative(_) => 16
case machine.Type.Int() => 8 // TODO Make fat?
case machine.Type.Double() => 8 // TODO Make fat?
case machine.Type.Stack() => 8 // TODO Make fat?
case machine.Positive(_) => 16
case machine.Negative(_) => 16
case machine.Type.Int() => 8 // TODO Make fat?
case machine.Type.Double() => 8 // TODO Make fat?
case machine.Type.String() => 16
case machine.Type.Stack() => 8 // TODO Make fat?
}

def defineFunction(name: String, parameters: List[Parameter])(prog: (FunctionContext, BlockContext) ?=> Terminator): ModuleContext ?=> Unit = {
Expand Down Expand Up @@ -503,6 +523,7 @@ object Transformer {
case machine.Type.Stack() => emit(Call("_", VoidType(), shareStack, List(transform(value))))
case machine.Type.Int() => ()
case machine.Type.Double() => ()
case machine.Type.String() => emit(Call("_", VoidType(), shareString, List(transform(value))))
}
}

Expand All @@ -513,6 +534,7 @@ object Transformer {
case machine.Type.Stack() => emit(Call("_", VoidType(), eraseStack, List(transform(value))))
case machine.Type.Int() => ()
case machine.Type.Double() => ()
case machine.Type.String() => emit(Call("_", VoidType(), eraseString, List(transform(value))))
}
}

Expand Down Expand Up @@ -585,12 +607,14 @@ object Transformer {
def shareNegative = ConstantGlobal(PointerType(FunctionType(VoidType(),List(negativeType))), "shareNegative");
def shareStack = ConstantGlobal(PointerType(FunctionType(VoidType(),List(stkType))), "shareStack");
def shareFrames = ConstantGlobal(PointerType(FunctionType(VoidType(),List(spType))), "shareFrames");
def shareString = ConstantGlobal(PointerType(FunctionType(VoidType(),List(positiveType))), "c_buffer_refcount_increment");

def eraseObject = ConstantGlobal(PointerType(FunctionType(VoidType(),List(objType))), "eraseObject");
def erasePositive = ConstantGlobal(PointerType(FunctionType(VoidType(),List(positiveType))), "erasePositive");
def eraseNegative = ConstantGlobal(PointerType(FunctionType(VoidType(),List(negativeType))), "eraseNegative");
def eraseStack = ConstantGlobal(PointerType(FunctionType(VoidType(),List(stkType))), "eraseStack");
def eraseFrames = ConstantGlobal(PointerType(FunctionType(VoidType(),List(spType))), "eraseFrames");
def eraseString = ConstantGlobal(PointerType(FunctionType(VoidType(),List(positiveType))), "c_buffer_refcount_decrement");

def newStack = ConstantGlobal(PointerType(FunctionType(stkType,List())), "newStack");
def pushStack = ConstantGlobal(PointerType(FunctionType(spType,List(stkType, spType))), "pushStack");
Expand Down Expand Up @@ -647,5 +671,4 @@ object Transformer {

def setStackPointer(stackPointer: Operand)(using C: BlockContext) =
C.stackPointer = stackPointer;

}
16 changes: 10 additions & 6 deletions effekt/shared/src/main/scala/effekt/generator/llvm/Tree.scala
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package effekt
package llvm


/**
* see: https://hackage.haskell.org/package/llvm-hs-pure-9.0.0/docs/LLVM-AST.html#t:Definition
*/
enum Definition {
case Function(returnType: Type, name: String, parameters: List[Parameter], basicBlocks: List[BasicBlock])
case VerbatimFunction(returnType: Type, name: String, parameters: List[Parameter], body: String)
case Verbatim(content: String)
case GlobalVariableArray(name: String, typ: Type, initializer: Operand) // initializer should be constant
}
export Definition.*

Expand All @@ -27,7 +29,6 @@ enum Instruction {
case FAdd(result: String, operand0: Operand, operand1: Operand)
case InsertValue(result: String, aggregate: Operand, element: Operand, index: Int)
case ExtractValue(result: String, aggregate: Operand, index: Int)

case Comment(msg: String)
}
export Instruction.*
Expand All @@ -41,8 +42,8 @@ export Terminator.*

case class Parameter(typ: Type, name: String)

// Operands cannot be an enum since we use the more specific types massively. Scala 3 will perform widening way
// too often.
// Operands cannot be an enum since we use the more specific types massively.
// Scala 3 will perform widening way too often.
sealed trait Operand
object Operand {
case class LocalReference(tpe: Type, name: String) extends Operand
Expand All @@ -51,6 +52,8 @@ object Operand {
case class ConstantAggregateZero(typ: Type) extends Operand
case class ConstantGlobal(tpe: Type, name: String) extends Operand
case class ConstantNull(tpe: Type) extends Operand
case class ConstantArray(memberType: Type, members: List[Operand]) extends Operand // members should be homogeneous
case class ConstantInteger8(b: Byte) extends Operand
}
export Operand.*

Expand All @@ -59,12 +62,13 @@ export Operand.*
*/
enum Type {
case VoidType()
case IntegerType64()
case IntegerType8()
case IntegerType1()
case IntegerType8() // required for `void*` (which only exists as `i8*` in LLVM) and `char*`
case IntegerType64()
case PointerType(pointerReferent: Type)
case ArrayType(size: Int, of: Type)
case StructureType(elementTypes: List[Type])
case FunctionType(resultType: Type, argumentTypes: List[Type])
case PointerType(pointerReferent: Type)
case NamedType(name: String)
}
export Type.*
28 changes: 2 additions & 26 deletions effekt/shared/src/main/scala/effekt/machine/Analysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,32 +41,8 @@ def freeVariables(statement: Statement): Set[Variable] =
freeVariables(rest) - name
case LiteralDouble(name, value, rest) =>
freeVariables(rest) - name
case LiteralUTF8String(name, utf8, rest) =>
freeVariables(rest) - name
case ForeignCall(name, builtin, arguments, rest) =>
arguments.toSet ++ freeVariables(rest) - name
}

def substitute(subst: Substitution, v: Variable): Variable =
subst.findLast { case (k, _) => k == v }.map { _._2 }.getOrElse(v)

def substitute(subst: Substitution, into: Substitution): Substitution = into map {
case (k, v) => (k, substitute(subst, v))
}

def substitute(subst: Substitution, stmt: Statement): Statement = stmt match {
case Def(label, body, rest) => substitute(subst, stmt)
case Jump(label) => Substitute(subst, Jump(label))
// we float substitutions downwards...
case Substitute(bindings, rest) => substitute(substitute(subst, bindings) ++ subst, rest)
case Construct(name, tag, values, rest) => ???
case Switch(value, clauses) => ???
case New(name, clauses, rest) => ???
case Invoke(value, tag, values) => ???
case PushFrame(frame, rest) => ???
case Return(environment) => ???
case NewStack(name, frame, rest) => ???
case PushStack(value, rest) => ???
case PopStack(name, rest) => ???
case LiteralInt(name, value, rest) => ???
case LiteralDouble(name, value, rest) => ???
case ForeignCall(name, builtin, arguments, rest) => ???
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ object PrettyPrinter extends ParenPrettyPrinter {
case Negative(name) => name
case Type.Int() => "Int"
case Type.Double() => "Double"
case Type.String() => "String"
case Type.Stack() => "Stack"
}

Expand All @@ -40,7 +41,7 @@ object PrettyPrinter extends ParenPrettyPrinter {
"jump" <+> label

case Substitute(bindings, rest) =>
"subst" <+> brackets(bindings map { case (to, from) => from <+> "!->" <+> to }) <> ";" <> line <> toDoc(rest)
"subst" <+> brackets(bindings map { case (left, right) => left <+> "!->" <+> right }) <> ";" <> line <> toDoc(rest)

case Construct(name, tag, arguments, rest) =>
"let" <+> name <+> "=" <+> tag.toString <> parens(arguments map toDoc) <> ";" <> line <> toDoc(rest)
Expand Down Expand Up @@ -79,6 +80,9 @@ object PrettyPrinter extends ParenPrettyPrinter {

case LiteralDouble(name, value, rest) =>
"let" <+> name <+> "=" <+> value.toString <> ";" <> line <> toDoc(rest)

case LiteralUTF8String(name, utf8, rest) =>
"let" <+> name <+> "=" <+> ("\"" + (utf8.map { b => "\\" + f"$b%02x" }).mkString + "\"") <> ";" <> line <> toDoc(rest)
}

def nested(content: Doc): Doc = group(nest(line <> content))
Expand Down
11 changes: 9 additions & 2 deletions effekt/shared/src/main/scala/effekt/machine/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,12 @@ object Transformer {
LiteralDouble(literal_binding, v, k(literal_binding))
}

case lifted.StringLit(javastring) =>
val literal_binding = Variable(freshName("utf8_string_literal"), Type.String());
Binding { k =>
LiteralUTF8String(literal_binding, javastring.getBytes("utf-8"), k(literal_binding))
}

case lifted.PureApp(lifted.BlockVar(blockName: symbols.BuiltinFunction), List(), args) =>
val variable = Variable(freshName("x"), transform(blockName.result))
transform(args).flatMap { values =>
Expand Down Expand Up @@ -342,6 +348,8 @@ object Transformer {

case symbols.builtins.TDouble => Type.Double()

case symbols.builtins.TString => Type.String()

case symbols.FunctionType(Nil, Nil, vparams, Nil, _, _) => Negative("<function>")

case symbols.Interface(name, List(), _) => Negative(name.name)
Expand All @@ -351,8 +359,7 @@ object Transformer {
case symbols.Record(name, List(), _, _) => Positive(name.name)

case _ =>
System.err.println(s"UNSUPPORTED TYPE: getClass($tpe) = ${tpe.getClass}")
Context.abort(s"unsupported type $tpe")
Context.abort(s"unsupported type: $tpe (class = ${tpe.getClass})")
}

def transform(id: Symbol): String =
Expand Down
Loading

0 comments on commit 964d5b3

Please sign in to comment.