Skip to content

Commit

Permalink
Merge pull request #622 from soutaro/narrowing-assignment
Browse files Browse the repository at this point in the history
Fix type narrowing on assignment
  • Loading branch information
soutaro committed Jul 31, 2022
1 parent d6ea36e commit b0beb62
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 3 deletions.
53 changes: 50 additions & 3 deletions lib/steep/type_construction.rb
Original file line number Diff line number Diff line change
Expand Up @@ -1844,8 +1844,21 @@ def synthesize(node, hint: nil, condition: false)
if cond
branch_results = []

cond_type, constr = constr.synthesize(cond).to_ary
cond_value_node, cond_vars = interpreter.decompose_value(cond)
cond_type, constr = constr.synthesize(cond)
_, cond_vars = interpreter.decompose_value(cond)

var_name = :"_a#{SecureRandom.base64(4)}"
var_cond, value_node = extract_outermost_call(cond, var_name)
if value_node
unless constr.context.type_env[value_node]
constr = constr.update_type_env do |env|
env.assign_local_variable(var_name, cond_type, nil)
end
cond = var_cond
else
value_node = nil
end
end

when_constr = constr
whens.each do |clause|
Expand Down Expand Up @@ -1893,7 +1906,11 @@ def synthesize(node, hint: nil, condition: false)
types = branch_results.map(&:type)
constrs = branch_results.map(&:constr)

cond_type = when_constr.context.type_env[cond_value_node]
cond_type = when_constr.context.type_env[var_name]
cond_type ||= when_constr.context.type_env[cond_vars.first || raise] unless cond_vars.empty?
cond_type ||= when_constr.context.type_env[value_node] if value_node
cond_type ||= typing.type_of(node: node.children[0])

if cond_type.is_a?(AST::Types::Bot)
# Exhaustive
if els
Expand Down Expand Up @@ -4235,5 +4252,35 @@ def save_typing
typing.save!
with_new_typing(typing.parent)
end

def extract_outermost_call(node, var_name)
case node.type
when :lvasgn
name, rhs = node.children
rhs, value_node = extract_outermost_call(rhs, var_name)
if value_node
[node.updated(nil, [name, rhs]), value_node]
else
[node, value_node]
end
when :begin
*children, last = node.children
last, value_node = extract_outermost_call(last, var_name)
if value_node
[node.updated(nil, children.push(last)), value_node]
else
[node, value_node]
end
when :lvar
[node, nil]
else
if value_node?(node)
[node, nil]
else
var_node = node.updated(:lvar, [var_name])
[var_node, node]
end
end
end
end
end
2 changes: 2 additions & 0 deletions sig/steep/node_helper.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ module Steep
def each_descendant_node: (Parser::AST::Node) -> Enumerator[Parser::AST::Node, void]
| (Parser::AST::Node) { (Parser::AST::Node) -> void } -> void

# Returns true if given node is a syntactic-value node
#
def value_node?: (Parser::AST::Node) -> bool
end
end
13 changes: 13 additions & 0 deletions sig/steep/type_construction.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -283,5 +283,18 @@ module Steep
# * All of the arguments are _pure_
#
def pure_send?: (TypeInference::MethodCall::Typed call, Parser::AST::Node receiver, Array[Parser::AST::Node] arguments) -> bool

# Transform given `node` to a node that has a local variable instead of the outer most call/non-value node
#
# Returns a pair of transformed node and the outer most call/non-value node if present.
#
# ```rb
# x = y = foo() # Call `#transform_value_node` with the node and var_name `:__foo__`
# # => Returns [x = y = __foo__, foo()]
# ```
#
# This is typically used for transforming assginment node for case condition.
#
def extract_outermost_call: (Parser::AST::Node node, Symbol) -> [Parser::AST::Node, Parser::AST::Node?]
end
end
24 changes: 24 additions & 0 deletions test/type_construction_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9758,12 +9758,36 @@ def test_lvar_special
end
-> (_) { 123 }
RUBY
with_standard_construction(checker, source) do |construction, typing|
type, _, context = construction.synthesize(source.node)

assert_no_error typing
end
end
end

def test_type_narrowing_assignment
with_checker(<<-RBS) do |checker|
class NarrowingAssignmentTest
def foo: () -> (Integer | String)
end
RBS
source = parse_ruby(<<-RUBY)
case value = NarrowingAssignmentTest.new.foo()
when Integer
"hello"
when String
value
end
RUBY

with_standard_construction(checker, source) do |construction, typing|
type, _, context = construction.synthesize(source.node)

assert_no_error typing

assert_equal parse_type("::String"), type
end
end
end
Expand Down

0 comments on commit b0beb62

Please sign in to comment.