Skip to content

Commit

Permalink
Choose the best type on Any result
Browse files Browse the repository at this point in the history
  • Loading branch information
soutaro committed Mar 8, 2022
1 parent 8142fa2 commit e0383cd
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 2 deletions.
22 changes: 20 additions & 2 deletions lib/steep/subtyping/check.rb
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ def check_type0(relation)

when relation.super_type.is_a?(AST::Types::Union)
Any(relation) do |result|
relation.super_type.types.each do |super_type|
relation.super_type.types.sort_by {|ty| (path = hole_path(ty)) ? -path.size : 1 }.each do |super_type|
rel = Relation.new(sub_type: relation.sub_type, super_type: super_type)
result.add(rel) do
check_type(rel)
Expand All @@ -357,7 +357,7 @@ def check_type0(relation)

when relation.sub_type.is_a?(AST::Types::Intersection)
Any(relation) do |result|
relation.sub_type.types.each do |sub_type|
relation.sub_type.types.sort_by {|ty| (path = hole_path(ty)) ? -path.size : 1 }.each do |sub_type|
rel = Relation.new(sub_type: sub_type, super_type: relation.super_type)
result.add(rel) do
check_type(rel)
Expand Down Expand Up @@ -974,6 +974,24 @@ def match_params(name, relation)
def expand_alias(type, &block)
factory.expand_alias(type, &block)
end

# Returns the shortest type paths for one of the _unknown_ type variables.
# Returns nil if there is no path.
def hole_path(type, path = [])
case type
when AST::Types::Var
if constraints.unknown?(type.name)
[type]
else
nil
end
else
paths = type.each_child.map do |ty|
hole_path(ty, path)&.unshift(ty)
end
paths.compact.min_by(&:size)
end
end
end
end
end
22 changes: 22 additions & 0 deletions test/type_construction_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8419,4 +8419,26 @@ def foo(&block)
end
end
end

def test_flat_map
with_checker(<<-RBS) do |checker|
class FlatMap
def flat_map: [A] () { (String) -> (A | Array[A]) } -> Array[A]
end
RBS

source = parse_ruby(<<-'RUBY')
# @type var a: FlatMap
a = _ = nil
a.flat_map {|s| [s] }
RUBY

with_standard_construction(checker, source) do |construction, typing|
type, _, _ = construction.synthesize(source.node)

assert_no_error typing
assert_equal parse_type("::Array[::String]"), type
end
end
end
end

0 comments on commit e0383cd

Please sign in to comment.