Files
epfl-archive/cs320-clp/src/amyc/codegen/CodeGen.scala
2022-04-07 18:43:21 +02:00

153 lines
6.5 KiB
Scala

package amyc
package codegen
import amyc.analyzer._
import amyc.ast.Identifier
import amyc.ast.SymbolicTreeModule.{And => AmyAnd, Call => AmyCall, Div => AmyDiv, Or => AmyOr, _}
import amyc.codegen.Utils._
import amyc.utils.{Context, Pipeline}
import amyc.wasm.Instructions._
import amyc.wasm._
// Generates WebAssembly code for an Amy program
object CodeGen extends Pipeline[(Program, SymbolTable), Module] {
def run(ctx: Context)(v: (Program, SymbolTable)): Module = {
val (program, table) = v
// Generate code for an Amy module
def cgModule(moduleDef: ModuleDef): List[Function] = {
val ModuleDef(name, defs, optExpr) = moduleDef
// Generate code for all functions
defs.collect { case fd: FunDef if !builtInFunctions(fullName(name, fd.name)) =>
cgFunction(fd, name, false)
} ++
// Generate code for the "main" function, which contains the module expression
optExpr.toList.map { expr =>
val mainFd = FunDef(Identifier.fresh("main"), Nil, TypeTree(IntType), expr)
cgFunction(mainFd, name, true)
}
}
// Generate code for a function in module 'owner'
def cgFunction(fd: FunDef, owner: Identifier, isMain: Boolean): Function = {
// Note: We create the wasm function name from a combination of
// module and function name, since we put everything in the same wasm module.
val name = fullName(owner, fd.name)
Function(name, fd.params.size, isMain){ lh =>
val locals = fd.paramNames.zipWithIndex.toMap
val body = cgExpr(fd.body)(locals, lh)
if (isMain) {
body <:> Drop // Main functions do not return a value,
// so we need to drop the value generated by their body
} else {
body
}
}
}
// Generate code for an expression expr.
// Additional arguments are a mapping from identifiers (parameters and variables) to
// their index in the wasm local variables, and a LocalsHandler which will generate
// fresh local slots as required.
def cgExpr(expr: Expr)(implicit locals: Map[Identifier, Int], lh: LocalsHandler): Code = expr match {
case Let(df, value, body) =>
val address = lh.getFreshLocal()
cgExpr(value) <:> SetLocal(address) <:> cgExpr(body)(locals + (df.name -> address), lh)
case Variable(name) => GetLocal(locals(name))
case Concat(lhs, rhs) => cgExpr(lhs) <:> cgExpr(rhs) <:> Call(concatImpl.name)
case Sequence(expr1, expr2) => cgExpr(expr1) <:> Drop <:> cgExpr(expr2)
case Ite(condition, thenBlock, elseBlock) => cgExpr(condition) <:> If_i32 <:> cgExpr(thenBlock) <:> Else <:> cgExpr(elseBlock) <:> End
case Plus(lhs, rhs) => cgExpr(lhs) <:> cgExpr(rhs) <:> Add
case Minus(lhs, rhs) => cgExpr(lhs) <:> cgExpr(rhs) <:> Sub
case Times(lhs, rhs) => cgExpr(lhs) <:> cgExpr(rhs) <:> Mul
case AmyDiv(lhs, rhs) => cgExpr(lhs) <:> cgExpr(rhs) <:> Div
case Mod(lhs, rhs) => cgExpr(lhs) <:> cgExpr(rhs) <:> Rem
case Equals(lhs, rhs) => cgExpr(lhs) <:> cgExpr(rhs) <:> Eq
case LessEquals(lhs, rhs) => cgExpr(lhs) <:> cgExpr(rhs) <:> Le_s
case LessThan(lhs, rhs) => cgExpr(lhs) <:> cgExpr(rhs) <:> Lt_s
case AmyAnd(lhs, rhs) => cgExpr(lhs) <:> If_i32 <:> cgExpr(rhs) <:> Else <:> Const(0) <:> End
case AmyOr(lhs, rhs) => cgExpr(lhs) <:> If_i32 <:> Const(1) <:> Else <:> cgExpr(rhs) <:> End
case Neg(expr1) => Const(0) <:> cgExpr(expr1) <:> Sub
case Not(expr1) => cgExpr(expr1) <:> Eqz
case IntLiteral(lit) => Const(lit)
case StringLiteral(lit) => mkString(lit)
case BooleanLiteral(lit) => if (lit) Const(1) else Const(0)
case UnitLiteral() => Const(0)
case AmyCall(qname, args) => table.getConstructor(qname) match {
case Some(sig) => {
val newLocal = lh.getFreshLocal()
GetGlobal(memoryBoundary) <:> SetLocal(newLocal) <:>
GetGlobal(memoryBoundary) <:> adtField(args.size) <:> SetGlobal(memoryBoundary) <:>
GetLocal(newLocal) <:> Const(sig.index) <:> Store <:>
args.indices.map(i => GetLocal(newLocal) <:> adtField(i) <:> cgExpr(args(i)) <:> Store).toList <:>
GetLocal(newLocal)
}
case None => args.map(cgExpr) <:> Call(fullName(table.getFunction(qname).get.owner, qname))
}
case Match(scrut, cases) =>
def matchAndBind(pat: Pattern): (Code, Map[Identifier, Int]) = pat match {
case WildcardPattern() => (Drop <:> Const(1), Map.empty)
case IdPattern(name) =>
val idLocal = lh.getFreshLocal()
(SetLocal(idLocal) <:> Const(1), Map.empty + (name -> idLocal))
case LiteralPattern(lit) => (cgExpr(lit) <:> Eq, Map.empty)
case CaseClassPattern(constr, args) => {
val caseLocal = lh.getFreshLocal()
val argsLocalnCode = args.zipWithIndex.map(v => {
val mab = matchAndBind(v._1)
(GetLocal(caseLocal) <:> adtField(v._2) <:> Load <:> mab._1, mab._2)
})
val argsCode: Code = {
if (args.isEmpty) Const(1)
else if (args.lengthCompare(1) == 0) argsLocalnCode.map( _._1)
else argsLocalnCode.map( _._1) <:> args.tail.map(_ => And)
}
val idx = table.getConstructor(constr).get.index
(SetLocal(caseLocal) <:> GetLocal(caseLocal) <:> Load <:> Const(idx) <:> Eq <:>
If_i32 <:> argsCode <:> Else <:> Const(0) <:> End,
argsLocalnCode.map(_._2).foldLeft(Map.empty[Identifier, Int])(_ ++ _))
}
case _ => (Unreachable, Map.empty)
}
val newLocal = lh.getFreshLocal()
val caseCodes = cases.map(cse => {
val mnb = matchAndBind(cse.pat)
GetLocal(newLocal) <:> mnb._1 <:> If_i32 <:> cgExpr(cse.expr)(locals ++ mnb._2, lh) <:> Else
})
(cgExpr(scrut) <:> SetLocal(newLocal) <:> caseCodes <:> mkString("Match error!") <:> Call("Std_printString") <:> Unreachable <:> cases.map(_ => End))
case Error(msg) => cgExpr(StringLiteral("Error: ")) <:> cgExpr(msg) <:> Call(concatImpl.name) <:> Call("Std_printString") <:> Unreachable
}
Module(
program.modules.last.name.name,
defaultImports,
globalsNo,
wasmFunctions ++ (program.modules flatMap cgModule)
)
}
}