Disabled external gits
This commit is contained in:
@@ -0,0 +1,131 @@
|
||||
package fos
|
||||
|
||||
object Infer {
|
||||
case class TypeScheme(params: List[TypeVar], tp: Type)
|
||||
type Env = List[(String, TypeScheme)]
|
||||
type Constraint = (Type, Type)
|
||||
|
||||
case class TypeError(msg: String) extends Exception(msg)
|
||||
|
||||
extension (t: Type)
|
||||
def in(s: Type): Boolean = getTV(s).contains(t)
|
||||
|
||||
extension (e: Env)
|
||||
def has(t: TypeVar): Boolean = e.find(_._2.params.contains(t)).nonEmpty
|
||||
|
||||
private var ftv_count = -1
|
||||
def freshTV(): TypeVar = {
|
||||
ftv_count += 1
|
||||
TypeVar("v"+ftv_count.toString)
|
||||
}
|
||||
|
||||
def getTV(t: Type): List[TypeVar] = getTV(t, Nil)
|
||||
def getTV(t: Type, ex: List[TypeVar]): List[TypeVar] = t match {
|
||||
case tpe@TypeVar(_) if !(ex.contains(tpe))=> List(tpe)
|
||||
case FunType(tpe1, tpe2) => getTV(tpe1,ex) ::: getTV(tpe2,ex)
|
||||
case _ => Nil
|
||||
}
|
||||
|
||||
def generalizeTV(tp: Type, env: Env) : TypeScheme = TypeScheme(getTV(tp).filter(env.has), tp)
|
||||
|
||||
|
||||
|
||||
def collectHelper(env: Env, t: Term, tpe: Type): (Type, List[Constraint]) =
|
||||
collectHelper(env, t, tpe, tpe)
|
||||
|
||||
def collectHelper(env: Env, t: Term, tpe1: Type, tpe2: Type): (Type, List[Constraint]) =
|
||||
val (tp1, c1) = collect(env,t)
|
||||
(tpe1, (tp1, tpe2) :: c1)
|
||||
|
||||
def collectHelperAbs(env: Env, x: String, t: Term, tpe: Type): (Type, List[Constraint]) =
|
||||
val (tp1, c1) = collect((x, TypeScheme(Nil,tpe)) :: env, t)
|
||||
(FunType(tpe, tp1), c1)
|
||||
|
||||
def collectHelperIf(env: Env, t1: Term, t2: Term, t3: Term): (Type, List[Constraint]) =
|
||||
val (tp1, c1) = collect(env, t1)
|
||||
val (tp2, c2) = collect(env, t2)
|
||||
val (tp3, c3) = collect(env, t3)
|
||||
(tp2, (tp1, BoolType) :: (tp2, tp3) :: c1 ::: c2 ::: c3)
|
||||
|
||||
def collectHelperApp(env: Env, t1: Term, t2: Term, tpe: Type): (Type, List[Constraint]) =
|
||||
val (tp1, c1) = collect(env, t1)
|
||||
val (tp2, c2) = collect(env, t2)
|
||||
(tpe, (tp1, FunType(tp2, tpe)) :: c1 ::: c2)
|
||||
|
||||
def collectHelperLet(env: Env, x: String, t1: Term, t2: Term): (Type, List[Constraint]) =
|
||||
val (tp1, c1) = collect(env, t1)
|
||||
val subst = unify(c1)
|
||||
val tp1Unified = subst(tp1)
|
||||
val fenv = env
|
||||
.map((s,ts) => (s, ts.params, subst(ts.tp)))
|
||||
.map((s,tsp,sub) => (s, TypeScheme(tsp.filter(_.in(sub)), sub)))
|
||||
val ts = generalizeTV(tp1Unified, fenv)
|
||||
val (tp2, c2) = collect((x, ts) :: fenv, t2)
|
||||
(tp2, c1 ::: c2)
|
||||
|
||||
|
||||
|
||||
def collect(env: Env, t: Term): (Type, List[Constraint]) =
|
||||
// println("Term: "+t)
|
||||
// println("Env: "+env)
|
||||
t match {
|
||||
case True | False => (BoolType, Nil)
|
||||
case Zero => (NatType, Nil)
|
||||
case Pred(t1) => collectHelper(env, t1, NatType)
|
||||
case Succ(t1) => collectHelper(env, t1, NatType)
|
||||
case IsZero(t1) => collectHelper(env, t1, BoolType, NatType)
|
||||
case If(t1, t2, t3) => collectHelperIf(env, t1, t2, t3)
|
||||
case Var(x) if (env.exists(_._1 == x)) => ( env.find(_._1 == x).get._2.tp, Nil)
|
||||
case Abs(x, tp, t1) => tp match {
|
||||
case EmptyTypeTree() => collectHelperAbs(env, x, t1, freshTV())
|
||||
case _ => collectHelperAbs(env, x, t1, tp.tpe)
|
||||
}
|
||||
case App(t1, t2) => collectHelperApp(env, t1, t2, freshTV())
|
||||
case Let(x, tp, t1, t2) => tp match {
|
||||
case EmptyTypeTree() => collectHelperLet(env, x, t1, t2)
|
||||
case _ => collect(env, App(Abs(x, tp, t2), t1))
|
||||
}
|
||||
case _ => throw TypeError(f"Collect is stuck on term `$t` !")
|
||||
}
|
||||
|
||||
|
||||
def unify(c: List[Constraint]): Type => Type = {
|
||||
val substitutions = unifyRec(c, Map())
|
||||
// println("Subst: "+ substitutions)
|
||||
return (x) => subst(x,substitutions)
|
||||
}
|
||||
|
||||
def subst(tpe: Type, constraints: Map[Type, Type]): Type = tpe match {
|
||||
case FunType(tpe1, tpe2) => FunType(subst(tpe1, constraints), subst(tpe2, constraints))
|
||||
case a@TypeVar(_) if constraints.contains(a) => subst(constraints(a), constraints)
|
||||
case others => others
|
||||
}
|
||||
|
||||
def unifyRec(c: List[Constraint], substitutions: Map[Type, Type]): Map[Type, Type] = {
|
||||
if (c.isEmpty)
|
||||
return substitutions
|
||||
// println("c: "+ c)
|
||||
// println("Subst: "+ substitutions)
|
||||
c.head match {
|
||||
case (s, t) if s == t => unifyRec(c.tail, substitutions)
|
||||
case (s@TypeVar(_), t) if !(s in t) => unifyRec(applyMapping(c.tail, s -> t), substitutions + (s -> t))
|
||||
case (s, t@TypeVar(_)) if !(t in s) => unifyRec(applyMapping(c.tail, t -> s), substitutions + (t -> s))
|
||||
case (FunType(s1, s2), FunType(t1, t2)) => unifyRec((s1, t1) :: (s2, t2) :: c.tail, substitutions)
|
||||
case _ => throw new TypeError(f"Impossible to find a substitution that satisfies the constraint set `${c.head}`")
|
||||
}
|
||||
}
|
||||
|
||||
def applyMapping(constraints: List[Constraint], st: (TypeVar, Type)): List[Constraint] =
|
||||
val (s, t) = st
|
||||
def recMapping(tp: Type): Type = tp match {
|
||||
case tp if tp == s => t
|
||||
case FunType(a, b) => FunType(recMapping(a), recMapping(b))
|
||||
case tp => tp
|
||||
}
|
||||
constraints.map {
|
||||
case (`s`, `s`) => (t, t)
|
||||
case (`s`, y) => (t, y)
|
||||
case (x , `s`) => (x, t)
|
||||
case (x, y) => (recMapping(x), recMapping(y))
|
||||
}
|
||||
}
|
@@ -0,0 +1,25 @@
|
||||
package fos
|
||||
|
||||
import Parser._
|
||||
import scala.util.parsing.input._
|
||||
|
||||
object Launcher {
|
||||
def main(args: Array[String]) = {
|
||||
val stdin = new java.io.BufferedReader(new java.io.InputStreamReader(System.in))
|
||||
val tokens = new lexical.Scanner(stdin.readLine())
|
||||
phrase(term)(tokens) match {
|
||||
case Success(term, _) =>
|
||||
try {
|
||||
val (tpe, c) = Infer.collect(Nil, term)
|
||||
// println("TPE: "+tpe)
|
||||
// println("C: "+c)
|
||||
val sub = Infer.unify(c)
|
||||
println("typed: " + sub(tpe))
|
||||
} catch {
|
||||
case tperror: Exception => println("type error: " + tperror.getMessage)
|
||||
}
|
||||
case e =>
|
||||
println(e)
|
||||
}
|
||||
}
|
||||
}
|
@@ -0,0 +1,74 @@
|
||||
package fos
|
||||
|
||||
import scala.util.parsing.combinator.syntactical.StandardTokenParsers
|
||||
import scala.util.parsing.input._
|
||||
|
||||
object Parser extends StandardTokenParsers {
|
||||
lexical.delimiters ++= List("(", ")", "\\", ".", ":", "=", "->", "{", "}", ",", "*", "+")
|
||||
lexical.reserved ++= List("Bool", "Nat", "true", "false", "if", "then", "else", "succ",
|
||||
"pred", "iszero", "let", "in")
|
||||
|
||||
/** <pre>
|
||||
* Term ::= SimpleTerm { SimpleTerm }</pre>
|
||||
*/
|
||||
def term: Parser[Term] = positioned(
|
||||
simpleTerm ~ rep(simpleTerm) ^^ { case t ~ ts => (t :: ts).reduceLeft[Term](App.apply) }
|
||||
| failure("illegal start of term"))
|
||||
|
||||
/** <pre>
|
||||
* SimpleTerm ::= "true"
|
||||
* | "false"
|
||||
* | number
|
||||
* | "succ" Term
|
||||
* | "pred" Term
|
||||
* | "iszero" Term
|
||||
* | "if" Term "then" Term "else" Term
|
||||
* | ident
|
||||
* | "\" ident [":" Type] "." Term
|
||||
* | "(" Term ")"
|
||||
* | "let" ident [":" Type] "=" Term "in" Term</pre>
|
||||
*/
|
||||
def simpleTerm: Parser[Term] = positioned(
|
||||
"true" ^^^ True
|
||||
| "false" ^^^ False
|
||||
| numericLit ^^ { case chars => lit2Num(chars.toInt) }
|
||||
| "succ" ~ term ^^ { case "succ" ~ t => Succ(t) }
|
||||
| "pred" ~ term ^^ { case "pred" ~ t => Pred(t) }
|
||||
| "iszero" ~ term ^^ { case "iszero" ~ t => IsZero(t) }
|
||||
| "if" ~ term ~ "then" ~ term ~ "else" ~ term ^^ {
|
||||
case "if" ~ t1 ~ "then" ~ t2 ~ "else" ~ t3 => If(t1, t2, t3)
|
||||
}
|
||||
| ident ^^ { case id => Var(id) }
|
||||
| "\\" ~ ident ~ opt(":" ~ typ) ~ "." ~ term ^^ {
|
||||
case "\\" ~ x ~ Some(":" ~ tp) ~ "." ~ t => Abs(x, tp, t)
|
||||
case "\\" ~ x ~ None ~ "." ~ t => Abs(x, EmptyTypeTree(), t)
|
||||
}
|
||||
| "(" ~> term <~ ")" ^^ { case t => t }
|
||||
| "let" ~ ident ~ opt(":" ~ typ) ~ "=" ~ term ~ "in" ~ term ^^ {
|
||||
case "let" ~ x ~ Some(":" ~ tp) ~ "=" ~ t1 ~ "in" ~ t2 => Let(x, tp, t1, t2)
|
||||
case "let" ~ x ~ None ~ "=" ~ t1 ~ "in" ~ t2 => Let(x, EmptyTypeTree(), t1, t2)
|
||||
}
|
||||
| failure("illegal start of simple term"))
|
||||
|
||||
/** <pre>
|
||||
* Type ::= SimpleType { "->" Type }</pre>
|
||||
*/
|
||||
def typ: Parser[TypeTree] = positioned(
|
||||
baseType ~ opt("->" ~ typ) ^^ {
|
||||
case t1 ~ Some("->" ~ t2) => FunTypeTree(t1, t2)
|
||||
case t1 ~ None => t1
|
||||
}
|
||||
| failure("illegal start of type"))
|
||||
|
||||
/** <pre>
|
||||
* BaseType ::= "Bool" | "Nat" | "(" Type ")"</pre>
|
||||
*/
|
||||
def baseType: Parser[TypeTree] = positioned(
|
||||
"Bool" ^^^ BoolTypeTree()
|
||||
| "Nat" ^^^ NatTypeTree()
|
||||
| "(" ~> typ <~ ")" ^^ { case t => t }
|
||||
)
|
||||
|
||||
private def lit2Num(n: Int): Term =
|
||||
if (n == 0) Zero else Succ(lit2Num(n - 1))
|
||||
}
|
@@ -0,0 +1,83 @@
|
||||
package fos
|
||||
|
||||
import scala.util.parsing.input.Positional
|
||||
|
||||
sealed abstract class Term extends Positional
|
||||
|
||||
case object True extends Term {
|
||||
override def toString() = "true"
|
||||
}
|
||||
|
||||
case object False extends Term {
|
||||
override def toString() = "false"
|
||||
}
|
||||
|
||||
case object Zero extends Term {
|
||||
override def toString() = "0"
|
||||
}
|
||||
|
||||
case class Succ(t: Term) extends Term {
|
||||
override def toString() = "succ " + t
|
||||
}
|
||||
|
||||
case class Pred(t: Term) extends Term {
|
||||
override def toString() = "pred " + t
|
||||
}
|
||||
|
||||
case class IsZero(t: Term) extends Term {
|
||||
override def toString() = "iszero " + t
|
||||
}
|
||||
|
||||
case class If(cond: Term, t1: Term, t2: Term) extends Term {
|
||||
override def toString() = "if " + cond + " then " + t1 + " else " + t2
|
||||
}
|
||||
|
||||
case class Var(name: String) extends Term {
|
||||
override def toString() = name
|
||||
}
|
||||
|
||||
case class Abs(v: String, tp: TypeTree, t: Term) extends Term {
|
||||
override def toString() = "(\\" + v + ":" + tp + "." + t + ")"
|
||||
}
|
||||
|
||||
case class App(t1: Term, t2: Term) extends Term {
|
||||
override def toString() = t1.toString + (t2 match {
|
||||
case App(_, _) => " (" + t2.toString + ")" // left-associative
|
||||
case _ => " " + t2.toString
|
||||
})
|
||||
}
|
||||
case class Let(x: String, tp: TypeTree, v: Term, t: Term) extends Term {
|
||||
override def toString() = "let " + x + ":" + tp + " = " + v + " in " + t
|
||||
}
|
||||
|
||||
// Note that TypeTree is distinct from Type.
|
||||
// The former is how types are parsed, the latter is how types are represented.
|
||||
// We need this distinction because:
|
||||
// 1) There are type vars, which can't be written by our users, but are needed by the inferencer.
|
||||
// 2) There are empty types, which can be written, but aren't directly supported by the inferencer.
|
||||
abstract class TypeTree extends Positional {
|
||||
def tpe: Type
|
||||
}
|
||||
|
||||
case class BoolTypeTree() extends TypeTree {
|
||||
override def tpe = BoolType
|
||||
override def toString() = "Bool"
|
||||
}
|
||||
|
||||
case class NatTypeTree() extends TypeTree {
|
||||
override def tpe = NatType
|
||||
override def toString() = "Nat"
|
||||
}
|
||||
|
||||
case class FunTypeTree(t1: TypeTree, t2: TypeTree) extends TypeTree {
|
||||
override def tpe = FunType(t1.tpe, t2.tpe)
|
||||
override def toString() = (t1 match {
|
||||
case FunTypeTree(_, _) => "(" + t1 + ")" // right-associative
|
||||
case _ => t1.toString
|
||||
}) + "->" + t2
|
||||
}
|
||||
|
||||
case class EmptyTypeTree() extends TypeTree {
|
||||
override def tpe = throw new UnsupportedOperationException
|
||||
override def toString() = "_"
|
||||
}
|
@@ -0,0 +1,22 @@
|
||||
package fos
|
||||
|
||||
// Note that TypeTree is distinct from Type.
|
||||
// See a comment on TypeTree to learn more.
|
||||
abstract class Type
|
||||
|
||||
case class TypeVar(name: String) extends Type {
|
||||
override def toString() = name
|
||||
}
|
||||
|
||||
case class FunType(t1: Type, t2: Type) extends Type {
|
||||
override def toString() = "(" + t1 + " -> " + t2 + ")"
|
||||
}
|
||||
|
||||
case object NatType extends Type {
|
||||
override def toString() = "Nat"
|
||||
}
|
||||
|
||||
case object BoolType extends Type {
|
||||
override def toString() = "Bool"
|
||||
}
|
||||
|
Reference in New Issue
Block a user