package amyc package codegen import amyc.analyzer._ import amyc.ast.Identifier import amyc.ast.SymbolicTreeModule.{And => AmyAnd, Call => AmyCall, Div => AmyDiv, Or => AmyOr, _} import amyc.codegen.Utils._ import amyc.utils.{Context, Pipeline} import amyc.wasm.Instructions._ import amyc.wasm._ // Generates WebAssembly code for an Amy program object CodeGen extends Pipeline[(Program, SymbolTable), Module] { def run(ctx: Context)(v: (Program, SymbolTable)): Module = { val (program, table) = v // Generate code for an Amy module def cgModule(moduleDef: ModuleDef): List[Function] = { val ModuleDef(name, defs, optExpr) = moduleDef // Generate code for all functions defs.collect { case fd: FunDef if !builtInFunctions(fullName(name, fd.name)) => cgFunction(fd, name, false) } ++ // Generate code for the "main" function, which contains the module expression optExpr.toList.map { expr => val mainFd = FunDef(Identifier.fresh("main"), Nil, TypeTree(IntType), expr) cgFunction(mainFd, name, true) } } // Generate code for a function in module 'owner' def cgFunction(fd: FunDef, owner: Identifier, isMain: Boolean): Function = { // Note: We create the wasm function name from a combination of // module and function name, since we put everything in the same wasm module. val name = fullName(owner, fd.name) Function(name, fd.params.size, isMain){ lh => val locals = fd.paramNames.zipWithIndex.toMap val body = cgExpr(fd.body)(locals, lh) if (isMain) { body <:> Drop // Main functions do not return a value, // so we need to drop the value generated by their body } else { body } } } // Generate code for an expression expr. // Additional arguments are a mapping from identifiers (parameters and variables) to // their index in the wasm local variables, and a LocalsHandler which will generate // fresh local slots as required. def cgExpr(expr: Expr)(implicit locals: Map[Identifier, Int], lh: LocalsHandler): Code = expr match { case Let(df, value, body) => val address = lh.getFreshLocal() cgExpr(value) <:> SetLocal(address) <:> cgExpr(body)(locals + (df.name -> address), lh) case Variable(name) => GetLocal(locals(name)) case Concat(lhs, rhs) => cgExpr(lhs) <:> cgExpr(rhs) <:> Call(concatImpl.name) case Sequence(expr1, expr2) => cgExpr(expr1) <:> Drop <:> cgExpr(expr2) case Ite(condition, thenBlock, elseBlock) => cgExpr(condition) <:> If_i32 <:> cgExpr(thenBlock) <:> Else <:> cgExpr(elseBlock) <:> End case Plus(lhs, rhs) => cgExpr(lhs) <:> cgExpr(rhs) <:> Add case Minus(lhs, rhs) => cgExpr(lhs) <:> cgExpr(rhs) <:> Sub case Times(lhs, rhs) => cgExpr(lhs) <:> cgExpr(rhs) <:> Mul case AmyDiv(lhs, rhs) => cgExpr(lhs) <:> cgExpr(rhs) <:> Div case Mod(lhs, rhs) => cgExpr(lhs) <:> cgExpr(rhs) <:> Rem case Equals(lhs, rhs) => cgExpr(lhs) <:> cgExpr(rhs) <:> Eq case LessEquals(lhs, rhs) => cgExpr(lhs) <:> cgExpr(rhs) <:> Le_s case LessThan(lhs, rhs) => cgExpr(lhs) <:> cgExpr(rhs) <:> Lt_s case AmyAnd(lhs, rhs) => cgExpr(lhs) <:> If_i32 <:> cgExpr(rhs) <:> Else <:> Const(0) <:> End case AmyOr(lhs, rhs) => cgExpr(lhs) <:> If_i32 <:> Const(1) <:> Else <:> cgExpr(rhs) <:> End case Neg(expr1) => Const(0) <:> cgExpr(expr1) <:> Sub case Not(expr1) => cgExpr(expr1) <:> Eqz case IntLiteral(lit) => Const(lit) case StringLiteral(lit) => mkString(lit) case BooleanLiteral(lit) => if (lit) Const(1) else Const(0) case UnitLiteral() => Const(0) case AmyCall(qname, args) => table.getConstructor(qname) match { case Some(sig) => { val newLocal = lh.getFreshLocal() GetGlobal(memoryBoundary) <:> SetLocal(newLocal) <:> GetGlobal(memoryBoundary) <:> adtField(args.size) <:> SetGlobal(memoryBoundary) <:> GetLocal(newLocal) <:> Const(sig.index) <:> Store <:> args.indices.map(i => GetLocal(newLocal) <:> adtField(i) <:> cgExpr(args(i)) <:> Store).toList <:> GetLocal(newLocal) } case None => args.map(cgExpr) <:> Call(fullName(table.getFunction(qname).get.owner, qname)) } case Match(scrut, cases) => def matchAndBind(pat: Pattern): (Code, Map[Identifier, Int]) = pat match { case WildcardPattern() => (Drop <:> Const(1), Map.empty) case IdPattern(name) => val idLocal = lh.getFreshLocal() (SetLocal(idLocal) <:> Const(1), Map.empty + (name -> idLocal)) case LiteralPattern(lit) => (cgExpr(lit) <:> Eq, Map.empty) case CaseClassPattern(constr, args) => { val caseLocal = lh.getFreshLocal() val argsLocalnCode = args.zipWithIndex.map(v => { val mab = matchAndBind(v._1) (GetLocal(caseLocal) <:> adtField(v._2) <:> Load <:> mab._1, mab._2) }) val argsCode: Code = { if (args.isEmpty) Const(1) else if (args.lengthCompare(1) == 0) argsLocalnCode.map( _._1) else argsLocalnCode.map( _._1) <:> args.tail.map(_ => And) } val idx = table.getConstructor(constr).get.index (SetLocal(caseLocal) <:> GetLocal(caseLocal) <:> Load <:> Const(idx) <:> Eq <:> If_i32 <:> argsCode <:> Else <:> Const(0) <:> End, argsLocalnCode.map(_._2).foldLeft(Map.empty[Identifier, Int])(_ ++ _)) } case _ => (Unreachable, Map.empty) } val newLocal = lh.getFreshLocal() val caseCodes = cases.map(cse => { val mnb = matchAndBind(cse.pat) GetLocal(newLocal) <:> mnb._1 <:> If_i32 <:> cgExpr(cse.expr)(locals ++ mnb._2, lh) <:> Else }) (cgExpr(scrut) <:> SetLocal(newLocal) <:> caseCodes <:> mkString("Match error!") <:> Call("Std_printString") <:> Unreachable <:> cases.map(_ => End)) case Error(msg) => cgExpr(StringLiteral("Error: ")) <:> cgExpr(msg) <:> Call(concatImpl.name) <:> Call("Std_printString") <:> Unreachable } Module( program.modules.last.name.name, defaultImports, globalsNo, wasmFunctions ++ (program.modules flatMap cgModule) ) } }