Disabled external gits
This commit is contained in:
283
cs320-clp/src/amyc/analyzer/NameAnalyzer.scala
Normal file
283
cs320-clp/src/amyc/analyzer/NameAnalyzer.scala
Normal file
@@ -0,0 +1,283 @@
|
||||
package amyc
|
||||
package analyzer
|
||||
|
||||
import amyc.ast.{Identifier, NominalTreeModule => N, SymbolicTreeModule => S}
|
||||
import amyc.utils._
|
||||
|
||||
// Name analyzer for Amy
|
||||
// Takes a nominal program (names are plain strings, qualified names are string pairs)
|
||||
// and returns a symbolic program, where all names have been resolved to unique Identifiers.
|
||||
// Rejects programs that violate the Amy naming rules.
|
||||
// Also populates and returns the symbol table.
|
||||
object NameAnalyzer extends Pipeline[N.Program, (S.Program, SymbolTable)] {
|
||||
def run(ctx: Context)(p: N.Program): (S.Program, SymbolTable) = {
|
||||
import ctx.reporter._
|
||||
|
||||
// Step 0: Initialize symbol table
|
||||
val table = new SymbolTable
|
||||
|
||||
// Step 1: Add modules to table
|
||||
val modNames = p.modules.groupBy(_.name)
|
||||
modNames.foreach { case (name, modules) =>
|
||||
if (modules.size > 1) {
|
||||
fatal(s"Two modules named $name in program", modules.head.position)
|
||||
}
|
||||
}
|
||||
|
||||
modNames.keys.toList foreach table.addModule
|
||||
|
||||
|
||||
// Helper method: will transform a nominal type 'tt' to a symbolic type,
|
||||
// given that we are within module 'inModule'.
|
||||
def transformType(tt: N.TypeTree, inModule: String): S.Type = {
|
||||
tt.tpe match {
|
||||
case N.IntType => S.IntType
|
||||
case N.BooleanType => S.BooleanType
|
||||
case N.StringType => S.StringType
|
||||
case N.UnitType => S.UnitType
|
||||
case N.ClassType(qn@N.QualifiedName(module, name)) =>
|
||||
table.getType(module getOrElse inModule, name) match {
|
||||
case Some(symbol) =>
|
||||
S.ClassType(symbol)
|
||||
case None =>
|
||||
fatal(s"Could not find type $qn", tt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Check name uniqueness of definitions in each module
|
||||
p.modules.foreach{
|
||||
mod => mod.defs.groupBy(_.name).foreach {
|
||||
case (name, ld) =>
|
||||
if (ld.size > 1) fatal(s"Multiple definitions of module: $name in ${mod.name}")
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Discover types and add them to symbol table
|
||||
p.modules.foreach{
|
||||
mod => mod.defs.foreach {
|
||||
case N.AbstractClassDef(name) =>
|
||||
table.addType(
|
||||
mod.name,
|
||||
name
|
||||
)
|
||||
case _ => //Ignore if not abstract-class
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Discover type constructors, add them to table
|
||||
p.modules.foreach{
|
||||
mod => mod.defs.foreach {
|
||||
case N.CaseClassDef(name, fields, parent) =>
|
||||
table.addConstructor(
|
||||
mod.name,
|
||||
name,
|
||||
fields.map(tt => transformType(tt, mod.name)),
|
||||
table.getType(mod.name, parent).getOrElse(fatal(s"Non-existant: ${mod.name}"))
|
||||
)
|
||||
case _ => //Ignore if not case-class
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5: Discover functions signatures, add them to table
|
||||
p.modules.foreach {
|
||||
mod => mod.defs.foreach{
|
||||
case N.FunDef(name, params, rt, _) =>
|
||||
table.addFunction(
|
||||
mod.name,
|
||||
name,
|
||||
params.map(p => p.tt).map(tt => transformType(tt, mod.name)),
|
||||
transformType(rt, mod.name)
|
||||
)
|
||||
case _ => //Ignore if not fun (note: last assignment was fun)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Step 6: We now know all definitions in the program.
|
||||
// Reconstruct modules and analyse function bodies/ expressions
|
||||
|
||||
// This part is split into three transfrom functions,
|
||||
// for definitions, FunDefs, and expressions.
|
||||
// Keep in mind that we transform constructs of the NominalTreeModule 'N' to respective constructs of the SymbolicTreeModule 'S'.
|
||||
// transformFunDef is given as an example, as well as some code for the other ones
|
||||
|
||||
def transformDef(df: N.ClassOrFunDef, module: String): S.ClassOrFunDef = { df match {
|
||||
case N.AbstractClassDef(name) =>
|
||||
table.getType(module, name) match {
|
||||
case Some(n)=> S.AbstractClassDef(n)
|
||||
case None => fatal(s"Could not find class name")
|
||||
}
|
||||
case N.CaseClassDef(name, _, _) =>
|
||||
table.getConstructor(module, name) match {
|
||||
case Some(c)=>
|
||||
S.CaseClassDef(c._1,c._2.argTypes.map(arg => S.TypeTree(arg)),c._2.parent)
|
||||
case None => fatal(s"Could not find class name")
|
||||
}
|
||||
case fd: N.FunDef =>
|
||||
transformFunDef(fd, module)
|
||||
}}.setPos(df)
|
||||
|
||||
def transformFunDef(fd: N.FunDef, module: String): S.FunDef = {
|
||||
val N.FunDef(name, params, retType, body) = fd
|
||||
val Some((sym, sig)) = table.getFunction(module, name)
|
||||
|
||||
params.groupBy(_.name).foreach { case (name, ps) =>
|
||||
if (ps.size > 1) {
|
||||
fatal(s"Two parameters named $name in function ${fd.name}", fd)
|
||||
}
|
||||
}
|
||||
|
||||
val paramNames = params.map(_.name)
|
||||
|
||||
val newParams = params zip sig.argTypes map { case (pd@N.ParamDef(name, tt), tpe) =>
|
||||
val s = Identifier.fresh(name)
|
||||
S.ParamDef(s, S.TypeTree(tpe).setPos(tt)).setPos(pd)
|
||||
}
|
||||
|
||||
val paramsMap = paramNames.zip(newParams.map(_.name)).toMap
|
||||
|
||||
S.FunDef(
|
||||
sym,
|
||||
newParams,
|
||||
S.TypeTree(sig.retType).setPos(retType),
|
||||
transformExpr(body)(module, (paramsMap, Map()))
|
||||
).setPos(fd)
|
||||
}
|
||||
|
||||
// This function takes as implicit a pair of two maps:
|
||||
// The first is a map from names of parameters to their unique identifiers,
|
||||
// the second is similar for local variables.
|
||||
// Make sure to update them correctly if needed given the scoping rules of Amy
|
||||
def transformExpr(expr: N.Expr)
|
||||
(implicit module: String, names: (Map[String, Identifier], Map[String, Identifier])): S.Expr = {
|
||||
val (params, locals) = names
|
||||
val res : S.Expr = expr match {
|
||||
|
||||
|
||||
// L1
|
||||
case N.Let(df, value, body) =>
|
||||
if (locals.contains(df.name))
|
||||
fatal(s"Redefinition of variable ${df.name}", df.position)
|
||||
val sn = Identifier.fresh(df.name)
|
||||
S.Let(S.ParamDef(sn, S.TypeTree(transformType(df.tt, module))), transformExpr(value), transformExpr(body)(module, (params, locals + (df.name -> sn))))
|
||||
|
||||
case N.Sequence(s1, s2) => S.Sequence(transformExpr(s1),transformExpr(s2))
|
||||
|
||||
|
||||
// L2
|
||||
case N.Match(scrut, cases) =>
|
||||
// Returns a transformed pattern along with all bindings
|
||||
// from strings to unique identifiers for names bound in the pattern.
|
||||
// Also, calls 'fatal' if a new name violates the Amy naming rules.
|
||||
def transformPattern(pat: N.Pattern): (S.Pattern, List[(String, Identifier)]) = {
|
||||
pat match{
|
||||
case N.LiteralPattern(l) => l match{
|
||||
case N.BooleanLiteral(x) => (S.LiteralPattern(S.BooleanLiteral(x)), List())
|
||||
case N.IntLiteral(x) => (S.LiteralPattern(S.IntLiteral(x)), List())
|
||||
case N.StringLiteral(x) => (S.LiteralPattern(S.StringLiteral(x)), List())
|
||||
case N.UnitLiteral() => (S.LiteralPattern(S.UnitLiteral()), List())
|
||||
}
|
||||
case N.IdPattern(name) =>
|
||||
if(locals.contains(name)){
|
||||
fatal(s"Duplicate variable $name")
|
||||
}
|
||||
if(name == s"Nil"){
|
||||
warning("Maybe you meant to write `Nil()` ?")
|
||||
}
|
||||
val id = Identifier.fresh(name)
|
||||
(S.IdPattern(id), List((name,id)))
|
||||
case N.WildcardPattern() => (S.WildcardPattern(), List())
|
||||
case N.CaseClassPattern(p_name,args) =>
|
||||
val owner = p_name.module.getOrElse(module)
|
||||
val name = p_name.name
|
||||
val constructor = table.getConstructor(owner,name).getOrElse(fatal(s"Constructor not found for Pattern $owner.$name"))
|
||||
if(constructor._2.argTypes.size != args.size)
|
||||
fatal(s"Invalid Arg Count for $owner.$name: had ${args.size}, expected ${constructor._2.argTypes.size}")
|
||||
|
||||
val n_args = args.map(arg => transformPattern(arg))
|
||||
|
||||
val pn_args = n_args.map(_._1)
|
||||
val ln_args = n_args.flatMap(_._2)
|
||||
|
||||
ln_args.groupBy(_._1).foreach(el => {
|
||||
if(el._2.size > 1)
|
||||
fatal(s"Duplicate local variable ${el._1}")
|
||||
if(locals.contains(el._1))
|
||||
warning(s"Shadowing an existing variable ${el._1}")
|
||||
})
|
||||
|
||||
(S.CaseClassPattern(constructor._1, pn_args),ln_args)
|
||||
}
|
||||
}
|
||||
|
||||
def transformCase(cse: N.MatchCase) : S.MatchCase = {
|
||||
val N.MatchCase(pat, rhs) = cse
|
||||
val (trans_pat, local_locals) = transformPattern(pat)
|
||||
S.MatchCase(trans_pat,transformExpr(rhs)(module, (params,locals ++ local_locals)))
|
||||
}
|
||||
|
||||
S.Match(transformExpr(scrut), cases.map(transformCase))
|
||||
|
||||
case N.Ite(cond, thenn, elze) => S.Ite(transformExpr(cond), transformExpr(thenn), transformExpr(elze))
|
||||
|
||||
|
||||
// L3
|
||||
case N.Plus(lhs,rhs) => S.Plus(transformExpr(lhs),transformExpr(rhs))
|
||||
case N.Minus(lhs,rhs) => S.Minus(transformExpr(lhs),transformExpr(rhs))
|
||||
case N.Times(lhs,rhs) => S.Times(transformExpr(lhs),transformExpr(rhs))
|
||||
case N.Div(lhs,rhs) => S.Div(transformExpr(lhs),transformExpr(rhs))
|
||||
case N.Mod(lhs,rhs) => S.Mod(transformExpr(lhs),transformExpr(rhs))
|
||||
case N.LessThan(lhs,rhs) => S.LessThan(transformExpr(lhs),transformExpr(rhs))
|
||||
case N.LessEquals(lhs,rhs) => S.LessEquals(transformExpr(lhs),transformExpr(rhs))
|
||||
case N.And(lhs,rhs) => S.And(transformExpr(lhs),transformExpr(rhs))
|
||||
case N.Or(lhs,rhs) => S.Or(transformExpr(lhs),transformExpr(rhs))
|
||||
case N.Equals(lhs,rhs) => S.Equals(transformExpr(lhs),transformExpr(rhs))
|
||||
case N.Concat(lhs,rhs) => S.Concat(transformExpr(lhs),transformExpr(rhs))
|
||||
|
||||
|
||||
// L4
|
||||
case N.Not(e) => S.Not(transformExpr(e))
|
||||
case N.Neg(e) => S.Neg(transformExpr(e))
|
||||
|
||||
|
||||
// L5
|
||||
case N.IntLiteral(v) => S.IntLiteral(v)
|
||||
case N.BooleanLiteral(v) => S.BooleanLiteral(v)
|
||||
case N.StringLiteral(v) => S.StringLiteral(v)
|
||||
case N.UnitLiteral() => S.UnitLiteral()
|
||||
|
||||
case N.Error(err) => S.Error(transformExpr(err))
|
||||
|
||||
case N.Variable(n) => S.Variable(locals.getOrElse(n, params.getOrElse(n, fatal(s"No variable $n", expr))))
|
||||
|
||||
case N.Call(qname, args) =>
|
||||
val owner = qname.module.getOrElse(module)
|
||||
val name = qname.name
|
||||
|
||||
val (sn, cs) = table.getConstructor(owner, name).getOrElse(table.getFunction(owner,name).getOrElse(fatal(s"No function found for $owner.$name")))
|
||||
if (args.size != cs.argTypes.size)
|
||||
fatal(s"Invalid Arg count for $owner.$name")
|
||||
val fun_args = args.map(arg => transformExpr(arg))
|
||||
S.Call(sn, fun_args)
|
||||
|
||||
//case _ =>// TODO: Implement the rest of the cases
|
||||
}
|
||||
res.setPos(expr)
|
||||
}
|
||||
|
||||
// Putting it all together to construct the final program for step 6.
|
||||
val newProgram = S.Program(
|
||||
p.modules map { case mod@N.ModuleDef(name, defs, optExpr) =>
|
||||
S.ModuleDef(
|
||||
table.getModule(name).get,
|
||||
defs map (transformDef(_, name)),
|
||||
optExpr map (transformExpr(_)(name, (Map(), Map())))
|
||||
).setPos(mod)
|
||||
}
|
||||
).setPos(p)
|
||||
|
||||
(newProgram, table)
|
||||
|
||||
}
|
||||
}
|
85
cs320-clp/src/amyc/analyzer/SymbolTable.scala
Normal file
85
cs320-clp/src/amyc/analyzer/SymbolTable.scala
Normal file
@@ -0,0 +1,85 @@
|
||||
package amyc.analyzer
|
||||
|
||||
import amyc.ast.Identifier
|
||||
import amyc.ast.SymbolicTreeModule._
|
||||
import amyc.utils.UniqueCounter
|
||||
|
||||
import scala.collection.mutable.HashMap
|
||||
|
||||
trait Signature[RT <: Type]{
|
||||
val argTypes: List[Type]
|
||||
val retType: RT
|
||||
}
|
||||
// The signature of a function in the symbol table
|
||||
case class FunSig(argTypes: List[Type], retType: Type, owner: Identifier) extends Signature[Type]
|
||||
// The signature of a constructor in the symbol table
|
||||
case class ConstrSig(argTypes: List[Type], parent: Identifier, index: Int) extends Signature[ClassType] {
|
||||
val retType = ClassType(parent)
|
||||
}
|
||||
|
||||
// A class that represents a dictionary of symbols for an Amy program
|
||||
class SymbolTable {
|
||||
private val defsByName = HashMap[(String, String), Identifier]()
|
||||
private val modules = HashMap[String, Identifier]()
|
||||
|
||||
private val types = HashMap[Identifier, Identifier]()
|
||||
private val functions = HashMap[Identifier, FunSig]()
|
||||
private val constructors = HashMap[Identifier, ConstrSig]()
|
||||
|
||||
private val typesToConstructors = HashMap[Identifier, List[Identifier]]()
|
||||
|
||||
private val constrIndexes = new UniqueCounter[Identifier]
|
||||
|
||||
def addModule(name: String) = {
|
||||
val s = Identifier.fresh(name)
|
||||
modules += name -> s
|
||||
s
|
||||
}
|
||||
def getModule(name: String) = modules.get(name)
|
||||
|
||||
def addType(owner: String, name: String) = {
|
||||
val s = Identifier.fresh(name)
|
||||
defsByName += (owner, name) -> s
|
||||
types += (s -> modules.getOrElse(owner, sys.error(s"Module $name not found!")))
|
||||
s
|
||||
}
|
||||
def getType(owner: String, name: String) =
|
||||
defsByName.get(owner,name) filter types.contains
|
||||
def getType(symbol: Identifier) = types.get(symbol)
|
||||
|
||||
def addConstructor(owner: String, name: String, argTypes: List[Type], parent: Identifier) = {
|
||||
val s = Identifier.fresh(name)
|
||||
defsByName += (owner, name) -> s
|
||||
constructors += s -> ConstrSig(
|
||||
argTypes,
|
||||
parent,
|
||||
constrIndexes.next(parent)
|
||||
)
|
||||
typesToConstructors += parent -> (typesToConstructors.getOrElse(parent, Nil) :+ s)
|
||||
s
|
||||
}
|
||||
def getConstructor(owner: String, name: String): Option[(Identifier, ConstrSig)] = {
|
||||
for {
|
||||
sym <- defsByName.get(owner, name)
|
||||
sig <- constructors.get(sym)
|
||||
} yield (sym, sig)
|
||||
}
|
||||
def getConstructor(symbol: Identifier) = constructors.get(symbol)
|
||||
|
||||
def getConstructorsForType(t: Identifier) = typesToConstructors.get(t)
|
||||
|
||||
def addFunction(owner: String, name: String, argTypes: List[Type], retType: Type) = {
|
||||
val s = Identifier.fresh(name)
|
||||
defsByName += (owner, name) -> s
|
||||
functions += s -> FunSig(argTypes, retType, getModule(owner).getOrElse(sys.error(s"Module $owner not found!")))
|
||||
s
|
||||
}
|
||||
def getFunction(owner: String, name: String): Option[(Identifier, FunSig)] = {
|
||||
for {
|
||||
sym <- defsByName.get(owner, name)
|
||||
sig <- functions.get(sym)
|
||||
} yield (sym, sig)
|
||||
}
|
||||
def getFunction(symbol: Identifier) = functions.get(symbol)
|
||||
|
||||
}
|
175
cs320-clp/src/amyc/analyzer/TypeChecker.scala
Normal file
175
cs320-clp/src/amyc/analyzer/TypeChecker.scala
Normal file
@@ -0,0 +1,175 @@
|
||||
package amyc
|
||||
package analyzer
|
||||
|
||||
import amyc.ast.Identifier
|
||||
import amyc.ast.SymbolicTreeModule._
|
||||
import amyc.utils._
|
||||
|
||||
// The type checker for Amy
|
||||
// Takes a symbolic program and rejects it if it does not follow the Amy typing rules.
|
||||
object TypeChecker extends Pipeline[(Program, SymbolTable), (Program, SymbolTable)] {
|
||||
|
||||
def run(ctx: Context)(v: (Program, SymbolTable)): (Program, SymbolTable) = {
|
||||
import ctx.reporter._
|
||||
|
||||
val (program, table) = v
|
||||
|
||||
case class Constraint(found: Type, expected: Type, pos: Position)
|
||||
|
||||
// Represents a type variable.
|
||||
// It extends Type, but it is meant only for internal type checker use,
|
||||
// since no Amy value can have such type.
|
||||
case class TypeVariable private (id: Int) extends Type
|
||||
object TypeVariable {
|
||||
private val c = new UniqueCounter[Unit]
|
||||
def fresh(): TypeVariable = TypeVariable(c.next(()))
|
||||
}
|
||||
|
||||
// Generates typing constraints for an expression `e` with a given expected type.
|
||||
// The environment `env` contains all currently available bindings (you will have to
|
||||
// extend these, e.g., to account for local variables).
|
||||
// Returns a list of constraints among types. These will later be solved via unification.
|
||||
def genConstraints(e: Expr, expected: Type)(implicit env: Map[Identifier, Type]): List[Constraint] = {
|
||||
|
||||
// This helper returns a list of a single constraint recording the type
|
||||
// that we found (or generated) for the current expression `e`
|
||||
def topLevelConstraint(found: Type): List[Constraint] =
|
||||
List(Constraint(found, expected, e.position))
|
||||
|
||||
e match {
|
||||
case IntLiteral(_) => topLevelConstraint(IntType)
|
||||
case BooleanLiteral(_) => topLevelConstraint(BooleanType)
|
||||
case StringLiteral(_) => topLevelConstraint(StringType)
|
||||
case UnitLiteral() => topLevelConstraint(UnitType)
|
||||
case Variable(name) => topLevelConstraint(env.getOrElse(name,UnitType))
|
||||
|
||||
case Plus(lhs, rhs) => topLevelConstraint(IntType) ++ genConstraints(lhs,IntType) ++ genConstraints(rhs,IntType)
|
||||
case Minus(lhs, rhs) => topLevelConstraint(IntType) ++ genConstraints(lhs,IntType) ++ genConstraints(rhs,IntType)
|
||||
case Times(lhs, rhs) => topLevelConstraint(IntType) ++ genConstraints(lhs,IntType) ++ genConstraints(rhs,IntType)
|
||||
case Div(lhs, rhs) => topLevelConstraint(IntType) ++ genConstraints(lhs,IntType) ++ genConstraints(rhs,IntType)
|
||||
case Mod(lhs, rhs) => topLevelConstraint(IntType) ++ genConstraints(lhs,IntType) ++ genConstraints(rhs,IntType)
|
||||
|
||||
case Neg(e) => topLevelConstraint(IntType) ++ genConstraints(e,IntType)
|
||||
case Not(e) => topLevelConstraint(BooleanType) ++ genConstraints(e,BooleanType)
|
||||
|
||||
case LessThan(lhs, rhs) => topLevelConstraint(BooleanType) ++ genConstraints(lhs,IntType) ++ genConstraints(rhs,IntType)
|
||||
case LessEquals(lhs, rhs) => topLevelConstraint(BooleanType) ++ genConstraints(lhs,IntType) ++ genConstraints(rhs,IntType)
|
||||
case Equals(lhs, rhs) =>
|
||||
val ltype = TypeVariable.fresh()
|
||||
topLevelConstraint(BooleanType) ++ genConstraints(lhs,ltype) ++ genConstraints(rhs,ltype)
|
||||
|
||||
case And(lhs, rhs) => topLevelConstraint(BooleanType) ++ genConstraints(lhs,BooleanType) ++ genConstraints(rhs,BooleanType)
|
||||
case Or(lhs, rhs) => topLevelConstraint(BooleanType) ++ genConstraints(lhs,BooleanType) ++ genConstraints(rhs,BooleanType)
|
||||
|
||||
case Concat(lhs, rhs) => topLevelConstraint(StringType) ++ genConstraints(lhs,StringType) ++ genConstraints(rhs,StringType)
|
||||
|
||||
case Error(msg) => topLevelConstraint(expected) ++ genConstraints(msg,StringType)
|
||||
|
||||
case Ite(cond, thenn, elze) =>
|
||||
topLevelConstraint(expected) ++ genConstraints(cond,BooleanType) ++ genConstraints(thenn,expected) ++ genConstraints(elze,expected)
|
||||
|
||||
case Sequence(e1, e2) =>
|
||||
topLevelConstraint(expected) ++ genConstraints(e1,TypeVariable.fresh()) ++ genConstraints(e2,expected)
|
||||
|
||||
case Let(df, value, body) =>
|
||||
topLevelConstraint(expected) ++ genConstraints(value,df.tt.tpe) ++ genConstraints(body, expected)(env + (df.name -> df.tt.tpe))
|
||||
|
||||
case Call(qname, args) =>
|
||||
val (sign, constr) = (table.getFunction(qname),table.getConstructor(qname)) match{
|
||||
case (Some(s), None) => (s, Constraint(s.retType, expected, e.position))
|
||||
case (None, Some(s)) => (s, Constraint(ClassType(s.parent), expected, e.position))
|
||||
case _ => throw new MatchError("Invalid Case")
|
||||
}
|
||||
(constr :: (args.zip(sign.argTypes)).flatMap(p => genConstraints(p._1, p._2)))
|
||||
|
||||
case Match(scrut, cases) =>
|
||||
// Returns additional constraints from within the pattern with all bindings
|
||||
// from identifiers to types for names bound in the pattern.
|
||||
// (This is analogous to `transformPattern` in NameAnalyzer.)
|
||||
def handlePattern(pat: Pattern, scrutExpected: Type):
|
||||
(List[Constraint], Map[Identifier, Type]) = pat match {
|
||||
case CaseClassPattern(constr, args) =>
|
||||
val constr_sig = table.getConstructor(constr).get
|
||||
val in_pat = args.zip(constr_sig.argTypes).map(p => handlePattern(p._1, p._2))
|
||||
(Constraint(ClassType(constr_sig.parent), scrutExpected, pat.position) ::
|
||||
in_pat.flatMap(_._1), in_pat.flatMap(_._2.toList).toMap)
|
||||
case IdPattern(name) => (List(), Map(name->scrutExpected))
|
||||
case LiteralPattern(lit) => (genConstraints(lit,scrutExpected), Map())
|
||||
case WildcardPattern() => (List(), Map())
|
||||
}
|
||||
|
||||
def handleCase(cse: MatchCase, scrutExpected: Type): List[Constraint] = {
|
||||
val (patConstraints, moreEnv) = handlePattern(cse.pat, scrutExpected)
|
||||
patConstraints ++ genConstraints(cse.expr, expected)(env ++ moreEnv)
|
||||
}
|
||||
|
||||
val st = TypeVariable.fresh()
|
||||
genConstraints(scrut, st) ++ cases.flatMap(cse => handleCase(cse, st))
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Given a list of constraints `constraints`, replace every occurence of type variable
|
||||
// with id `from` by type `to`.
|
||||
def subst_*(constraints: List[Constraint], from: Int, to: Type): List[Constraint] = {
|
||||
// Do a single substitution.
|
||||
def subst(tpe: Type, from: Int, to: Type): Type = {
|
||||
tpe match {
|
||||
case TypeVariable(`from`) => to
|
||||
case other => other
|
||||
}
|
||||
}
|
||||
|
||||
constraints map { case Constraint(found, expected, pos) =>
|
||||
Constraint(subst(found, from, to), subst(expected, from, to), pos)
|
||||
}
|
||||
}
|
||||
|
||||
// Solve the given set of typing constraints and
|
||||
// call `typeError` if they are not satisfiable.
|
||||
// We consider a set of constraints to be satisfiable exactly if they unify.
|
||||
def solveConstraints(constraints: List[Constraint]): Unit = {
|
||||
constraints match {
|
||||
case Nil => ()
|
||||
case Constraint(found, expected, pos) :: more =>
|
||||
// HINT: You can use the `subst_*` helper above to replace a type variable
|
||||
// by another type in your current set of constraints.
|
||||
(found, expected) match {
|
||||
case (TypeVariable(id1), tpe@TypeVariable(id2)) =>
|
||||
if (id1 == id2) {
|
||||
solveConstraints(more)
|
||||
} else {
|
||||
solveConstraints(subst_*(constraints, id1, tpe))
|
||||
}
|
||||
case (tpe, TypeVariable(id)) =>
|
||||
solveConstraints(subst_*(constraints, id, tpe))
|
||||
case (tpe1, tpe2) =>
|
||||
if(tpe1 == tpe2){
|
||||
solveConstraints(more)
|
||||
}else{
|
||||
error(s"Error in TypeChecking, Types Not matching, found: ${found.toString}, expected ${expected.toString}.",pos)
|
||||
solveConstraints(more)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Putting it all together to type-check each module's functions and main expression.
|
||||
program.modules.foreach { mod =>
|
||||
// Put function parameters to the symbol table, then typecheck them against the return type
|
||||
mod.defs.collect { case FunDef(_, params, retType, body) =>
|
||||
val env = params.map{ case ParamDef(name, tt) => name -> tt.tpe }.toMap
|
||||
solveConstraints(genConstraints(body, retType.tpe)(env))
|
||||
}
|
||||
|
||||
// Type-check expression if present. We allow the result to be of an arbitrary type by
|
||||
// passing a fresh (and therefore unconstrained) type variable as the expected type.
|
||||
val tv = TypeVariable.fresh()
|
||||
mod.optExpr.foreach(e => solveConstraints(genConstraints(e, tv)(Map())))
|
||||
}
|
||||
|
||||
v
|
||||
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user