diff --git a/lib/steep/type_construction.rb b/lib/steep/type_construction.rb index 794bebaec..44a64bd13 100644 --- a/lib/steep/type_construction.rb +++ b/lib/steep/type_construction.rb @@ -1844,8 +1844,21 @@ def synthesize(node, hint: nil, condition: false) if cond branch_results = [] - cond_type, constr = constr.synthesize(cond).to_ary - cond_value_node, cond_vars = interpreter.decompose_value(cond) + cond_type, constr = constr.synthesize(cond) + _, cond_vars = interpreter.decompose_value(cond) + + var_name = :"_a#{SecureRandom.base64(4)}" + var_cond, value_node = extract_outermost_call(cond, var_name) + if value_node + unless constr.context.type_env[value_node] + constr = constr.update_type_env do |env| + env.assign_local_variable(var_name, cond_type, nil) + end + cond = var_cond + else + value_node = nil + end + end when_constr = constr whens.each do |clause| @@ -1893,7 +1906,11 @@ def synthesize(node, hint: nil, condition: false) types = branch_results.map(&:type) constrs = branch_results.map(&:constr) - cond_type = when_constr.context.type_env[cond_value_node] + cond_type = when_constr.context.type_env[var_name] + cond_type ||= when_constr.context.type_env[cond_vars.first || raise] unless cond_vars.empty? + cond_type ||= when_constr.context.type_env[value_node] if value_node + cond_type ||= typing.type_of(node: node.children[0]) + if cond_type.is_a?(AST::Types::Bot) # Exhaustive if els @@ -4235,5 +4252,35 @@ def save_typing typing.save! with_new_typing(typing.parent) end + + def extract_outermost_call(node, var_name) + case node.type + when :lvasgn + name, rhs = node.children + rhs, value_node = extract_outermost_call(rhs, var_name) + if value_node + [node.updated(nil, [name, rhs]), value_node] + else + [node, value_node] + end + when :begin + *children, last = node.children + last, value_node = extract_outermost_call(last, var_name) + if value_node + [node.updated(nil, children.push(last)), value_node] + else + [node, value_node] + end + when :lvar + [node, nil] + else + if value_node?(node) + [node, nil] + else + var_node = node.updated(:lvar, [var_name]) + [var_node, node] + end + end + end end end diff --git a/sig/steep/node_helper.rbs b/sig/steep/node_helper.rbs index b5491946a..198f71676 100644 --- a/sig/steep/node_helper.rbs +++ b/sig/steep/node_helper.rbs @@ -6,6 +6,8 @@ module Steep def each_descendant_node: (Parser::AST::Node) -> Enumerator[Parser::AST::Node, void] | (Parser::AST::Node) { (Parser::AST::Node) -> void } -> void + # Returns true if given node is a syntactic-value node + # def value_node?: (Parser::AST::Node) -> bool end end diff --git a/sig/steep/type_construction.rbs b/sig/steep/type_construction.rbs index aaa8dcecc..ca08c14db 100644 --- a/sig/steep/type_construction.rbs +++ b/sig/steep/type_construction.rbs @@ -283,5 +283,18 @@ module Steep # * All of the arguments are _pure_ # def pure_send?: (TypeInference::MethodCall::Typed call, Parser::AST::Node receiver, Array[Parser::AST::Node] arguments) -> bool + + # Transform given `node` to a node that has a local variable instead of the outer most call/non-value node + # + # Returns a pair of transformed node and the outer most call/non-value node if present. + # + # ```rb + # x = y = foo() # Call `#transform_value_node` with the node and var_name `:__foo__` + # # => Returns [x = y = __foo__, foo()] + # ``` + # + # This is typically used for transforming assginment node for case condition. + # + def extract_outermost_call: (Parser::AST::Node node, Symbol) -> [Parser::AST::Node, Parser::AST::Node?] end end diff --git a/test/type_construction_test.rb b/test/type_construction_test.rb index fbb56a230..64ddf4c8a 100644 --- a/test/type_construction_test.rb +++ b/test/type_construction_test.rb @@ -9758,12 +9758,36 @@ def test_lvar_special end -> (_) { 123 } +RUBY + with_standard_construction(checker, source) do |construction, typing| + type, _, context = construction.synthesize(source.node) + + assert_no_error typing + end + end + end + + def test_type_narrowing_assignment + with_checker(<<-RBS) do |checker| +class NarrowingAssignmentTest + def foo: () -> (Integer | String) +end + RBS + source = parse_ruby(<<-RUBY) +case value = NarrowingAssignmentTest.new.foo() +when Integer + "hello" +when String + value +end RUBY with_standard_construction(checker, source) do |construction, typing| type, _, context = construction.synthesize(source.node) assert_no_error typing + + assert_equal parse_type("::String"), type end end end