diff --git a/lib/steep/ast/types/factory.rb b/lib/steep/ast/types/factory.rb index d44ce99b1..6950444fd 100644 --- a/lib/steep/ast/types/factory.rb +++ b/lib/steep/ast/types/factory.rb @@ -342,11 +342,10 @@ def interface(type, private:, self_type: type) definition.methods.each do |name, method| next if method.private? && !private - interface.methods[name] = Interface::Interface::Combination.overload( - method.method_types.map do |type| + interface.methods[name] = Interface::Interface::Entry.new( + method_types: method.method_types.map do |type| method_type(type, self_type: self_type, subst2: subst) - end, - incompatible: name == :initialize || name == :new + end ) end end @@ -363,11 +362,10 @@ def interface(type, private:, self_type: type) ) definition.methods.each do |name, method| - interface.methods[name] = Interface::Interface::Combination.overload( - method.method_types.map do |type| + interface.methods[name] = Interface::Interface::Entry.new( + method_types: method.method_types.map do |type| method_type(type, self_type: self_type, subst2: subst) - end, - incompatible: false + end ) end end @@ -389,11 +387,10 @@ def interface(type, private:, self_type: type) definition.methods.each do |name, method| next if !private && method.private? - interface.methods[name] = Interface::Interface::Combination.overload( - method.method_types.map do |type| + interface.methods[name] = Interface::Interface::Entry.new( + method_types: method.method_types.map do |type| method_type(type, self_type: self_type, subst2: subst) - end, - incompatible: false + end ) end end @@ -416,7 +413,25 @@ def interface(type, private:, self_type: type) Interface::Interface.new(type: self_type, private: private).tap do |interface| common_methods = Set.new(interface1.methods.keys) & Set.new(interface2.methods.keys) common_methods.each do |name| - interface.methods[name] = Interface::Interface::Combination.union([interface1.methods[name], interface2.methods[name]]) + types1 = interface1.methods[name].method_types + types2 = interface2.methods[name].method_types + + if types1 == types2 + interface.methods[name] = interface1.methods[name] + else + method_types = {} + + types1.each do |type1| + types2.each do |type2| + type = type1 | type2 or next + method_types[type] = true + end + end + + unless method_types.empty? + interface.methods[name] = Interface::Interface::Entry.new(method_types: method_types.keys) + end + end end end end @@ -427,11 +442,8 @@ def interface(type, private:, self_type: type) interfaces = type.types.map {|ty| interface(ty, private: private, self_type: self_type) } interfaces.inject do |interface1, interface2| Interface::Interface.new(type: self_type, private: private).tap do |interface| - all_methods = Set.new(interface1.methods.keys) + Set.new(interface2.methods.keys) - all_methods.each do |name| - methods = [interface1.methods[name], interface2.methods[name]].compact - interface.methods[name] = Interface::Interface::Combination.intersection(methods) - end + interface.methods.merge!(interface1.methods) + interface.methods.merge!(interface2.methods) end end end @@ -442,8 +454,8 @@ def interface(type, private:, self_type: type) array_type = Builtin::Array.instance_type(element_type) interface(array_type, private: private, self_type: self_type).tap do |array_interface| array_interface.methods[:[]] = array_interface.methods[:[]].yield_self do |aref| - Interface::Interface::Combination.overload( - type.types.map.with_index {|elem_type, index| + Interface::Interface::Entry.new( + method_types: type.types.map.with_index {|elem_type, index| Interface::MethodType.new( type_params: [], params: Interface::Params.new(required: [AST::Types::Literal.new(value: index)], @@ -456,14 +468,13 @@ def interface(type, private:, self_type: type) return_type: elem_type, location: nil ) - } + aref.types, - incompatible: false + } + aref.method_types ) end array_interface.methods[:[]=] = array_interface.methods[:[]=].yield_self do |update| - Interface::Interface::Combination.overload( - type.types.map.with_index {|elem_type, index| + Interface::Interface::Entry.new( + method_types: type.types.map.with_index {|elem_type, index| Interface::MethodType.new( type_params: [], params: Interface::Params.new(required: [AST::Types::Literal.new(value: index), elem_type], @@ -476,14 +487,13 @@ def interface(type, private:, self_type: type) return_type: elem_type, location: nil ) - } + update.types, - incompatible: false + } + update.method_types ) end array_interface.methods[:first] = array_interface.methods[:first].yield_self do |first| - Interface::Interface::Combination.overload( - [ + Interface::Interface::Entry.new( + method_types: [ Interface::MethodType.new( type_params: [], params: Interface::Params.empty, @@ -491,14 +501,13 @@ def interface(type, private:, self_type: type) return_type: type.types[0] || AST::Builtin.nil_type, location: nil ) - ], - incompatible: false + ] ) end array_interface.methods[:last] = array_interface.methods[:last].yield_self do |last| - Interface::Interface::Combination.overload( - [ + Interface::Interface::Entry.new( + method_types: [ Interface::MethodType.new( type_params: [], params: Interface::Params.empty, @@ -506,8 +515,7 @@ def interface(type, private:, self_type: type) return_type: type.types.last || AST::Builtin.nil_type, location: nil ) - ], - incompatible: false + ] ) end end @@ -523,8 +531,8 @@ def interface(type, private:, self_type: type) interface(hash_type, private: private, self_type: self_type).tap do |hash_interface| hash_interface.methods[:[]] = hash_interface.methods[:[]].yield_self do |ref| - Interface::Interface::Combination.overload( - type.elements.map {|key_value, value_type| + Interface::Interface::Entry.new( + method_types: type.elements.map {|key_value, value_type| key_type = Literal.new(value: key_value, location: nil) Interface::MethodType.new( type_params: [], @@ -538,14 +546,13 @@ def interface(type, private:, self_type: type) return_type: value_type, location: nil ) - } + ref.types, - incompatible: false + } + ref.method_types ) end hash_interface.methods[:[]=] = hash_interface.methods[:[]=].yield_self do |update| - Interface::Interface::Combination.overload( - type.elements.map {|key_value, value_type| + Interface::Interface::Entry.new( + method_types: type.elements.map {|key_value, value_type| key_type = Literal.new(value: key_value, location: nil) Interface::MethodType.new( type_params: [], @@ -559,8 +566,7 @@ def interface(type, private:, self_type: type) return_type: value_type, location: nil ) - } + update.types, - incompatible: false + } + update.method_types ) end end @@ -576,8 +582,8 @@ def interface(type, private:, self_type: type) location: nil ) - interface.methods[:[]] = Interface::Interface::Combination.overload([method_type], incompatible: false) - interface.methods[:call] = Interface::Interface::Combination.overload([method_type], incompatible: false) + interface.methods[:[]] = Interface::Interface::Entry.new(method_types: [method_type]) + interface.methods[:call] = Interface::Interface::Entry.new(method_types: [method_type]) end else diff --git a/lib/steep/ast/types/intersection.rb b/lib/steep/ast/types/intersection.rb index 721415b22..ee28b426e 100644 --- a/lib/steep/ast/types/intersection.rb +++ b/lib/steep/ast/types/intersection.rb @@ -28,33 +28,36 @@ def self.build(types:, location: nil) else type end - end.compact.uniq.yield_self do |tys| - if tys.size == 1 + end.compact.yield_self do |tys| + dups = Set.new(tys) + + case dups.size + when 0 + AST::Types::Top.new(location: location) + when 1 tys.first else - new(types: tys.sort_by(&:hash), location: location) + new(types: dups.to_a, location: location) end end end def ==(other) - other.is_a?(Intersection) && - other.types == types + other.is_a?(Intersection) && other.types == types end def hash - self.class.hash ^ types.hash + @hash ||= self.class.hash ^ types.hash end alias eql? == def subst(s) - self.class.build(location: location, - types: types.map {|ty| ty.subst(s) }) + self.class.build(location: location, types: types.map {|ty| ty.subst(s) }) end def to_s - "(#{types.map(&:to_s).sort.join(" & ")})" + "(#{types.map(&:to_s).join(" & ")})" end def free_variables() diff --git a/lib/steep/ast/types/union.rb b/lib/steep/ast/types/union.rb index af6cac8f9..9bc5b4fcc 100644 --- a/lib/steep/ast/types/union.rb +++ b/lib/steep/ast/types/union.rb @@ -38,29 +38,28 @@ def self.build(types:, location: nil) when 1 tys.first else - new(types: tys.sort_by(&:hash), location: location) + new(types: tys, location: location) end end end def ==(other) other.is_a?(Union) && - other.types == types + Set.new(other.types) == Set.new(types) end def hash - self.class.hash ^ types.hash + @hash ||= self.class.hash ^ types.sort_by(&:to_s).hash end alias eql? == def subst(s) - self.class.build(location: location, - types: types.map {|ty| ty.subst(s) }) + self.class.build(location: location, types: types.map {|ty| ty.subst(s) }) end def to_s - "(#{types.map(&:to_s).sort.join(" | ")})" + "(#{types.map(&:to_s).join(" | ")})" end def free_variables diff --git a/lib/steep/interface/interface.rb b/lib/steep/interface/interface.rb index fdf5db151..14f33b28f 100644 --- a/lib/steep/interface/interface.rb +++ b/lib/steep/interface/interface.rb @@ -1,72 +1,15 @@ module Steep module Interface class Interface - class Combination - attr_reader :operator - attr_reader :types + class Entry + attr_reader :method_types - def initialize(operator:, types:) - @types = types - @operator = operator - @incompatible = false - end - - def overload? - operator == :overload - end - - def union? - operator == :union - end - - def intersection? - operator == :intersection - end - - def self.overload(types, incompatible:) - new(operator: :overload, types: types).incompatible!(incompatible) - end - - def incompatible? - @incompatible - end - - def incompatible!(value) - @incompatible = value - self - end - - def self.union(types) - case types.size - when 0 - raise "Combination.union called with zero types" - when 1 - types.first - else - new(operator: :union, types: types) - end - end - - def self.intersection(types) - case types.size - when 0 - raise "Combination.intersection called with zero types" - when 1 - types.first - else - new(operator: :intersection, types: types) - end + def initialize(method_types:) + @method_types = method_types end def to_s - case operator - when :overload - "{ #{types.map(&:to_s).join(" | ")} }" - when :union - "[#{types.map(&:to_s).join(" | ")}]" - when :intersection - "[#{types.map(&:to_s).join(" & ")}]" - end + "{ #{method_types.join(" || ")} }" end end diff --git a/lib/steep/interface/method_type.rb b/lib/steep/interface/method_type.rb index 3c761580f..72bd7ead7 100644 --- a/lib/steep/interface/method_type.rb +++ b/lib/steep/interface/method_type.rb @@ -17,15 +17,14 @@ def initialize(required:, optional:, rest:, required_keywords:, optional_keyword @rest_keywords = rest_keywords end - NONE = Object.new - def update(required: NONE, optional: NONE, rest: NONE, required_keywords: NONE, optional_keywords: NONE, rest_keywords: NONE) + def update(required: self.required, optional: self.optional, rest: self.rest, required_keywords: self.required_keywords, optional_keywords: self.optional_keywords, rest_keywords: self.rest_keywords) self.class.new( - required: required.equal?(NONE) ? self.required : required, - optional: optional.equal?(NONE) ? self.optional : optional, - rest: rest.equal?(NONE) ? self.rest : rest, - required_keywords: required_keywords.equal?(NONE) ? self.required_keywords : required_keywords, - optional_keywords: optional_keywords.equal?(NONE) ? self.optional_keywords : optional_keywords, - rest_keywords: rest_keywords.equal?(NONE) ? self.rest_keywords : rest_keywords + required: required, + optional: optional, + rest: rest, + required_keywords: required_keywords, + optional_keywords: optional_keywords, + rest_keywords: rest_keywords, ) end @@ -84,6 +83,12 @@ def ==(other) other.rest_keywords == rest_keywords end + alias eql? == + + def hash + required.hash ^ optional.hash ^ rest.hash ^ required_keywords.hash ^ optional_keywords.hash ^ rest_keywords.hash + end + def flat_unnamed_params required.map {|p| [:required, p] } + optional.map {|p| [:optional, p] } end @@ -278,59 +283,61 @@ def empty? !has_positional? && !has_keywords? end - def |(other) + # self + params returns a new params for overloading. + # + def +(other) a = first_param b = other.first_param case when a.is_a?(RequiredPositional) && b.is_a?(RequiredPositional) AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first | other.drop_first).with_first_param(RequiredPositional.new(type)) + (self.drop_first + other.drop_first).with_first_param(RequiredPositional.new(type)) end when a.is_a?(RequiredPositional) && b.is_a?(OptionalPositional) AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first | other.drop_first).with_first_param(OptionalPositional.new(type)) + (self.drop_first + other.drop_first).with_first_param(OptionalPositional.new(type)) end when a.is_a?(RequiredPositional) && b.is_a?(RestPositional) AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first | other).with_first_param(OptionalPositional.new(type)) + (self.drop_first + other).with_first_param(OptionalPositional.new(type)) end when a.is_a?(RequiredPositional) && b.nil? - (self.drop_first | other).with_first_param(OptionalPositional.new(a.type)) + (self.drop_first + other).with_first_param(OptionalPositional.new(a.type)) when a.is_a?(OptionalPositional) && b.is_a?(RequiredPositional) AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first | other.drop_first).with_first_param(OptionalPositional.new(type)) + (self.drop_first + other.drop_first).with_first_param(OptionalPositional.new(type)) end when a.is_a?(OptionalPositional) && b.is_a?(OptionalPositional) AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first | other.drop_first).with_first_param(OptionalPositional.new(type)) + (self.drop_first + other.drop_first).with_first_param(OptionalPositional.new(type)) end when a.is_a?(OptionalPositional) && b.is_a?(RestPositional) AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first | other).with_first_param(OptionalPositional.new(type)) + (self.drop_first + other).with_first_param(OptionalPositional.new(type)) end when a.is_a?(OptionalPositional) && b.nil? - (self.drop_first | other).with_first_param(OptionalPositional.new(a.type)) + (self.drop_first + other).with_first_param(OptionalPositional.new(a.type)) when a.is_a?(RestPositional) && b.is_a?(RequiredPositional) AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self | other.drop_first).with_first_param(OptionalPositional.new(type)) + (self + other.drop_first).with_first_param(OptionalPositional.new(type)) end when a.is_a?(RestPositional) && b.is_a?(OptionalPositional) AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self | other.drop_first).with_first_param(OptionalPositional.new(type)) + (self + other.drop_first).with_first_param(OptionalPositional.new(type)) end when a.is_a?(RestPositional) && b.is_a?(RestPositional) AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first | other.drop_first).with_first_param(RestPositional.new(type)) + (self.drop_first + other.drop_first).with_first_param(RestPositional.new(type)) end when a.is_a?(RestPositional) && b.nil? - (self.drop_first | other).with_first_param(RestPositional.new(a.type)) + (self.drop_first + other).with_first_param(RestPositional.new(a.type)) when a.nil? && b.is_a?(RequiredPositional) - (self | other.drop_first).with_first_param(OptionalPositional.new(b.type)) + (self + other.drop_first).with_first_param(OptionalPositional.new(b.type)) when a.nil? && b.is_a?(OptionalPositional) - (self | other.drop_first).with_first_param(OptionalPositional.new(b.type)) + (self + other.drop_first).with_first_param(OptionalPositional.new(b.type)) when a.nil? && b.is_a?(RestPositional) - (self | other.drop_first).with_first_param(RestPositional.new(b.type)) + (self + other.drop_first).with_first_param(RestPositional.new(b.type)) when a.nil? && b.nil? required_keywords = {} @@ -410,6 +417,9 @@ def |(other) end end + # Returns the intersection between self and other. + # Returns nil if the intersection cannot be computed. + # def &(other) a = first_param b = other.first_param @@ -417,48 +427,48 @@ def &(other) case when a.is_a?(RequiredPositional) && b.is_a?(RequiredPositional) AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first & other.drop_first).with_first_param(RequiredPositional.new(type)) + (self.drop_first & other.drop_first)&.with_first_param(RequiredPositional.new(type)) end when a.is_a?(RequiredPositional) && b.is_a?(OptionalPositional) AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first & other.drop_first).with_first_param(RequiredPositional.new(type)) + (self.drop_first & other.drop_first)&.with_first_param(RequiredPositional.new(type)) end when a.is_a?(RequiredPositional) && b.is_a?(RestPositional) AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first & other).with_first_param(RequiredPositional.new(type)) + (self.drop_first & other)&.with_first_param(RequiredPositional.new(type)) end when a.is_a?(RequiredPositional) && b.nil? - (self.drop_first & other).with_first_param(RequiredPositional.new(AST::Types::Bot.new)) + nil when a.is_a?(OptionalPositional) && b.is_a?(RequiredPositional) AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first & other.drop_first).with_first_param(RequiredPositional.new(type)) + (self.drop_first & other.drop_first)&.with_first_param(RequiredPositional.new(type)) end when a.is_a?(OptionalPositional) && b.is_a?(OptionalPositional) AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first & other.drop_first).with_first_param(OptionalPositional.new(type)) + (self.drop_first & other.drop_first)&.with_first_param(OptionalPositional.new(type)) end when a.is_a?(OptionalPositional) && b.is_a?(RestPositional) AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first & other).with_first_param(OptionalPositional.new(type)) + (self.drop_first & other)&.with_first_param(OptionalPositional.new(type)) end when a.is_a?(OptionalPositional) && b.nil? self.drop_first & other when a.is_a?(RestPositional) && b.is_a?(RequiredPositional) AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| - (self & other.drop_first).with_first_param(RequiredPositional.new(type)) + (self & other.drop_first)&.with_first_param(RequiredPositional.new(type)) end when a.is_a?(RestPositional) && b.is_a?(OptionalPositional) AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| - (self & other.drop_first).with_first_param(OptionalPositional.new(type)) + (self & other.drop_first)&.with_first_param(OptionalPositional.new(type)) end when a.is_a?(RestPositional) && b.is_a?(RestPositional) AST::Types::Intersection.build(types: [a.type, b.type]).yield_self do |type| - (self.drop_first & other.drop_first).with_first_param(RestPositional.new(type)) + (self.drop_first & other.drop_first)&.with_first_param(RestPositional.new(type)) end when a.is_a?(RestPositional) && b.nil? self.drop_first & other when a.nil? && b.is_a?(RequiredPositional) - (self & other.drop_first).with_first_param(RequiredPositional.new(AST::Types::Bot.new)) + nil when a.nil? && b.is_a?(OptionalPositional) self & other.drop_first when a.nil? && b.is_a?(RestPositional) @@ -467,7 +477,7 @@ def &(other) optional_keywords = {} (Set.new(self.optional_keywords.keys) & Set.new(other.optional_keywords.keys)).each do |keyword| - self.optional_keywords[keyword] = AST::Types::Intersection.build( + optional_keywords[keyword] = AST::Types::Intersection.build( types: [ self.optional_keywords[keyword], other.optional_keywords[keyword] @@ -482,9 +492,7 @@ def &(other) when other.required_keywords.key?(keyword) required_keywords[keyword] = AST::Types::Intersection.build(types: [t, other.required_keywords[keyword]]) when other.rest_keywords - required_keywords[keyword] = AST::Types::Intersection.build(types: [t, other.rest_keywords]) - else - required_keywords[keyword] = t + optional_keywords[keyword] = AST::Types::Intersection.build(types: [t, other.rest_keywords]) end end end @@ -494,9 +502,7 @@ def &(other) when self.required_keywords.key?(keyword) required_keywords[keyword] = AST::Types::Intersection.build(types: [t, self.required_keywords[keyword]]) when self.rest_keywords - required_keywords[keyword] = AST::Types::Intersection.build(types: [t, self.rest_keywords]) - else - required_keywords[keyword] = t + optional_keywords[keyword] = AST::Types::Intersection.build(types: [t, self.rest_keywords]) end end end @@ -508,7 +514,7 @@ def &(other) when other.rest_keywords required_keywords[keyword] = AST::Types::Intersection.build(types: [t, other.rest_keywords]) else - required_keywords[keyword] = t + return end end end @@ -520,7 +526,7 @@ def &(other) when self.rest_keywords required_keywords[keyword] = AST::Types::Intersection.build(types: [t, self.rest_keywords]) else - required_keywords[keyword] = t + return end end end @@ -529,7 +535,153 @@ def &(other) when self.rest_keywords && other.rest_keywords AST::Types::Intersection.build(types: [self.rest_keywords, other.rest_keywords]) else - self.rest_keywords || other.rest_keywords + nil + end + + Params.new( + required: [], + optional: [], + rest: nil, + required_keywords: required_keywords, + optional_keywords: optional_keywords, + rest_keywords: rest) + end + end + + # Returns the union between self and other. + # + def |(other) + a = first_param + b = other.first_param + + case + when a.is_a?(RequiredPositional) && b.is_a?(RequiredPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first | other.drop_first)&.with_first_param(RequiredPositional.new(type)) + end + when a.is_a?(RequiredPositional) && b.is_a?(OptionalPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first | other.drop_first)&.with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(RequiredPositional) && b.is_a?(RestPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first | other.drop_first)&.with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(RequiredPositional) && b.nil? + self.drop_first&.with_first_param(OptionalPositional.new(a.type)) + when a.is_a?(OptionalPositional) && b.is_a?(RequiredPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first | other.drop_first)&.with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(OptionalPositional) && b.is_a?(OptionalPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first | other.drop_first)&.with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(OptionalPositional) && b.is_a?(RestPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first | other.drop_first)&.with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(OptionalPositional) && b.nil? + (self.drop_first | other)&.with_first_param(a) + when a.is_a?(RestPositional) && b.is_a?(RequiredPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first | other.drop_first)&.with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(RestPositional) && b.is_a?(OptionalPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self | other.drop_first)&.with_first_param(OptionalPositional.new(type)) + end + when a.is_a?(RestPositional) && b.is_a?(RestPositional) + AST::Types::Union.build(types: [a.type, b.type]).yield_self do |type| + (self.drop_first | other.drop_first)&.with_first_param(RestPositional.new(type)) + end + when a.is_a?(RestPositional) && b.nil? + (self.drop_first | other)&.with_first_param(a) + when a.nil? && b.is_a?(RequiredPositional) + other.drop_first&.with_first_param(OptionalPositional.new(b.type)) + when a.nil? && b.is_a?(OptionalPositional) + (self | other.drop_first)&.with_first_param(b) + when a.nil? && b.is_a?(RestPositional) + (self | other.drop_first)&.with_first_param(b) + when a.nil? && b.nil? + required_keywords = {} + optional_keywords = {} + + (Set.new(self.required_keywords.keys) & Set.new(other.required_keywords.keys)).each do |keyword| + required_keywords[keyword] = AST::Types::Union.build( + types: [ + self.required_keywords[keyword], + other.required_keywords[keyword] + ] + ) + end + + self.optional_keywords.each do |keyword, t| + unless optional_keywords.key?(keyword) || required_keywords.key?(keyword) + case + when s = other.required_keywords[keyword] + optional_keywords[keyword] = AST::Types::Union.build(types: [t, s]) + when s = other.optional_keywords[keyword] + optional_keywords[keyword] = AST::Types::Union.build(types: [t, s]) + when r = other.rest_keywords + optional_keywords[keyword] = AST::Types::Union.build(types: [t, r]) + else + optional_keywords[keyword] = t + end + end + end + other.optional_keywords.each do |keyword, t| + unless optional_keywords.key?(keyword) || required_keywords.key?(keyword) + case + when s = self.required_keywords[keyword] + optional_keywords[keyword] = AST::Types::Union.build(types: [t, s]) + when s = self.optional_keywords[keyword] + optional_keywords[keyword] = AST::Types::Union.build(types: [t, s]) + when r = self.rest_keywords + optional_keywords[keyword] = AST::Types::Union.build(types: [t, r]) + else + optional_keywords[keyword] = t + end + end + end + self.required_keywords.each do |keyword, t| + unless optional_keywords.key?(keyword) || required_keywords.key?(keyword) + case + when s = other.optional_keywords[keyword] + optional_keywords[keyword] = AST::Types::Union.build(types: [t, s]) + when r = other.rest_keywords + optional_keywords[keyword] = AST::Types::Union.build(types: [t, r]) + else + optional_keywords[keyword] = t + end + end + end + other.required_keywords.each do |keyword, t| + unless optional_keywords.key?(keyword) || required_keywords.key?(keyword) + case + when s = self.optional_keywords[keyword] + optional_keywords[keyword] = AST::Types::Union.build(types: [t, s]) + when r = self.rest_keywords + optional_keywords[keyword] = AST::Types::Union.build(types: [t, r]) + else + optional_keywords[keyword] = t + end + end + end + + rest = case + when self.rest_keywords && other.rest_keywords + AST::Types::Union.build(types: [self.rest_keywords, other.rest_keywords]) + when self.rest_keywords + if required_keywords.empty? && optional_keywords.empty? + self.rest_keywords + end + when other.rest_keywords + if required_keywords.empty? && optional_keywords.empty? + other.rest_keywords + end + else + nil end Params.new( @@ -567,6 +719,12 @@ def ==(other) other.is_a?(self.class) && other.type == type && other.optional == optional end + alias eql? == + + def hash + type.hash ^ optional.hash + end + def closed? type.closed? end @@ -600,9 +758,10 @@ def map_type(&block) def +(other) optional = self.optional? || other.optional? - type = AST::Types::Proc.new(params: self.type.params | other.type.params, - return_type: AST::Types::Union.build(types: [self.type.return_type, - other.type.return_type])) + type = AST::Types::Proc.new( + params: self.type.params + other.type.params, + return_type: AST::Types::Union.build(types: [self.type.return_type, other.type.return_type]) + ) self.class.new( type: type, optional: optional @@ -617,8 +776,6 @@ class MethodType attr_reader :return_type attr_reader :location - NONE = Object.new - def initialize(type_params:, params:, block:, return_type:, location:) @type_params = type_params @params = params @@ -636,6 +793,12 @@ def ==(other) (!other.location || !location || other.location == location) end + alias eql? == + + def hash + type_params.hash ^ params.hash ^ block.hash ^ return_type.hash + end + def free_variables @fvs ||= Set.new.tap do |set| set.merge(params.free_variables) @@ -676,23 +839,19 @@ def each_type(&block) end def instantiate(s) - self.class.new( - type_params: [], - params: params.subst(s), - block: block&.subst(s), - return_type: return_type.subst(s), - location: location, - ) + self.class.new(type_params: [], + params: params.subst(s), + block: block&.subst(s), + return_type: return_type.subst(s), + location: location) end - def with(type_params: NONE, params: NONE, block: NONE, return_type: NONE, location: NONE) - self.class.new( - type_params: type_params.equal?(NONE) ? self.type_params : type_params, - params: params.equal?(NONE) ? self.params : params, - block: block.equal?(NONE) ? self.block : block, - return_type: return_type.equal?(NONE) ? self.return_type : return_type, - location: location.equal?(NONE) ? self.location : location - ) + def with(type_params: self.type_params, params: self.params, block: self.block, return_type: self.return_type, location: self.location) + self.class.new(type_params: type_params, + params: params, + block: block, + return_type: return_type, + location: location) end def to_s @@ -704,16 +863,16 @@ def to_s end def map_type(&block) - self.class.new( - type_params: type_params, - params: params.map_type(&block), - block: self.block&.yield_self {|blk| blk.map_type(&block) }, - return_type: yield(return_type), - location: location - ) + self.class.new(type_params: type_params, + params: params.map_type(&block), + block: self.block&.yield_self {|blk| blk.map_type(&block) }, + return_type: yield(return_type), + location: location) end - def +(other) + # Returns a new method type which can be used for the method implementation type of both `self` and `other`. + # + def unify_overload(other) type_params = [] s1 = Substitution.build(self.type_params) type_params.push(*s1.dictionary.values.map(&:name)) @@ -731,7 +890,7 @@ def +(other) self.class.new( type_params: type_params, - params: params.subst(s1) | other.params.subst(s2), + params: params.subst(s1) + other.params.subst(s2), block: block, return_type: AST::Types::Union.build( types: [return_type.subst(s1),other.return_type.subst(s2)] @@ -739,6 +898,108 @@ def +(other) location: nil ) end + + def +(other) + unify_overload(other) + end + + # Returns a method type which is a super-type of both self and other. + # self <: (self | other) && other <: (self | other) + # + # Returns nil if self and other are incompatible. + # + def |(other) + self_type_params = Set.new(self.type_params) + other_type_params = Set.new(other.type_params) + + unless (common_type_params = (self_type_params & other_type_params).to_a).empty? + fresh_types = common_type_params.map {|name| AST::Types::Var.fresh(name) } + fresh_names = fresh_types.map(&:name) + subst = Substitution.build(common_type_params, fresh_types) + other = other.instantiate(subst) + type_params = (self_type_params + (other_type_params - common_type_params + Set.new(fresh_names))).to_a + else + type_params = (self_type_params + other_type_params).to_a + end + + params = self.params & other.params or return + block = case + when self.block && other.block + block_params = self.block.type.params | other.block.type.params + block_return_type = AST::Types::Intersection.build(types: [self.block.type.return_type, other.block.type.return_type]) + block_type = AST::Types::Proc.new(params: block_params, + return_type: block_return_type, + location: nil) + Block.new( + type: block_type, + optional: self.block.optional && other.block.optional + ) + when self.block && self.block.optional? + self.block + when other.block && other.block.optional? + other.block + when !self.block && !other.block + nil + else + return + end + return_type = AST::Types::Union.build(types: [self.return_type, other.return_type]) + + MethodType.new( + params: params, + block: block, + return_type: return_type, + type_params: type_params, + location: nil + ) + end + + # Returns a method type which is a sub-type of both self and other. + # (self & other) <: self && (self & other) <: other + # + # Returns nil if self and other are incompatible. + # + def &(other) + self_type_params = Set.new(self.type_params) + other_type_params = Set.new(other.type_params) + + unless (common_type_params = (self_type_params & other_type_params).to_a).empty? + fresh_types = common_type_params.map {|name| AST::Types::Var.fresh(name) } + fresh_names = fresh_types.map(&:name) + subst = Substitution.build(common_type_params, fresh_types) + other = other.subst(subst) + type_params = (self_type_params + (other_type_params - common_type_params + Set.new(fresh_names))).to_a + else + type_params = (self_type_params + other_type_params).to_a + end + + params = self.params | other.params + block = case + when self.block && other.block + block_params = self.block.type.params & other.block.type.params or return + block_return_type = AST::Types::Union.build(types: [self.block.type.return_type, other.block.type.return_type]) + block_type = AST::Types::Proc.new(params: block_params, + return_type: block_return_type, + location: nil) + Block.new( + type: block_type, + optional: self.block.optional || other.block.optional + ) + + else + self.block || other.block + end + + return_type = AST::Types::Intersection.build(types: [self.return_type, other.return_type]) + + MethodType.new( + params: params, + block: block, + return_type: return_type, + type_params: type_params, + location: nil + ) + end end end end diff --git a/lib/steep/subtyping/check.rb b/lib/steep/subtyping/check.rb index 02f84bf8e..bd44419ba 100644 --- a/lib/steep/subtyping/check.rb +++ b/lib/steep/subtyping/check.rb @@ -529,29 +529,24 @@ def check_interface(sub_interface, super_interface, self_type:, assumption:, tra def check_method(name, sub_method, super_method, self_type:, assumption:, trace:, constraints:) trace.method name, sub_method, super_method do - case - when sub_method.overload? && super_method.overload? - super_method.types.map do |super_type| - sub_method.types.map do |sub_type| - check_generic_method_type name, - sub_type, - super_type, - self_type: self_type, - assumption: assumption, - trace: trace, - constraints: constraints - end.yield_self do |results| - results.find(&:success?) || results[0] - end + super_method.method_types.map do |super_type| + sub_method.method_types.map do |sub_type| + check_generic_method_type name, + sub_type, + super_type, + self_type: self_type, + assumption: assumption, + trace: trace, + constraints: constraints end.yield_self do |results| - if results.all?(&:success?) || sub_method.incompatible? - success constraints: constraints - else - results.select(&:failure?).last - end + results.find(&:success?) || results[0] + end + end.yield_self do |results| + if results.all?(&:success?) + success constraints: constraints + else + results.select(&:failure?).last end - else - raise "aaaaaaaaaaaaaa" end end end diff --git a/lib/steep/type_construction.rb b/lib/steep/type_construction.rb index ee00089e8..fc848d67a 100644 --- a/lib/steep/type_construction.rb +++ b/lib/steep/type_construction.rb @@ -788,11 +788,10 @@ def synthesize(node, hint: nil) synthesize(child) end - super_method = Interface::Interface::Combination.overload( - method_context.super_method.method_types.map {|method_type| + super_method = Interface::Interface::Entry.new( + method_types: method_context.super_method.method_types.map {|method_type| checker.factory.method_type(method_type, self_type: self_type) - }, - incompatible: false + } ) args = TypeInference::SendArgs.from_nodes(node.children.dup) @@ -1714,7 +1713,7 @@ def synthesize(node, hint: nil) AST::Types::Any.new else each = checker.factory.interface(collection_type, private: true).methods[:each] - method_type = (each&.types || []).find {|type| type.block && type.block.type.params.first_param } + method_type = (each&.method_types || []).find {|type| type.block && type.block.type.params.first_param } method_type&.yield_self do |method_type| method_type.block.type.params.first_param&.type end @@ -1915,8 +1914,8 @@ def synthesize(node, hint: nil) param_type = hint.params.required[0] interface = checker.factory.interface(param_type, private: true) method = interface.methods[value.children[0]] - if method&.overload? - return_types = method.types.select {|method_type| + if method + return_types = method.method_types.select {|method_type| method_type.params.each_type.count == 0 }.map(&:return_type) @@ -2424,139 +2423,68 @@ def expand_self(type) def type_method_call(node, method_name:, receiver_type:, method:, args:, block_params:, block_body:, topdown_hint:) node_range = node.loc.expression.yield_self {|l| l.begin_pos..l.end_pos } - case - when method.union? - yield_self do - results = method.types.map do |method| + results = method.method_types.flat_map do |method_type| + Steep.logger.tagged method_type.to_s do + zips = args.zips(method_type.params, method_type.block&.type) + + zips.map do |arg_pairs| typing.new_child(node_range) do |child_typing| - with_new_typing(child_typing).type_method_call(node, - method_name: method_name, - receiver_type: receiver_type, - method: method, - args: args, - block_params: block_params, - block_body: block_body, - topdown_hint: false) - end - end + ret = self.with_new_typing(child_typing).try_method_type( + node, + receiver_type: receiver_type, + method_type: method_type, + args: args, + arg_pairs: arg_pairs, + block_params: block_params, + block_body: block_body, + child_typing: child_typing, + topdown_hint: topdown_hint + ) - if (type, constr, error = results.find {|_, _, error| error }) - constr.typing.save! - [type, - update_lvar_env { constr.context.lvar_env }, - error] - else - types = results.map(&:first) + raise unless ret.is_a?(Array) && ret[1].is_a?(TypeConstruction) - _, constr, _ = results.first - constr.typing.save! + result, constr = ret - [union_type(*types), - update_lvar_env { constr.context.lvar_env }, - nil] + [result, constr, method_type] + end end end + end - when method.intersection? - yield_self do - results = method.types.map do |method| - typing.new_child(node.loc.expression.yield_self {|l| l.begin_pos..l.end_pos }) do |child_typing| - with_new_typing(child_typing).type_method_call(node, + unless results.empty? + result, constr, method_type = results.find {|result, _, _| !result.is_a?(Errors::Base) } || results.last + else + method_type = method.method_types.last + constr = self.with_new_typing(typing.new_child(node_range)) + result = Errors::IncompatibleArguments.new(node: node, receiver_type: receiver_type, method_type: method_type) + end + constr.typing.save! + + case result + when Errors::Base + if method.method_types.size == 1 + typing.add_error result + type = case method_type.return_type + when AST::Types::Var + AST::Builtin.any_type + else + method_type.return_type + end + else + typing.add_error Errors::UnresolvedOverloading.new(node: node, + receiver_type: expand_self(receiver_type), method_name: method_name, - receiver_type: receiver_type, - method: method, - args: args, - block_params: block_params, - block_body: block_body, - topdown_hint: false) - end - end - - successes = results.reject {|_, _, error| error } - unless successes.empty? - types = successes.map(&:first) - constr = successes[0][1] - constr.typing.save! - - [AST::Types::Intersection.build(types: types), - update_lvar_env { constr.context.lvar_env }, - nil] - else - type, constr, error = results.first - constr.typing.save! - - [type, - update_lvar_env { constr.context.lvar_env }, - error] - end + method_types: method.method_types) + type = AST::Builtin.any_type end - when method.overload? - yield_self do - results = method.types.flat_map do |method_type| - Steep.logger.tagged method_type.to_s do - zips = args.zips(method_type.params, method_type.block&.type) - - zips.map do |arg_pairs| - typing.new_child(node_range) do |child_typing| - ret = self.with_new_typing(child_typing).try_method_type( - node, - receiver_type: receiver_type, - method_type: method_type, - args: args, - arg_pairs: arg_pairs, - block_params: block_params, - block_body: block_body, - child_typing: child_typing, - topdown_hint: topdown_hint - ) - - raise unless ret.is_a?(Array) && ret[1].is_a?(TypeConstruction) - - result, constr = ret - - [result, constr, method_type] - end - end - end - end - - unless results.empty? - result, constr, method_type = results.find {|result, _, _| !result.is_a?(Errors::Base) } || results.last - else - method_type = method.types.last - constr = self.with_new_typing(typing.new_child(node_range)) - result = Errors::IncompatibleArguments.new(node: node, receiver_type: receiver_type, method_type: method_type) - end - constr.typing.save! - - case result - when Errors::Base - if method.types.size == 1 - typing.add_error result - type = case method_type.return_type - when AST::Types::Var - AST::Builtin.any_type - else - method_type.return_type - end - else - typing.add_error Errors::UnresolvedOverloading.new(node: node, - receiver_type: expand_self(receiver_type), - method_name: method_name, - method_types: method.types) - type = AST::Builtin.any_type - end - - [type, - update_lvar_env { constr.context.lvar_env }, - result] - else # Type - [result, - update_lvar_env { constr.context.lvar_env }, - nil] - end - end + [type, + update_lvar_env { constr.context.lvar_env }, + result] + else # Type + [result, + update_lvar_env { constr.context.lvar_env }, + nil] end end @@ -2743,16 +2671,29 @@ def try_method_type(node, receiver_type:, method_type:, args:, arg_pairs:, block if block_params && method_type.block block_annotations = source.annotations(block: node, factory: checker.factory, current_module: current_namespace) block_params_ = TypeInference::BlockParams.from_node(block_params, annotations: block_annotations) - block_param_hint = block_params_.params_type( - hint: topdown_hint ? method_type.block.type.params : nil - ) + pairs = block_params_.zip(method_type.block.type.params) + + unless pairs + return [ + Errors::IncompatibleBlockParameters.new(node: node, method_type: method_type), + constr + ] + end - check_relation(sub_type: AST::Types::Proc.new(params: block_param_hint, return_type: AST::Types::Any.new), - super_type: method_type.block.type, - constraints: constraints).else do |result| - return [Errors::IncompatibleBlockParameters.new(node: node, - method_type: method_type), - constr] + pairs.each do |param, type| + if param.type + check_relation(sub_type: type, super_type: param.type, constraints: constraints).else do |result| + return [ + Errors::IncompatibleAssignment.new( + node: param.node, + lhs_type: param.type, + rhs_type: type, + result: result + ), + constr + ] + end + end end end diff --git a/smoke/alias/a.rb b/smoke/alias/a.rb index bfb3d4043..eb3835d76 100644 --- a/smoke/alias/a.rb +++ b/smoke/alias/a.rb @@ -1,7 +1,7 @@ # @type var x: foo x = "" -# !expects ArgumentTypeMismatch: receiver=(::Integer | ::String), expected=::string, actual=::Integer +# !expects* UnresolvedOverloading: receiver=(::String | ::Integer), method_name=+, x + 123 # @type var y: bar diff --git a/smoke/case/a.rb b/smoke/case/a.rb index 0fbc1a0b8..4aba3562f 100644 --- a/smoke/case/a.rb +++ b/smoke/case/a.rb @@ -1,6 +1,6 @@ # @type var a: Integer -# !expects IncompatibleAssignment: lhs_type=::Integer, rhs_type=(::Array[::String] | ::Integer | ::String | nil) +# !expects IncompatibleAssignment: lhs_type=::Integer, rhs_type=(::Integer | ::Array[::String] | nil | ::String) a = case 1 when 2 1 diff --git a/smoke/if/a.rb b/smoke/if/a.rb index 718d6bfca..cbfd5c08d 100644 --- a/smoke/if/a.rb +++ b/smoke/if/a.rb @@ -13,7 +13,7 @@ "baz" end -# !expects IncompatibleAssignment: lhs_type=::String, rhs_type=(::Integer | ::String) +# !expects IncompatibleAssignment: lhs_type=::String, rhs_type=(::String | ::Integer) a = if z "foofoo" else diff --git a/smoke/rescue/a.rb b/smoke/rescue/a.rb index 18df50be3..b852d0405 100644 --- a/smoke/rescue/a.rb +++ b/smoke/rescue/a.rb @@ -1,6 +1,6 @@ # @type var a: Integer -# !expects IncompatibleAssignment: lhs_type=::Integer, rhs_type=(::Integer | ::String) +# !expects IncompatibleAssignment: lhs_type=::Integer, rhs_type=(::String | ::Integer) a = begin 'foo' rescue @@ -9,12 +9,12 @@ # @type var b: Integer -# !expects IncompatibleAssignment: lhs_type=::Integer, rhs_type=(::Integer | ::String) +# !expects IncompatibleAssignment: lhs_type=::Integer, rhs_type=(::String | ::Integer) b = 'foo' rescue 1 # @type var c: Integer -# !expects IncompatibleAssignment: lhs_type=::Integer, rhs_type=(::Integer | ::String | ::Symbol) +# !expects IncompatibleAssignment: lhs_type=::Integer, rhs_type=(::String | ::Symbol | ::Integer) c = begin 'foo' rescue RuntimeError @@ -25,7 +25,7 @@ # @type var e: Integer -# !expects IncompatibleAssignment: lhs_type=::Integer, rhs_type=(::Array[::Integer] | ::Integer | ::Symbol) +# !expects IncompatibleAssignment: lhs_type=::Integer, rhs_type=(::Array[::Integer] | ::Symbol | ::Integer) e = begin 'foo' rescue RuntimeError diff --git a/test/block_params_test.rb b/test/block_params_test.rb index 006771636..b16b0a624 100644 --- a/test/block_params_test.rb +++ b/test/block_params_test.rb @@ -158,6 +158,46 @@ def test_zip4 end end + def test_zip_missing_required_params + with_factory do + type = Params.new( + required: [parse_type("::Integer"), parse_type("::Object")], + optional: [], + rest: nil, + required_keywords: {}, + optional_keywords: {}, + rest_keywords: nil + ) + + block_params("proc { }") do |params| + zip = params.zip(type) + + assert_empty zip + end + end + end + + def test_zip_with_extra_params + with_factory do + type = Params.new( + required: [parse_type("::Object")], + optional: [], + rest: nil, + required_keywords: {}, + optional_keywords: {}, + rest_keywords: nil + ) + + block_params("proc {|x, y| }") do |params| + zip = params.zip(type) + + assert_equal 2, zip.size + assert_equal [params.params[0], parse_type("::Object")], zip[0] + assert_equal [params.params[1], parse_type("nil")], zip[1] + end + end + end + def test_zip_expand_array with_factory do type = Params.new( diff --git a/test/interface_test.rb b/test/interface_test.rb index 2a7eed127..ff39318ac 100644 --- a/test/interface_test.rb +++ b/test/interface_test.rb @@ -4,59 +4,334 @@ class InterfaceTest < Minitest::Test include TestHelper include FactoryHelper - def test_method_type_params_union - with_factory do |factory| + def test_method_type_params_plus + with_factory do assert_equal parse_method_type("(String | Integer) -> untyped").params, - parse_method_type("(String) -> untyped").params | parse_method_type("(Integer) -> untyped").params + parse_method_type("(String) -> untyped").params + parse_method_type("(Integer) -> untyped").params assert_equal parse_method_type("(?String | Integer) -> untyped").params, - parse_method_type("(?String) -> untyped").params | parse_method_type("(Integer) -> untyped").params + parse_method_type("(?String) -> untyped").params + parse_method_type("(Integer) -> untyped").params assert_equal parse_method_type("(?String) -> untyped").params, - parse_method_type("(String) -> untyped").params | parse_method_type("() -> untyped").params + parse_method_type("(String) -> untyped").params + parse_method_type("() -> untyped").params assert_equal parse_method_type("(?String | Symbol, *Symbol) -> untyped").params, - parse_method_type("(String) -> untyped").params | parse_method_type("(*Symbol) -> untyped").params + parse_method_type("(String) -> untyped").params + parse_method_type("(*Symbol) -> untyped").params assert_equal parse_method_type("(?String | Symbol, *Symbol) -> void").params, - parse_method_type("(String) -> params").params | parse_method_type("(*Symbol) -> void").params + parse_method_type("(String) -> params").params + parse_method_type("(*Symbol) -> void").params assert_equal parse_method_type("(name: String | Symbol, ?email: String | Array, ?age: Integer | Object, **Array | Object) -> void").params, - parse_method_type("(name: String, email: String, **Object) -> void").params | parse_method_type("(name: Symbol, age: Integer, **Array) -> void").params + parse_method_type("(name: String, email: String, **Object) -> void").params + parse_method_type("(name: Symbol, age: Integer, **Array) -> void").params assert_equal parse_method_type("() ?{ (String | Integer) -> (Array | Hash) } -> void").params, - parse_method_type("() ?{ (String) -> Array } -> void").params | parse_method_type("() { (Integer) -> Hash } -> void").params + parse_method_type("() ?{ (String) -> Array } -> void").params + parse_method_type("() { (Integer) -> Hash } -> void").params end end def test_method_type_params_intersection - with_factory do |factory| + with_factory do + # req, none, opt, rest + + # required:required assert_equal parse_method_type("(String & Integer) -> untyped").params, parse_method_type("(String) -> untyped").params & parse_method_type("(Integer) -> untyped").params + # required:none + assert_nil parse_method_type("(String) -> untyped").params & parse_method_type("() -> untyped").params + + # required:optional assert_equal parse_method_type("(String & Integer) -> untyped").params, - parse_method_type("(?String) -> untyped").params & parse_method_type("(Integer) -> untyped").params + parse_method_type("(String) -> untyped").params & parse_method_type("(?Integer) -> untyped").params + # required:rest assert_equal parse_method_type("(String & Integer) -> untyped").params, parse_method_type("(String) -> untyped").params & parse_method_type("(*Integer) -> untyped").params - assert_equal parse_method_type("(bot) -> untyped").params, - (parse_method_type("(String) -> untyped").params & parse_method_type("() -> untyped").params) + # none:required + assert_nil parse_method_type("() -> untyped").params & parse_method_type("(String) -> void").params + # none:optional assert_equal parse_method_type("() -> untyped").params, - parse_method_type("(?String) -> untyped").params & parse_method_type("() -> untyped").params + parse_method_type("() -> untyped").params & parse_method_type("(?Integer) -> untyped").params + + # none:rest + assert_equal parse_method_type("() -> untyped").params, + parse_method_type("() -> untyped").params & parse_method_type("(*Integer) -> untyped").params + + # opt:required + assert_equal parse_method_type("(String & Integer) -> untyped").params, + parse_method_type("(?String) -> untyped").params & parse_method_type("(Integer) -> untyped").params - assert_equal parse_method_type("(String & Symbol) -> untyped").params, - parse_method_type("(String) -> untyped").params & parse_method_type("(*Symbol) -> untyped").params + # opt:none + assert_equal parse_method_type("() -> untyped").params, + parse_method_type("(?String) -> untyped").params & parse_method_type("() -> untyped").params + # opt:opt assert_equal parse_method_type("(?String & Integer) -> untyped").params, parse_method_type("(?String) -> untyped").params & parse_method_type("(?Integer) -> untyped").params - assert_equal parse_method_type("(String & Symbol) -> void").params, - parse_method_type("(String) -> params").params & parse_method_type("(*Symbol) -> void").params + # opt:rest + assert_equal parse_method_type("(?String & Integer) -> untyped").params, + parse_method_type("(?String) -> untyped").params & parse_method_type("(*Integer) -> untyped").params + + # rest:required + assert_equal parse_method_type("(String & Integer) -> untyped").params, + parse_method_type("(*String) -> untyped").params & parse_method_type("(Integer) -> untyped").params + + # rest:none + assert_equal parse_method_type("() -> untyped").params, + parse_method_type("(*String) -> untyped").params & parse_method_type("() -> untyped").params + + # rest:opt + assert_equal parse_method_type("(?String & Integer) -> untyped").params, + parse_method_type("(*String) -> untyped").params & parse_method_type("(?Integer) -> untyped").params + + # rest:rest + assert_equal parse_method_type("(*String & Integer) -> untyped").params, + parse_method_type("(*String) -> untyped").params & parse_method_type("(*Integer) -> untyped").params + + ## Keywords + + # req:req + assert_equal parse_method_type("(foo: String & Integer) -> untyped").params, + parse_method_type("(foo: String) -> untyped").params & parse_method_type("(foo: Integer) -> untyped").params + + # req:opt + assert_equal parse_method_type("(foo: Integer & String) -> untyped").params, + parse_method_type("(foo: String) -> untyped").params & parse_method_type("(?foo: Integer) -> untyped").params + + # req:none + assert_nil parse_method_type("(foo: String) -> untyped").params & parse_method_type("() -> untyped").params + + # req:rest + assert_equal parse_method_type("(foo: String & Integer) -> untyped").params, + parse_method_type("(foo: String) -> untyped").params & parse_method_type("(**Integer) -> untyped").params + + # opt:req + assert_equal parse_method_type("(foo: String & Integer) -> untyped").params, + parse_method_type("(?foo: String) -> untyped").params & parse_method_type("(foo: Integer) -> untyped").params + + # opt:opt + assert_equal parse_method_type("(?foo: String & Integer) -> untyped").params, + parse_method_type("(?foo: String) -> untyped").params & parse_method_type("(?foo: Integer) -> untyped").params + + # opt:none + assert_equal parse_method_type("() -> untyped").params, + parse_method_type("(?foo: String) -> untyped").params & parse_method_type("() -> untyped").params + + # opt:rest + assert_equal parse_method_type("(?foo: String & Integer) -> untyped").params, + parse_method_type("(?foo: String) -> untyped").params & parse_method_type("(**Integer) -> untyped").params + + # none:req + assert_nil parse_method_type("() -> untyped").params & parse_method_type("(foo: String) -> untyped").params + + # none:opt + assert_equal parse_method_type("() -> untyped").params, + parse_method_type("() -> untyped").params & parse_method_type("(?foo: Integer) -> untyped").params + + # none:rest + assert_equal parse_method_type("() -> untyped").params, + parse_method_type("() -> untyped").params & parse_method_type("(**Integer) -> untyped").params + + # rest:req + assert_equal parse_method_type("(foo: Integer & String) -> untyped").params, + parse_method_type("(**String) -> untyped").params & parse_method_type("(foo: Integer) -> untyped").params + + # rest:opt + assert_equal parse_method_type("(?foo: Integer & String) -> untyped").params, + parse_method_type("(**String) -> untyped").params & parse_method_type("(?foo: Integer) -> untyped").params + + # rest:none + assert_equal parse_method_type("() -> untyped").params, + parse_method_type("(**String) -> untyped").params & parse_method_type("() -> untyped").params + + # rest:rest + assert_equal parse_method_type("(**String & Integer) -> untyped").params, + parse_method_type("(**String) -> untyped").params & parse_method_type("(**Integer) -> untyped").params + end + end + + def test_method_type_params_union + with_factory do + # required:required + assert_equal parse_method_type("(String | Integer) -> untyped").params, + parse_method_type("(String) -> untyped").params | parse_method_type("(Integer) -> untyped").params + + # required:none + assert_equal parse_method_type("(?String) -> void").params, + parse_method_type("(String) -> untyped").params | parse_method_type("() -> untyped").params + + # required:optional + assert_equal parse_method_type("(?String | Integer) -> untyped").params, + parse_method_type("(String) -> untyped").params | parse_method_type("(?Integer) -> untyped").params + + # required:rest + assert_equal parse_method_type("(?String | Integer) -> untyped").params, + parse_method_type("(String) -> untyped").params | parse_method_type("(*Integer) -> untyped").params + + # none:required + assert_equal parse_method_type("(?String) -> untyped").params, + parse_method_type("() -> untyped").params | parse_method_type("(String) -> untyped").params + + # none:optional + assert_equal parse_method_type("(?Integer) -> untyped").params, + parse_method_type("() -> untyped").params | parse_method_type("(?Integer) -> untyped").params + + # none:rest + assert_equal parse_method_type("(*Integer) -> untyped").params, + parse_method_type("() -> untyped").params | parse_method_type("(*Integer) -> untyped").params + + # opt:required + assert_equal parse_method_type("(?String | Integer) -> untyped").params, + parse_method_type("(?String) -> untyped").params | parse_method_type("(Integer) -> untyped").params + + # opt:none + assert_equal parse_method_type("(?String) -> untyped").params, + parse_method_type("(?String) -> untyped").params | parse_method_type("() -> untyped").params + + # opt:opt + assert_equal parse_method_type("(?String | Integer) -> untyped").params, + parse_method_type("(?String) -> untyped").params | parse_method_type("(?Integer) -> untyped").params + + # opt:rest + assert_equal parse_method_type("(?String | Integer) -> untyped").params, + parse_method_type("(?String) -> untyped").params | parse_method_type("(*Integer) -> untyped").params + + # rest:required + assert_equal parse_method_type("(?String | Integer) -> untyped").params, + parse_method_type("(*String) -> untyped").params | parse_method_type("(Integer) -> untyped").params + + # rest:none + assert_equal parse_method_type("(*String) -> untyped").params, + parse_method_type("(*String) -> untyped").params | parse_method_type("() -> untyped").params + + # rest:opt + assert_equal parse_method_type("(?String | Integer, *String) -> untyped").params, + parse_method_type("(*String) -> untyped").params | parse_method_type("(?Integer) -> untyped").params + + # rest:rest + assert_equal parse_method_type("(*String | Integer) -> untyped").params, + parse_method_type("(*String) -> untyped").params | parse_method_type("(*Integer) -> untyped").params + + ## Keywords + + # req:req + assert_equal parse_method_type("(foo: String | Integer) -> untyped").params, + parse_method_type("(foo: String) -> untyped").params | parse_method_type("(foo: Integer) -> untyped").params + + # req:opt + assert_equal parse_method_type("(?foo: String | Integer) -> untyped").params, + parse_method_type("(foo: String) -> untyped").params | parse_method_type("(?foo: Integer) -> untyped").params + + # req:none + assert_equal parse_method_type("(?foo: String) -> untyped").params, + parse_method_type("(foo: String) -> untyped").params | parse_method_type("() -> untyped").params + + # req:rest + assert_equal parse_method_type("(?foo: String | Integer) -> untyped").params, + parse_method_type("(foo: String) -> untyped").params | parse_method_type("(**Integer) -> untyped").params + + # opt:req + assert_equal parse_method_type("(?foo: String | Integer) -> untyped").params, + parse_method_type("(?foo: String) -> untyped").params | parse_method_type("(foo: Integer) -> untyped").params + + # opt:opt + assert_equal parse_method_type("(?foo: String | Integer) -> untyped").params, + parse_method_type("(?foo: String) -> untyped").params | parse_method_type("(?foo: Integer) -> untyped").params + + # opt:none + assert_equal parse_method_type("(?foo: String) -> untyped").params, + parse_method_type("(?foo: String) -> untyped").params | parse_method_type("() -> untyped").params + + # opt:rest + assert_equal parse_method_type("(?foo: String | Integer) -> untyped").params, + parse_method_type("(?foo: String) -> untyped").params | parse_method_type("(**Integer) -> untyped").params + + # none:req + assert_equal parse_method_type("(?foo: String) -> untyped").params, + parse_method_type("() -> untyped").params | parse_method_type("(foo: String) -> untyped").params + + # none:opt + assert_equal parse_method_type("(?foo: Integer) -> untyped").params, + parse_method_type("() -> untyped").params | parse_method_type("(?foo: Integer) -> untyped").params + + # none:rest + assert_equal parse_method_type("(**Integer) -> untyped").params, + parse_method_type("() -> untyped").params | parse_method_type("(**Integer) -> untyped").params + + # rest:req + assert_equal parse_method_type("(?foo: String | Integer) -> untyped").params, + parse_method_type("(**String) -> untyped").params | parse_method_type("(foo: Integer) -> untyped").params + + # rest:opt + assert_equal parse_method_type("(?foo: String | Integer) -> untyped").params, + parse_method_type("(**String) -> untyped").params | parse_method_type("(?foo: Integer) -> untyped").params + + # rest:none + assert_equal parse_method_type("(**String) -> untyped").params, + parse_method_type("(**String) -> untyped").params | parse_method_type("() -> untyped").params + + # rest:rest + assert_equal parse_method_type("(**String | Integer) -> untyped").params, + parse_method_type("(**String) -> untyped").params | parse_method_type("(**Integer) -> untyped").params + end + end + + def test_method_type_union + with_factory do + assert_equal parse_method_type("(String & Integer) -> (String | Symbol)"), + parse_method_type("(String) -> String") | parse_method_type("(Integer) -> Symbol") + + assert_nil parse_method_type("() -> String") | parse_method_type("(Integer) -> untyped") + assert_equal parse_method_type("() -> bool"), + parse_method_type("() -> bot") | parse_method_type("() -> bool") + assert_equal parse_method_type("() -> untyped"), + parse_method_type("() -> untyped") | parse_method_type("() -> String") + + assert_equal parse_method_type("() { (String | Integer) -> (Integer & Float) } -> (String | Symbol)"), + parse_method_type("() { (String) -> Integer } -> String") | parse_method_type("() { (Integer) -> Float } -> Symbol") + + assert_equal parse_method_type("() { (String | Integer, ?String) -> void } -> void"), + parse_method_type("() { (String, String) -> void } -> void") | parse_method_type("() { (Integer) -> void } -> void") + + assert_equal parse_method_type("() { (String | Integer) -> (Integer & Float) } -> (String | Symbol)"), + parse_method_type("() ?{ (String) -> Integer } -> String") | parse_method_type("() { (Integer) -> Float } -> Symbol") + + assert_equal parse_method_type("() ?{ (String) -> Integer } -> (String | Symbol)"), + parse_method_type("() ?{ (String) -> Integer } -> String") | parse_method_type("() -> Symbol") + end + end + + def test_method_type_union_poly + skip + assert_equal parse_method_type("[A, A_1, B] (Array[A] & Hash[A_1, B]) -> (String | Symbol)"), + parse_method_type("[A] (Array[A]) -> String") | parse_method_type("[A, B] (Hash[A, B]) -> Symbol") + end + + def test_method_type_intersection + with_factory do + assert_equal parse_method_type("(String | Integer) -> (String & Symbol)"), + parse_method_type("(String) -> String") & parse_method_type("(Integer) -> Symbol") + + assert_equal parse_method_type("(?Integer) -> untyped"), + parse_method_type("() -> String") & parse_method_type("(Integer) -> untyped") + + assert_equal parse_method_type("() -> bot"), + parse_method_type("() -> bot") & parse_method_type("() -> bool") + assert_equal parse_method_type("() -> untyped"), + parse_method_type("() -> untyped") & parse_method_type("() -> String") + + assert_equal parse_method_type("() { (String & Integer) -> (Integer | Float) } -> (String & Symbol)"), + parse_method_type("() { (String) -> Integer } -> String") & parse_method_type("() { (Integer) -> Float } -> Symbol") + + assert_nil parse_method_type("() { (String, String) -> void } -> void") & parse_method_type("() { (Integer) -> void } -> void") + + assert_equal parse_method_type("() ?{ (String & Integer) -> (Integer | Float) } -> (String & Symbol)"), + parse_method_type("() ?{ (String) -> Integer } -> String") & parse_method_type("() { (Integer) -> Float } -> Symbol") + - assert_equal parse_method_type("(name: String & Symbol, email: String & Array, age: Integer & Object, **Array & Object) -> void").params, - (parse_method_type("(name: String, email: String, **Object) -> void").params & parse_method_type("(name: Symbol, age: Integer, **Array) -> void").params) end end diff --git a/test/test_helper.rb b/test/test_helper.rb index 57648e0ff..2d84cf490 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -22,8 +22,7 @@ def self.new_module(location: nil, name:, args: []) def self.new_class(location: nil, name:, constructor:, args: []) name = Steep::Names::Module.parse(name.to_s) unless name.is_a?(Steep::Names::Module) - Steep::AST::Types::Name::Singleton.new(location: location, - name: name) + Steep::AST::Types::Name::Singleton.new(location: location, name: name) end def self.new_instance(location: nil, name:, args: []) diff --git a/test/type_construction_test.rb b/test/type_construction_test.rb index e5fce7482..0d9dfd82e 100644 --- a/test/type_construction_test.rb +++ b/test/type_construction_test.rb @@ -648,6 +648,36 @@ def test_block_param_type end end + def test_block_extra_missing_params + with_checker(<<-RBS) do |checker| +class M1 + def foo: () { (Integer, String, bool) -> void } -> void +end + +class M2 + def foo: () { (String, Integer) -> void } -> void +end + RBS + + source = parse_ruby(<<-EOF) +x = [M1.new, M2.new][0] + +x.foo do |a| + nil +end + +x.foo do |a, b, c| + nil +end + EOF + + with_standard_construction(checker, source) do |construction, typing| + construction.synthesize(source.node) + assert_no_error typing + end + end + end + def test_block_value_type with_checker do |checker| source = parse_ruby(<<-EOF) @@ -1720,14 +1750,16 @@ def test_intersection_send source = parse_ruby(<<-RUBY) # @type var x: Integer & String x = (_ = nil) -y = x + "" +y = x.to_str +z = x.to_int RUBY with_standard_construction(checker, source) do |construction, typing| pair = construction.synthesize(source.node) - assert_empty typing.errors + assert_no_error typing assert_equal parse_type("::String"), pair.context.lvar_env[:y] + assert_equal parse_type("::Integer"), pair.context.lvar_env[:z] end end end diff --git a/test/type_factory_test.rb b/test/type_factory_test.rb index 4a99f5f51..8549c004e 100644 --- a/test/type_factory_test.rb +++ b/test/type_factory_test.rb @@ -10,13 +10,21 @@ def parse_method_type(str) end def assert_overload_with(c, *types) - assert_operator c, :overload? - assert_equal Set.new(types.map(&:to_s)), Set.new(c.types.map(&:to_s)), "Expected: { #{types.join(" | ")} }, Actual: { #{c.types.join(" | ")} }" + types = types.map do |s| + factory.method_type(parse_method_type(s), self_type: Steep::AST::Types::Self.new, subst2: nil) + end + + assert_equal Set.new(types), Set.new(c.method_types), "Expected: { #{types.join(" | ")} }, Actual: #{c.to_s}" end def assert_overload_including(c, *types) - assert_operator c, :overload? - assert_operator Set.new(types.map(&:to_s)), :subset?, Set.new(c.types.map(&:to_s)), "Expected: { #{types.join(" | ")} } is a subset of { #{c.types.join(" | ")} }" + types = types.map do |s| + factory.method_type(parse_method_type(s), self_type: Steep::AST::Types::Self.new, subst2: nil) + end + + assert_operator Set.new(types), + :subset?, Set.new(c.method_types), + "Expected: { #{types.join(" | ")} } is a subset of { #{c.method_types.join(" | ")} }" end Types = Steep::AST::Types @@ -361,8 +369,7 @@ def test_literal_type assert_overload_including interface.methods[:+], "(::Integer) -> ::Integer", "(::Float) -> ::Float" interface.methods[:yield_self].tap do |method| - assert_operator method, :overload? - x = method.types[0].type_params[0] + x = method.method_types[0].type_params[0] assert_includes method.to_s, " [#{x}] () { (3) -> #{x} } -> #{x} " end end @@ -396,13 +403,14 @@ def test_record_type factory.interface(type, private: false).yield_self do |interface| assert_instance_of Steep::Interface::Interface, interface - assert_operator interface.methods[:[]].to_s, - :start_with?, - "{ (1) -> ::Integer | (:foo) -> ::String | (\"baz\") -> bool" - assert_operator interface.methods[:[]=].to_s, - :start_with?, - "{ (1, ::Integer) -> ::Integer | (:foo, ::String) -> ::String | (\"baz\", bool) -> bool" - + assert_overload_including interface.methods[:[]], + "(1) -> ::Integer", + "(:foo) -> ::String", + "(\"baz\") -> bool" + assert_overload_including interface.methods[:[]=], + "(1, ::Integer) -> ::Integer", + "(:foo, ::String) -> ::String", + "(\"baz\", bool) -> bool" end end end @@ -414,21 +422,14 @@ def test_union_type factory.interface(type, private: false).yield_self do |interface| assert_instance_of Steep::Interface::Interface, interface - interface.methods[:to_s].yield_self do |combination| - assert_equal :union, combination.operator - assert_any! combination.types do |c| - assert_overload_with c, "() -> ::String" - end + interface.methods[:to_s].yield_self do |entry| + assert_overload_with entry, "() -> ::String" end - interface.methods[:+].yield_self do |combination| - assert_equal :union, combination.operator - assert_any! combination.types do |c| - assert_overload_including c, "(::Integer) -> ::Integer", "(::Float) -> ::Float" - end - assert_any! combination.types do |c| - assert_overload_with c, "(::string) -> ::String" - end + interface.methods[:+].yield_self do |entry| + assert_overload_including entry, + "((::Integer & ::string)) -> (::Integer | ::String)", + "((::Float & ::string)) -> (::Float | ::String)" end assert_nil interface.methods[:floor] @@ -438,36 +439,58 @@ def test_union_type end end - def test_intersection_type - with_factory() do |factory| - factory.type(parse_type("::Integer & ::String")).yield_self do |type| - factory.interface(type, private: false).yield_self do |interface| + def test_union_type_methods + with_factory("foo.rbs" => <<-RBS) do |factory| +interface _I1 + def f: () -> void + def g: () -> void + + def foo: (String) -> void + | (Symbol) -> String +end + +interface _I2 + def g: () -> void + def h: () -> void + + def foo: () -> void + | (Integer) -> Float +end + RBS + factory.type(parse_type("::_I1 | ::_I2")).tap do |type| + factory.interface(type, private: false).tap do |interface| assert_instance_of Steep::Interface::Interface, interface - interface.methods[:to_s].yield_self do |combination| - assert_equal :intersection, combination.operator - assert_any! combination.types do |c| - assert_overload_including c, "() -> ::String" - end - end + assert_equal Set[:g, :foo], Set.new(interface.methods.keys) + assert_overload_with interface.methods[:g], "() -> void" + assert_overload_with interface.methods[:foo], + "((::String & ::Integer)) -> (::Float | void)", + "((::Symbol & ::Integer)) -> (::Float | ::String)" + end + end + end + end - interface.methods[:+].yield_self do |combination| - assert_equal :intersection, combination.operator - assert_any! combination.types do |c| - assert_overload_including c, "(::Integer) -> ::Integer", "(::Float) -> ::Float" - end - assert_any! combination.types do |c| - assert_overload_including c, "(::string) -> ::String" - end - end + def test_intersection_type + with_factory("foo.rbs" => <<-RBS) do |factory| +interface _I1 + def f: () -> void + def g: () -> void +end - interface.methods[:floor].yield_self do |combination| - assert_overload_including combination, "(::int) -> (::Float | ::Integer)", "() -> ::Integer" - end +interface _I2 + def g: (String) -> Integer + def h: () -> void +end + RBS + factory.type(parse_type("::_I1 & ::_I2")).tap do |type| + factory.interface(type, private: false).tap do |interface| + assert_instance_of Steep::Interface::Interface, interface - interface.methods[:end_with?].yield_self do |combination| - assert_overload_including combination, "(*::string) -> bool" - end + assert_equal Set[:f, :g, :h], Set.new(interface.methods.keys) + assert_overload_with interface.methods[:f], "() -> void" + assert_overload_with interface.methods[:g], "(::String) -> ::Integer" + assert_overload_with interface.methods[:h], "() -> void" end end end @@ -479,14 +502,12 @@ def test_proc_type factory.interface(type, private: false).yield_self do |interface| assert_instance_of Steep::Interface::Interface, interface - interface.methods[:call].yield_self do |combination| - assert_equal :overload, combination.operator - assert_equal [parse_method_type("(String) -> Integer").to_s], combination.types.map(&:to_s) + interface.methods[:call].yield_self do |entry| + assert_overload_with entry, "(String) -> Integer" end - interface.methods[:[]].yield_self do |combination| - assert_equal :overload, combination.operator - assert_equal [parse_method_type("(String) -> Integer").to_s], combination.types.map(&:to_s) + interface.methods[:[]].yield_self do |entry| + assert_overload_with entry, "(String) -> Integer" end end end