Skip to content

Commit

Permalink
Merge pull request #643 from soutaro/unfold-alias
Browse files Browse the repository at this point in the history
Unfold type alias before splatting on block parameters
  • Loading branch information
soutaro authored Sep 5, 2022
2 parents 21150f4 + 34e0eab commit 32d2ae4
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 86 deletions.
4 changes: 2 additions & 2 deletions lib/steep/type_construction.rb
Original file line number Diff line number Diff line change
Expand Up @@ -3420,7 +3420,7 @@ def try_method_type(node, receiver_type:, method_name:, method_type:, arguments:

if method_type.block
# Method accepts block
pairs = block_params_&.zip(method_type.block.type.params, nil)
pairs = block_params_&.zip(method_type.block.type.params, nil, factory: checker.factory)

if block_params_ && pairs
# Block parameters are compatible with the block type
Expand Down Expand Up @@ -3797,7 +3797,7 @@ def set_up_block_mlhs_params_env(node, type, hash, &block)
end

def for_block(body_node, block_params:, block_param_hint:, block_type_hint:, block_block_hint:, block_annotations:, node_type_hint:, block_self_hint:)
block_param_pairs = block_param_hint && block_params.zip(block_param_hint, block_block_hint)
block_param_pairs = block_param_hint && block_params.zip(block_param_hint, block_block_hint, factory: checker.factory)

# @type var param_types_hash: Hash[Symbol, AST::Types::t]
param_types_hash = {}
Expand Down
130 changes: 66 additions & 64 deletions lib/steep/type_inference/block_params.rb
Original file line number Diff line number Diff line change
Expand Up @@ -209,99 +209,101 @@ def params_type0(hint:)
)
end

def zip(params_type, block)
def zip(params_type, block, factory:)
if trailing_params.any?
Steep.logger.error "Block definition with trailing required parameters are not supported yet"
end

[].tap do |zip|
if expandable_params?(params_type) && expandable?
type = params_type.required[0]

case
when AST::Builtin::Array.instance_type?(type)
type.is_a?(AST::Types::Name::Instance) or raise

type_arg = type.args[0]
params.each do |param|
unless param == rest_param
zip << [param, AST::Types::Union.build(types: [type_arg, AST::Builtin.nil_type])]
else
zip << [param, AST::Builtin::Array.instance_type(type_arg)]
end
end
when type.is_a?(AST::Types::Tuple)
types = type.types.dup
(leading_params + optional_params).each do |param|
ty = types.shift
if ty
zip << [param, ty]
else
zip << [param, AST::Types::Nil.new]
end
end
# @type var zip: Array[[Param | MultipleParam, AST::Types::t]]
zip = []

if expandable? && (type = expandable_params?(params_type, factory))
case
when AST::Builtin::Array.instance_type?(type)
type.is_a?(AST::Types::Name::Instance) or raise

if rest_param
if types.any?
union = AST::Types::Union.build(types: types)
zip << [rest_param, AST::Builtin::Array.instance_type(union)]
else
zip << [rest_param, AST::Types::Nil.new]
end
type_arg = type.args[0]
params.each do |param|
unless param == rest_param
zip << [param, AST::Types::Union.build(types: [type_arg, AST::Builtin.nil_type])]
else
zip << [param, AST::Builtin::Array.instance_type(type_arg)]
end
end
else
types = params_type.flat_unnamed_params

when type.is_a?(AST::Types::Tuple)
types = type.types.dup
(leading_params + optional_params).each do |param|
type = types.shift&.last || params_type.rest

if type
zip << [param, type]
ty = types.shift
if ty
zip << [param, ty]
else
zip << [param, AST::Builtin.nil_type]
zip << [param, AST::Types::Nil.new]
end
end

if rest_param
if types.empty?
array = AST::Builtin::Array.instance_type(params_type.rest || AST::Builtin.any_type)
zip << [rest_param, array]
if types.any?
union = AST::Types::Union.build(types: types)
zip << [rest_param, AST::Builtin::Array.instance_type(union)]
else
union = AST::Types::Union.build(types: types.map(&:last) + [params_type.rest])
array = AST::Builtin::Array.instance_type(union)
zip << [rest_param, array]
zip << [rest_param, AST::Types::Nil.new]
end
end
end
else
types = params_type.flat_unnamed_params

if block_param
if block
proc_type = AST::Types::Proc.new(type: block.type, block: nil, self_type: block.self_type)
if block.optional?
proc_type = AST::Types::Union.build(types: [proc_type, AST::Builtin.nil_type])
end
(leading_params + optional_params).each do |param|
type = types.shift&.last || params_type.rest

if type
zip << [param, type]
else
zip << [param, AST::Builtin.nil_type]
end
end

zip << [block_param, proc_type]
if rest_param
if types.empty?
array = AST::Builtin::Array.instance_type(params_type.rest || AST::Builtin.any_type)
zip << [rest_param, array]
else
zip << [block_param, AST::Builtin.nil_type]
union = AST::Types::Union.build(types: types.map(&:last) + [params_type.rest])
array = AST::Builtin::Array.instance_type(union)
zip << [rest_param, array]
end
end
end

if block_param
if block
proc_type = AST::Types::Proc.new(type: block.type, block: nil, self_type: block.self_type)
if block.optional?
proc_type = AST::Types::Union.build(types: [proc_type, AST::Builtin.nil_type])
end

zip << [block_param, proc_type]
else
zip << [block_param, AST::Builtin.nil_type]
end
end

zip
end

def expandable_params?(params_type)
def expandable_params?(params_type, factory)
if params_type.flat_unnamed_params.size == 1
case (type = params_type.required.first)
type = params_type.required.first or raise
type = factory.deep_expand_alias(type) || type

case type
when AST::Types::Tuple
true
type
when AST::Types::Name::Base
AST::Builtin::Array.instance_type?(type)
else
false
if AST::Builtin::Array.instance_type?(type)
type
end
end
else
false
end
end

Expand Down
15 changes: 10 additions & 5 deletions sig/steep/type_inference/block_params.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ module Steep
#
class MultipleParam
attr_reader node: Parser::AST::Node

attr_reader params: Array[Param | MultipleParam]

def initialize: (node: Parser::AST::Node, params: Array[Param | MultipleParam]) -> void
Expand Down Expand Up @@ -120,16 +120,21 @@ module Steep
def params_type0: (hint: nil) -> Interface::Function::Params
| (hint: Interface::Function::Params?) -> Interface::Function::Params?

def zip: (Interface::Function::Params params_type, Interface::Block? block) -> Array[[Param | MultipleParam, AST::Types::t]]
def zip: (
Interface::Function::Params params_type,
Interface::Block? block,
factory: AST::Types::Factory
) -> Array[[Param | MultipleParam, AST::Types::t]]

# Returns true if given possible block yields are subject to auto expand/splat
#
# ```rbs
# { (Array[String]) -> void } # true
# { ([Integer, String]) -> void } # true
# { (Array[String]) -> void } # Array[String]
# { ([Integer, String]) -> void } # [Integer, String]
# { (String) -> void } # nil
# ```
#
def expandable_params?: (Interface::Function::Params params_type) -> bool
def expandable_params?: (Interface::Function::Params params_type, AST::Types::Factory) -> AST::Types::t?

# Returns true if the block is defined to expand/splat automatically
#
Expand Down
7 changes: 7 additions & 0 deletions smoke/regression/block_param_split.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# @type var a: Array[BlockParamSplit::pair[Integer, String]]
a = [[1, "a"]]

a.each do |x, y|
x + 1
y + "2"
end
3 changes: 3 additions & 0 deletions smoke/regression/block_param_split.rbs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
module BlockParamSplit
type pair[T, S] = [T, S]
end
34 changes: 19 additions & 15 deletions test/block_params_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_zip1
)

block_params("proc {|a, b=1, *c| }") do |params, args|
zip = params.zip(type, nil)
zip = params.zip(type, nil, factory: factory)
assert_equal [params.params[0], parse_type("Integer")], zip[0]
assert_equal [params.params[1], parse_type("nil")], zip[1]
assert_equal [params.params[2], parse_type("::Array[untyped]")], zip[2]
Expand All @@ -122,7 +122,7 @@ def test_zip2
)

block_params("proc {|a, b, *c| }") do |params, args|
zip = params.zip(type, nil)
zip = params.zip(type, nil, factory: factory)

assert_equal [params.params[0], parse_type("::Integer")], zip[0]
assert_equal [params.params[1], parse_type("::String")], zip[1]
Expand All @@ -143,7 +143,7 @@ def test_zip3
)

block_params("proc {|x, *y| }") do |params|
zip = params.zip(type, nil)
zip = params.zip(type, nil, factory: factory)

assert_equal [params.params[0], parse_type("::Integer")], zip[0]
assert_equal [params.params[1], parse_type("::Array[::Object | ::String]")], zip[1]
Expand All @@ -163,7 +163,7 @@ def test_zip4
)

block_params("proc {|x| }") do |params|
zip = params.zip(type, nil)
zip = params.zip(type, nil, factory: factory)

assert_equal 1, zip.size
assert_equal [params.params[0], parse_type("::Integer")], zip[0]
Expand All @@ -183,7 +183,7 @@ def test_zip_missing_required_params
)

block_params("proc { }") do |params|
zip = params.zip(type, nil)
zip = params.zip(type, nil, factory: factory)

assert_empty zip
end
Expand All @@ -202,7 +202,7 @@ def test_zip_with_extra_params
)

block_params("proc {|x, y| }") do |params|
zip = params.zip(type, nil)
zip = params.zip(type, nil, factory: factory)

assert_equal 2, zip.size
assert_equal [params.params[0], parse_type("::Object")], zip[0]
Expand All @@ -223,15 +223,15 @@ def test_zip_expand_array
)

block_params("proc {|x,y,*z| }") do |params|
zip = params.zip(type, nil)
zip = params.zip(type, nil, factory: factory)

assert_equal [params.params[0], parse_type("::Integer | nil")], zip[0]
assert_equal [params.params[1], parse_type("::Integer | nil")], zip[1]
assert_equal [params.params[2], parse_type("::Array[::Integer]")], zip[2]
end

block_params("proc {|x,| }") do |params|
zip = params.zip(type, nil)
zip = params.zip(type, nil, factory: factory)

assert_equal [params.params[0], parse_type("::Integer | nil")], zip[0]
end
Expand All @@ -250,21 +250,21 @@ def test_zip_expand_tuple
)

block_params("proc {|x,y,*z| }") do |params|
zip = params.zip(type, nil)
zip = params.zip(type, nil, factory: factory)

assert_equal [params.params[0], parse_type("::Symbol")], zip[0]
assert_equal [params.params[1], parse_type("::Integer")], zip[1]
assert_equal [params.params[2], parse_type("nil")], zip[2]
end

block_params("proc {|x,| }") do |params|
zip = params.zip(type, nil)
zip = params.zip(type, nil, factory: factory)

assert_equal [params.params[0], parse_type("::Symbol")], zip[0]
end

block_params("proc {|x, *y| }") do |params|
zip = params.zip(type, nil)
zip = params.zip(type, nil, factory: factory)

assert_equal [params.params[0], parse_type("::Symbol")], zip[0]
assert_equal [params.params[1], parse_type("::Array[::Integer]")], zip[1]
Expand Down Expand Up @@ -402,7 +402,8 @@ def test_zip_block
type: parse_type("^() -> void").type,
optional: false,
self_type: nil
)
),
factory: factory
).tap do |zip|
assert_equal 1, zip.size
assert_equal [params.block_param, parse_type("^() -> void")], zip[0]
Expand All @@ -414,15 +415,17 @@ def test_zip_block
type: parse_type("^() -> void").type,
optional: true,
self_type: nil
)
),
factory: factory
).tap do |zip|
assert_equal 1, zip.size
assert_equal [params.block_param, parse_type("^() -> void | nil")], zip[0]
end

params.zip(
Params.empty,
nil
nil,
factory: factory
).tap do |zip|
assert_equal 1, zip.size
assert_equal [params.block_param, parse_type("nil")], zip[0]
Expand Down Expand Up @@ -476,7 +479,8 @@ def test_multiple_param_zip
optional_keywords: {},
rest_keywords: nil
),
nil
nil,
factory: factory
).tap do |zip|
assert_equal 1, zip.size
zip[0].tap do |pair|
Expand Down
22 changes: 22 additions & 0 deletions test/type_construction_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -10058,4 +10058,26 @@ def bar(&block)
end
end
end

def test_block_splat_alias
with_checker(<<-RBS) do |checker|
type foo = [Integer, String]
RBS
source = parse_ruby(<<-RUBY)
# @type var array: Array[foo]
array = []
array.each do |x, y|
x + 1
y + ""
end
RUBY

with_standard_construction(checker, source) do |construction, typing|
type, _, context = construction.synthesize(source.node)

assert_no_error typing
end
end
end
end

0 comments on commit 32d2ae4

Please sign in to comment.