#lang plait (define-type TypeExp [numTE] [boolTE] [arrowTE (arg : TypeExp) (result : TypeExp)] [objTE (fields : (Listof (Symbol * TypeExp)))]) (define-type Exp [numE (n : Number)] [boolE (b : Boolean)] [plusE (left : Exp) (right : Exp)] [timesE (left : Exp) (right : Exp)] [minusE (left : Exp) (right : Exp)] [leqE (left : Exp) (right : Exp)] [lamE (var : Symbol) (te : TypeExp) (body : Exp)] [appE (fun : Exp) (arg : Exp)] [varE (name : Symbol)] [ifE (check : Exp) (zero : Exp) (non-zero : Exp)] [let1E (var : Symbol) (te : TypeExp) (value : Exp) (body : Exp)] [recE (var : Symbol) (te : TypeExp) (value : Exp) (body : Exp)] [objE (fields : (Listof (Symbol * Exp)))] [msgE (obj : Exp) (selector : Symbol)] ) (define-type Type [numT] [boolT] [arrowT (arg : Type) (result : Type)] [objT (fields : (Hashof Symbol Type))]) (define-type-alias TypeEnv (Hashof Symbol Type)) (define mt-type-env (hash empty)) ;; "empty type environment" (define (type-lookup (s : Symbol) (n : TypeEnv)) (type-case (Optionof Type) (hash-ref n s) [(none) (error s "not bound")] [(some b) b])) (module+ test (test/exn (type-lookup 'x mt-type-env) "not bound")) (define (type-extend (env : TypeEnv) (s : Symbol) (t : Type)) (hash-set env s t)) (define (interp-te te) (type-case TypeExp te [(numTE) (numT)] [(boolTE) (boolT)] [(arrowTE a b) (arrowT (interp-te a) (interp-te b))] [(objTE fields) (objT (hash (map (lambda (key-val) (values (fst key-val) (interp-te (snd key-val)))) fields)))])) (module+ test (test (interp-te (objTE (list (pair 'add1 (arrowTE (numTE) (numTE))) (pair 'compare (arrowTE (numTE) (boolTE)))))) (objT (hash (list (pair 'add1 (arrowT (numT) (numT))) (pair 'compare (arrowT (numT) (boolT)))))))) (define (subtype? X Y) (type-case Type X [(numT) (type-case Type Y [(numT) #t] [else #f])] [(boolT) (type-case Type Y [(boolT) #t] [else #f])] [(arrowT X-arg X-result) (type-case Type Y [(arrowT Y-arg Y-result) (and (subtype? Y-arg X-arg) ;; Contravariance of arguments (subtype? X-result Y-result))] ;; Covariance of results [else #f])] [(objT X-fields) (type-case Type Y [(objT Y-fields) (local [(define (loop keys) (if (empty? keys) #t (let ([key (first keys)]) (type-case (Optionof Type) (hash-ref Y-fields key) [(none) #f] ;; Key not found in Y-fields [(some Y-type) (type-case (Optionof Type) (hash-ref X-fields key) [(none) #f] ;; Key not found in X-fields [(some X-type) (and (subtype? X-type Y-type) ;; Check subtyping of field types (loop (rest keys)))])]))))] ;; Recurse on remaining keys (loop (hash-keys Y-fields)))] [else #f])])) (module+ test (define hello-t (objT (hash (list (pair 'hello (numT)))))) (define hello-goodbye-t (objT (hash (list (pair 'hello (numT)) (pair 'goodbye (boolT)))))) (test (subtype? (numT) (boolT)) #f) (test (subtype? (numT) (numT)) #t) (test (subtype? (numT) hello-t) #f) (test (subtype? hello-t (objT (hash (list (pair 'hello (boolT)))))) #f) (test (subtype? hello-goodbye-t hello-t) #t) (test (subtype? hello-t hello-goodbye-t) #f)) (define (typecheck [exp : Exp] [env : TypeEnv]) : Type (local [(define (num2 l r type) (let ([left-t (typecheck l env)] [right-t (typecheck r env)]) (if (and (equal? (numT) left-t) (equal? (numT) right-t)) type (error 'typecheck "expected 2 num"))))] (type-case Exp exp [(numE n) (numT)] [(boolE b) (boolT)] [(plusE l r) (num2 l r (numT))] [(minusE l r) (num2 l r (numT))] [(timesE l r) (num2 l r (numT))] [(leqE l r) (num2 l r (boolT))] [(varE s) (type-lookup s env)] [(lamE name te body) (let* ([arg-type (interp-te te)] [body-type (typecheck body (type-extend env name arg-type))]) (arrowT arg-type body-type))] [(appE fn arg) (type-case Type (typecheck fn env) [(arrowT arg-type result-type) (let ([actual-type (typecheck arg env)]) (if (subtype? actual-type arg-type) ;; Use subtype? to check compatibility result-type (error 'typecheck "argument type")))] [else (error 'typecheck "not function")])] [(ifE c t f) (if (equal? (typecheck c env) (boolT)) (let ([t-type (typecheck t env)] [f-type (typecheck f env)]) (if (equal? f-type t-type) f-type (error 'typecheck "branches must have same type"))) (error 'typecheck "expected boolean"))] [(let1E var te val body) (let* ([var-t (interp-te te)] [val-t (typecheck val env)] [body-t (typecheck body (type-extend env var var-t))]) (if (equal? var-t val-t) body-t (error 'typecheck "type does not match annotation")))] [(recE var te val body) (let* ([var-t (interp-te te)] [val-t (typecheck val (type-extend env var var-t))] [body-t (typecheck body (type-extend env var var-t))]) (if (equal? var-t val-t) body-t (error 'typecheck "type does not match annotation")))] [(objE fields) (let* ([extract-exp (lambda (obj) (pair (fst obj) (typecheck (snd obj) env)))] [field-list (map extract-exp fields)]) (objT (hash field-list)))] [(msgE obj selector) (type-case Exp obj [(objE fields) (type-case (Optionof Exp) (hash-ref (make-hash fields) selector) [(none) (error 'typecheck "unknown field")] [(some v) (typecheck v env)])] [(varE name) (type-case Type (type-lookup name env) [(objT fields) (type-lookup selector fields)] [else (error 'typecheck "bound variable is not an object")])] [else (error 'typecheck "passing message to non-object")])]))) (define (parse-error sx) (error 'parse (string-append "parse error: " (to-string sx)))) (module+ test (test/exn (parse `"strings are not in our language") "parse") (test/exn (parse `{& 1 2}) "parse")) (define (sx-ref sx n) (list-ref (s-exp->list sx) n)) (define (parse-te sx) (cond [(s-exp-symbol? sx) (case (s-exp->symbol sx) [(num) (numTE)] [(bool) (boolTE)])] [(s-exp-match? `(ANY -> ANY) sx) (arrowTE (parse-te (sx-ref sx 0)) (parse-te (sx-ref sx 2)))] [(s-exp-match? `(obj (SYMBOL ANY) ...) sx) (objTE (map (lambda (element) (pair (s-exp->symbol (sx-ref element 0)) (parse-te (sx-ref element 1)))) (rest (s-exp->list sx))))])) (define (parse sx) (local [(define (px i) (parse (sx-ref sx i)))] (cond [(s-exp-number? sx) (numE (s-exp->number sx))] [(s-exp-symbol? sx) (let ([sym (s-exp->symbol sx)]) (case sym [(true) (boolE #t)] [(false) (boolE #f)] [else (varE sym)]))] [(s-exp-match? `(msg ANY SYMBOL) sx) (msgE (px 1) (s-exp->symbol (sx-ref sx 2)))] [(s-exp-match? `(obj (SYMBOL ANY) ...) sx) (objE (map (lambda (element) (pair (s-exp->symbol (sx-ref element 0)) (parse (sx-ref element 1)))) (rest (s-exp->list sx))))] [(s-exp-match? `(lam (SYMBOL : ANY) ANY) sx) (let* ([args (sx-ref sx 1)] [id (s-exp->symbol (sx-ref args 0))] [te (parse-te (sx-ref args 2))] [body (px 2)]) (lamE id te body))] [(s-exp-match? `(let1 (SYMBOL : ANY) ANY ANY) sx) (let* ([args (sx-ref sx 1)] [id (s-exp->symbol (sx-ref args 0))] [te (parse-te (sx-ref args 2))] [rhs (px 2)] [body (px 3)]) (let1E id te rhs body))] [(s-exp-match? `(rec (SYMBOL : ANY) ANY ANY) sx) (let* ([args (sx-ref sx 1)] [id (s-exp->symbol (sx-ref args 0))] [te (parse-te (sx-ref args 2))] [rhs (px 2)] [body (px 3)]) (recE id te rhs body))] [(s-exp-match? `(ANY ANY) sx) (appE (px 0) (px 1))] [(s-exp-list? sx) (case (s-exp->symbol (sx-ref sx 0)) [(+) (plusE (px 1) (px 2))] [(-) (minusE (px 1) (px 2))] [(*) (timesE (px 1) (px 2))] [(<=) (leqE (px 1) (px 2))] [(if) (ifE (px 1) (px 2) (px 3))] [else (parse-error sx)])] [else (parse-error sx)]))) (module+ test (test (parse `{obj {hello true} {goodbye 42}}) (objE (list (pair 'hello (boolE #t)) (pair 'goodbye (numE 42))))) (test (parse `{lam {x : (obj (n-func (num -> num)))} x}) (lamE 'x (objTE (list (pair 'n-func (arrowTE (numTE) (numTE))))) (varE 'x)))) (tc : (S-Exp -> Type)) (define (tc s) (typecheck (parse s) mt-type-env)) (module+ test (test (tc `{+ 1 2}) (numT)) (test/exn (tc `{+ true 2}) "expected 2 num") (test/exn (tc `{1 1}) "function") (test/exn (tc `{{lam {b : bool} false} 1}) "argument type") (test/exn (tc `{if false 1 true}) "branches") (test/exn (tc `{if 1 false true}) "boolean") (test/exn (tc `{let1 [x : num] true x}) "annotation") (test/exn (tc `{rec [x : num] true x}) "annotation")) (module+ test (define sampler `{obj {hello true} {goodbye false} {a-num 42} {n-func {lam {x : num} x}} {b-func {lam {x : bool} x}} }) (test (tc sampler) (objT (hash (list (pair 'hello (boolT)) (pair 'goodbye (boolT)) (pair 'a-num (numT)) (pair 'n-func (arrowT (numT) (numT))) (pair 'b-func (arrowT (boolT) (boolT))))))) (test (tc `{msg ,sampler hello}) (boolT)) (test/exn (tc `{msg 1 hello}) "object") (test/exn (tc `{msg ,sampler blah}) "unknown field") (define obj-fun `{lam {x : (obj (n-func (num -> num)))} {{msg x n-func} 3}}) (test (tc obj-fun) (arrowT (objT (hash (list (pair 'n-func (arrowT (numT) (numT)))))) (numT))) (test (tc `{,obj-fun {obj {n-func {lam {x : num} x}}}}) (numT)) (test/exn (tc `{,obj-fun 2}) "argument type") (test/exn (tc `{if true ,obj-fun 2}) "branches") (test (tc `{rec {fact : (obj (run (num -> num)))} {obj {run {lam {n : num} {if {<= n 0} 1 {* n {{msg fact run} {- n 1}}}}}}} {{msg fact run} 10}}) (numT))) (module+ test (test (tc `{,obj-fun {obj {n-func {lam {x : num} x}} {b-func {lam {x : bool} x}}}}) (numT)) (test (tc `{let1 {f : {(obj (n-func (num -> num))) -> num}} ,obj-fun {f ,sampler}}) (numT))) (module+ test (test (subtype? (arrowT hello-t hello-t) (arrowT hello-t hello-t)) #t) (test (subtype? (arrowT hello-t hello-t) (arrowT hello-t hello-goodbye-t)) #f) (test (subtype? (arrowT hello-t hello-goodbye-t) (arrowT hello-t hello-t)) #t) (test (subtype? (arrowT hello-t hello-goodbye-t) (arrowT hello-goodbye-t hello-t)) #t) (test (subtype? (arrowT hello-goodbye-t hello-goodbye-t) (arrowT hello-t hello-t)) #f) ;; for coverage (test (subtype? (arrowT (numT) (numT)) (numT)) #f) (test (subtype? (numT) (arrowT (numT) (numT))) #f))