diff --git a/typed-racket-lib/typed-racket/infer/infer-unit.rkt b/typed-racket-lib/typed-racket/infer/infer-unit.rkt index da1db939b..86d112655 100644 --- a/typed-racket-lib/typed-racket/infer/infer-unit.rkt +++ b/typed-racket-lib/typed-racket/infer/infer-unit.rkt @@ -22,6 +22,7 @@ "signatures.rkt" "fail.rkt" "promote-demote.rkt" racket/match + (only-in racket/function curry curryr thunk) ;racket/trace (contract-req) (for-syntax @@ -38,6 +39,7 @@ (define (empty-set) '()) (define current-seen (make-parameter (empty-set))) +(define infered-tvar-map (make-parameter (hash))) ;; Type Type -> Pair ;; construct a pair for the set of seen type pairs @@ -65,6 +67,7 @@ [(context V X Y) (context (append bounds V) (append vars X) (append indices Y))])) + (define (inferable-index? ctx bound) (match ctx [(context _ _ Y) @@ -492,6 +495,31 @@ ;; this constrains just x (which is a single var) (define (singleton S x T) (insert empty x S T)) + + (define (constrain tvar-a tvar-b #:above above) + (define (maybe-type-app t) + (match t + [(App: t1 (list (F: var))) #:when (hash-has-key? (infered-tvar-map) var) + (define v (hash-ref (infered-tvar-map) var)) + (-App t1 (list v))] + [_ t])) + + (match-define (F: var (app maybe-type-app maybe-type-bound)) tvar-a) + + (define-values (default sub sing) (if above + (values Univ + (thunk (subtype tvar-b maybe-type-bound obj)) + (curry singleton (var-promote tvar-b (context-bounds context)) var)) + (values -Bottom + (thunk (subtype maybe-type-bound tvar-b obj)) + (curryr singleton var (var-demote tvar-b (context-bounds context)))))) + (cond + [(not maybe-type-bound) (sing default)] + [(sub) + (infered-tvar-map (hash-set (infered-tvar-map) var maybe-type-bound)) + (sing maybe-type-bound)] + [else #f])) + ;; FIXME -- figure out how to use parameters less here ;; subtyping doesn't need to use it quite as much (define cs (current-seen)) @@ -568,24 +596,24 @@ ;; variables that are in X and should be constrained ;; all other variables are compatible only with themselves - [((F: (? (inferable-var? context) v)) T) + [((F: (? (inferable-var? context))) T) #:return-when (match T ;; fail when v* is an index variable [(F: v*) (and (bound-index? v*) (not (bound-tvar? v*)))] [_ #f]) #f - ;; constrain v to be below T (but don't mention bounds) - (singleton -Bottom v (var-demote T (context-bounds context)))] + ;; constrain S to be below T (but don't mention bounds) + (constrain S T #:above #f)] - [(S (F: (? (inferable-var? context) v))) + [(S (F: (? (inferable-var? context)))) #:return-when (match S [(F: v*) (and (bound-index? v*) (not (bound-tvar? v*)))] [_ #f]) #f - ;; constrain v to be above S (but don't mention bounds) - (singleton (var-promote S (context-bounds context)) v Univ)] + ;; constrain T to be above S (but don't mention bounds) + (constrain T S #:above #t)] ;; recursive names should get resolved as they're seen [(s (? Name? t)) @@ -595,6 +623,10 @@ (let ([s (resolve-once s)]) (and s (cg s t obj)))] + [((F: var (? Type? bound)) t) + (let ([s (resolve-once bound)]) + (and s (cg s t obj)))] + ;; constrain b1 to be below T, but don't mention the new vars [((Poly: v1 b1) T) (cgen (context-add context #:bounds v1) b1 T)] @@ -966,6 +998,7 @@ (build-subst md)) (build-subst (stream-first (cset-maps C))))) + ;; context : the context of what to infer/not infer ;; S : a list of types to be the subtypes of T ;; T : a list of types @@ -983,9 +1016,9 @@ (for/list/fail ([s (in-list S)] [t (in-list T)] [obj (in-list/rest objs #f)]) - ;; We meet early to prune the csets to a reasonable size. - ;; This weakens the inference a bit, but sometimes avoids - ;; constraint explosion. + ;; We meet early to prune the csets to a reasonable size. + ;; This weakens the inference a bit, but sometimes avoids + ;; constraint explosion. (% cset-meet (cgen context s t obj) expected-cset))))) @@ -1031,16 +1064,17 @@ ;; like infer, but T-var is the vararg type: (define (infer/vararg X Y S T T-var R [expected #f] #:objs [objs '()]) - (and ((length S) . >= . (length T)) - (let* ([fewer-ts (- (length S) (length T))] - [new-T (match T-var - [(? Type? var-t) (list-extend S T var-t)] - [(Rest: rst-ts) - #:when (zero? (remainder fewer-ts (length rst-ts))) - (append T (repeat-list rst-ts - (quotient fewer-ts (length rst-ts))))] - [_ T])]) - (infer X Y S new-T R expected #:objs objs)))) + (parameterize ([infered-tvar-map (hash)]) + (and ((length S) . >= . (length T)) + (let* ([fewer-ts (- (length S) (length T))] + [new-T (match T-var + [(? Type? var-t) (list-extend S T var-t)] + [(Rest: rst-ts) + #:when (zero? (remainder fewer-ts (length rst-ts))) + (append T (repeat-list rst-ts + (quotient fewer-ts (length rst-ts))))] + [_ T])]) + (infer X Y S new-T R expected #:objs objs))))) ;; like infer, but dotted-var is the bound on the ... ;; and T-dotted is the repeated type diff --git a/typed-racket-lib/typed-racket/private/parse-type.rkt b/typed-racket-lib/typed-racket/private/parse-type.rkt index 64a7412da..ec55ceffd 100644 --- a/typed-racket-lib/typed-racket/private/parse-type.rkt +++ b/typed-racket-lib/typed-racket/private/parse-type.rkt @@ -193,6 +193,11 @@ (parse-literal-alls #'t.type))] [_ null])) +(define-syntax-class maybe-bounded + #:datum-literals (<:) + #:attributes (name bound) + (pattern (name:id <: bound:expr)) + (pattern name:id #:attr bound #f)) ;; Syntax -> Type ;; Parse a Forall type @@ -215,7 +220,21 @@ "variable" (syntax-e maybe-dup))) (let* ([vars (stx-map syntax-e #'(vars ...))]) (extend-tvars vars - (make-Poly vars (parse-type #'t.type))))] + (make-Poly vars (parse-type #'t.type))))] + [(:All^ (vars:maybe-bounded ...) . t:omit-parens) + (define maybe-dup (check-duplicate-identifier (attribute vars.name))) + (when maybe-dup + (parse-error "duplicate type variable" + "variable" (syntax-e maybe-dup))) + (define bounds (for/fold ([acc (hash)]) + ([i (stx-map syntax-e (attribute vars.name))] + [j (attribute vars.bound)] + #:when j) + (hash-set acc i + (extend-tvars (hash-keys acc) (parse-type j))))) + (let* ([vars (stx-map syntax-e (attribute vars.name))]) + (extend-tvars vars + (make-Poly vars (parse-type #'t.type) #:bounds bounds)))] ;; Next two are row polymorphic cases [(:All^ (var:id #:row) . t:omit-parens) (add-disappeared-use #'kw) diff --git a/typed-racket-lib/typed-racket/rep/type-rep.rkt b/typed-racket-lib/typed-racket/rep/type-rep.rkt index 257eb1523..334aed47b 100644 --- a/typed-racket-lib/typed-racket/rep/type-rep.rkt +++ b/typed-racket-lib/typed-racket/rep/type-rep.rkt @@ -24,6 +24,7 @@ syntax/id-set racket/contract racket/lazy-require + racket/syntax racket/unsafe/undefined (for-syntax racket/base racket/syntax @@ -46,6 +47,9 @@ PolyDots-unsafe: Mu? Poly? PolyDots? PolyRow? Poly-n + F-n + F-bound + F? PolyDots-n Class? Row? Row: free-vars* @@ -84,6 +88,8 @@ [PolyDots:* PolyDots:] [PolyRow:* PolyRow:] [Mu* make-Mu] + [F* make-F] + [F:* F:] [make-Mu unsafe-make-Mu] [Poly* make-Poly] [PolyDots* make-PolyDots] @@ -105,6 +111,7 @@ (App? x))) (lazy-require + ("../types/substitute.rkt" (subst)) ("../types/overlap.rkt" (overlap?)) ("../types/prop-ops.rkt" (-and)) ("../types/resolve.rkt" (resolve-app)) @@ -139,13 +146,26 @@ ;; free type variables ;; n is a Name -(def-type F ([n symbol?]) +(def-type F ([n symbol?] + [bound (or/c #f Type?)]) + #:no-provide [#:frees [#:vars (_) (single-free-var n)] [#:idxs (_) empty-free-vars]] [#:fmap (_ #:self self) self] [#:for-each (_) (void)]) +(define (F* n [bound #f]) + (make-F n bound)) + + +(define-match-expander F:* + (lambda (stx) + (syntax-case stx () + [(_ n) #'(F: n _)] + [(_ n b) #'(F: n b)]))) + + (define Name-table (make-free-id-table)) ;; Name, an indirection of a type through the environment @@ -519,10 +539,14 @@ ;; n is how many variables are bound here ;; body is a type (def-type Poly ([n exact-nonnegative-integer?] + [bounds (hash/c exact-nonnegative-integer? + Type? + #:immutable #t + #:flat #t)] [body Type?]) #:no-provide [#:frees (f) (f body)] - [#:fmap (f) (make-Poly n (f body))] + [#:fmap (f) (make-Poly n bounds (f body))] [#:for-each (f) (f body)] [#:mask (λ (t) (mask (Poly-body t)))]) @@ -1456,7 +1480,7 @@ ;; De Bruijn indices [(B: idx) (transform idx lvl cur #f)] ;; Type variables - [(F: var) (transform var lvl cur #f)] + [(F: var _) (transform var lvl cur #f)] ;; forms w/ dotted type vars/indices [(RestDots: ty d) (make-RestDots (rec ty) (transform d lvl d #t))] @@ -1477,8 +1501,8 @@ (make-PolyRow constraints (rec/lvl body (add1 lvl)))] [(PolyDots: n body) (make-PolyDots n (rec/lvl body (+ n lvl)))] - [(Poly: n body) - (make-Poly n (rec/lvl body (+ n lvl)))] + [(Poly: n bounds body) + (make-Poly n bounds (rec/lvl body (+ n lvl)))] [_ (Rep-fmap cur rec)]))) @@ -1618,7 +1642,7 @@ (define (Mu-body* name t) (match t [(Mu: body) - (instantiate-type body (make-F name))])) + (instantiate-type body (F* name))])) ;; unfold : Mu -> Type (define/cond-contract (unfold t) @@ -1638,19 +1662,42 @@ ;; ;; list type #:original-names list -> type ;; -(define (Poly* names body #:original-names [orig names]) + +(define (Poly* names body #:bounds [bounds '#hash()] #:original-names [orig names]) (if (null? names) body - (let ([v (make-Poly (length names) (abstract-type body names))]) + (let* ([len (length names)] + [new-bounds (for/hash ([(n v) bounds]) + (values (index-of names n) v))] + [v (make-Poly len new-bounds (abstract-type body names))]) (hash-set! type-var-name-table v orig) v))) +(define (unsubst ty orig-names names) + (for/fold ([acc ty]) + ([o orig-names] + [n names]) + (subst o (make-F n #f) acc) + #; + (subst o (make-Name (format-id #f "~a" n) 0 #f) acc))) + ;; Poly 'smart' destructor (define (Poly-body* names t) + (define orig-names (hash-ref type-var-name-table t null)) (match t - [(Poly: n body) + [(Poly: n bounds body) + (define new-bounds (for/hash ([(idx v) bounds]) + (values (list-ref names idx) (unsubst v orig-names names)))) (unless (= (length names) n) (int-err "Wrong number of names: expected ~a got ~a" n (length names))) - (instantiate-type body (map make-F names))])) + (instantiate-type body + (map (lambda (n) + (define v (match (hash-ref new-bounds n #f) + [(App: rator (list (F: vb _))) + #:when (hash-has-key? new-bounds vb) + (make-App rator (list (hash-ref new-bounds vb)))] + [_else _else])) + (make-F n v)) + names))])) ;; PolyDots 'smart' constructor (define (PolyDots* names body) @@ -1665,7 +1712,7 @@ [(PolyDots: n body) (unless (= (length names) n) (int-err "Wrong number of names: expected ~a got ~a" n (length names))) - (instantiate-type body (map make-F names))])) + (instantiate-type body (map F* names))])) ;; PolyRow 'smart' constructor @@ -1683,7 +1730,7 @@ (define (PolyRow-body* names t) (match t [(PolyRow: constraints body) - (instantiate-type body (map make-F names))])) + (instantiate-type body (map F* names))])) ;;*************************************************************** @@ -1939,7 +1986,7 @@ [(Some: n body) (unless (= (length names) n) (int-err "Wrong number of names: expected ~a got ~a" n (length names))) - (instantiate-type body (map make-F names))])) + (instantiate-type body (map F* names))])) (define-match-expander Some-names: diff --git a/typed-racket-lib/typed-racket/types/printer.rkt b/typed-racket-lib/typed-racket/types/printer.rkt index 7abb422d8..8a75d1dd5 100644 --- a/typed-racket-lib/typed-racket/types/printer.rkt +++ b/typed-racket-lib/typed-racket/types/printer.rkt @@ -714,10 +714,11 @@ [(? Base?) (Base-name type)] [(Pair: l r) `(Pairof ,(t->s l) ,(t->s r))] [(ListDots: dty dbound) `(List ,(t->s dty) ... ,dbound)] - [(F: nm) + [(F: nm bound) (cond [(eq? nm imp-var) "Imp"] [(eq? nm self-var) "Self"] + [(Type? bound) (format "~a <: ~a" nm (t->s bound))] [else nm])] [(Param: in out) (if (equal? in out) diff --git a/typed-racket-lib/typed-racket/types/subtype.rkt b/typed-racket-lib/typed-racket/types/subtype.rkt index 3d580858e..3673eb946 100644 --- a/typed-racket-lib/typed-racket/types/subtype.rkt +++ b/typed-racket-lib/typed-racket/types/subtype.rkt @@ -802,11 +802,12 @@ (match t2 [(Evt: result2) (subtype* A result1 result2)] [_ (continue<: A t1 t2 obj)])] - [(case: F (F: var1)) - (match t2 + [(case: F (F: var1 bound)) + (match* (t2 bound) ;; tvars are equal if they are the same variable - [(F: var2) (and (eq? var1 var2) A)] - [_ (continue<: A t1 t2 obj)])] + [((F: var2) _) (and (eq? var1 var2) A)] + [(_ (? Type?)) (subtype* A bound t2 obj)] + [(_ _) (continue<: A t1 t2 obj)])] [(case: Fun (Fun: arrows1)) (match* (t2 arrows1) ;; special case when t1 can be collapsed into simpler arrow diff --git a/typed-racket-test/fail/bounded-poly.rkt b/typed-racket-test/fail/bounded-poly.rkt new file mode 100644 index 000000000..9f23e6b81 --- /dev/null +++ b/typed-racket-test/fail/bounded-poly.rkt @@ -0,0 +1,9 @@ +#lang typed/racket/base + +(: a-func (All ([X <: Integer]) + (-> X String))) +(define (a-func a) + (number->string a)) + + +(a-func 19.999) ;; fail diff --git a/typed-racket-test/main.rkt b/typed-racket-test/main.rkt index d40e4a308..a131c5639 100644 --- a/typed-racket-test/main.rkt +++ b/typed-racket-test/main.rkt @@ -192,7 +192,7 @@ (define missed-opt? (make-parameter #f)) (define bench? (make-parameter #f)) (define math? (make-parameter #f)) - (define excl (make-parameter (list))) + (define excl (make-parameter (set))) (define single (make-parameter #f)) (current-namespace (make-base-namespace)) (command-line diff --git a/typed-racket-test/succeed/bounded-poly.rkt b/typed-racket-test/succeed/bounded-poly.rkt new file mode 100644 index 000000000..445db0bd9 --- /dev/null +++ b/typed-racket-test/succeed/bounded-poly.rkt @@ -0,0 +1,26 @@ +#lang typed/racket/base + +(: a-func (All ([X <: Integer]) + (-> X String))) +(define (a-func a) + (number->string a)) + + +(a-func 10) ;; pass + + +(struct (A) foo ([a : A]) #:type-name Foo) + +(: c-func (All ([X <: Integer] [Y <: (Foo X)]) + (-> X Y String))) +(define (c-func a b) + (number->string a)) + +(c-func 10 (foo 10)) + +(: d-func (All ([X <: Integer] [Y <: (Foo X)]) + (-> X Y String))) +(define (d-func a b) + (number->string (foo-a b))) + +(d-func 42 (foo 10))