From 7339e82d29b189ab48c697935131c8ce5d4e0d6a Mon Sep 17 00:00:00 2001 From: Soutaro Matsumoto Date: Mon, 18 Jul 2022 23:14:44 +0900 Subject: [PATCH] Add MultipleAssignment --- lib/steep.rb | 1 + lib/steep/diagnostic/ruby.rb | 16 ++ .../type_inference/multiple_assignment.rb | 189 ++++++++++++++++++ sig/steep/diagnostic/ruby.rbs | 17 ++ .../type_inference/multiple_assignment.rbs | 76 +++++++ test/multiple_assignment_test.rb | 147 ++++++++++++++ 6 files changed, 446 insertions(+) create mode 100644 lib/steep/type_inference/multiple_assignment.rb create mode 100644 sig/steep/type_inference/multiple_assignment.rbs create mode 100644 test/multiple_assignment_test.rb diff --git a/lib/steep.rb b/lib/steep.rb index c30a41cda..5d8da8585 100644 --- a/lib/steep.rb +++ b/lib/steep.rb @@ -88,6 +88,7 @@ require "steep/type_inference/type_env" require "steep/type_inference/type_env_builder" require "steep/type_inference/logic_type_interpreter" +require "steep/type_inference/multiple_assignment" require "steep/type_inference/method_call" require "steep/ast/types" diff --git a/lib/steep/diagnostic/ruby.rb b/lib/steep/diagnostic/ruby.rb index 09a95fcbe..2a23bf64a 100644 --- a/lib/steep/diagnostic/ruby.rb +++ b/lib/steep/diagnostic/ruby.rb @@ -717,6 +717,22 @@ def header_line end end + class MultipleAssignmentConversionError < Base + attr_reader :original_type, :returned_type + + def initialize(node:, original_type:, returned_type:) + super(node: node) + + @node = node + @original_type = original_type + @returned_type = returned_type + end + + def header_line + "Cannot convert `#{original_type}` to Array or tuple (`#to_ary` returns `#{returned_type}`)" + end + end + class UnsupportedSyntax < Base attr_reader :message diff --git a/lib/steep/type_inference/multiple_assignment.rb b/lib/steep/type_inference/multiple_assignment.rb new file mode 100644 index 000000000..9d85072cc --- /dev/null +++ b/lib/steep/type_inference/multiple_assignment.rb @@ -0,0 +1,189 @@ +module Steep + module TypeInference + class MultipleAssignment + Assignments = _ = Struct.new(:rhs_type, :optional, :leading_assignments, :trailing_assignments, :splat_assignment, keyword_init: true) do + # @implements Assignments + + def each(&block) + if block + leading_assignments.each(&block) + if sp = splat_assignment + yield sp + end + trailing_assignments.each(&block) + else + enum_for :each + end + end + end + + def expand(mlhs, rhs_type, optional) + lhss = mlhs.children + + case rhs_type + when AST::Types::Tuple + expand_tuple(lhss.dup, rhs_type, rhs_type.types.dup, optional) + when AST::Types::Name::Instance + if AST::Builtin::Array.instance_type?(rhs_type) + expand_array(lhss.dup, rhs_type, optional) + end + when AST::Types::Any + expand_any(lhss, rhs_type, AST::Builtin.any_type, optional) + end + end + + def expand_tuple(lhss, rhs_type, tuples, optional) + # @type var leading_assignments: Array[node_type_pair] + leading_assignments = [] + # @type var trailing_assignments: Array[node_type_pair] + trailing_assignments = [] + # @type var splat_assignment: node_type_pair? + splat_assignment = nil + + while !lhss.empty? + first = lhss.first or raise + + case + when first.type == :splat + break + else + leading_assignments << [first, tuples.first || AST::Builtin.nil_type] + lhss.shift + tuples.shift + end + end + + while !lhss.empty? + last = lhss.last or raise + + case + when last.type == :splat + break + else + trailing_assignments << [last, tuples.last || AST::Builtin.nil_type] + lhss.pop + tuples.pop + end + end + + case lhss.size + when 0 + # nop + when 1 + splat_assignment = [lhss.first || raise, AST::Types::Tuple.new(types: tuples)] + else + raise + end + + Assignments.new( + rhs_type: rhs_type, + optional: optional, + leading_assignments: leading_assignments, + trailing_assignments: trailing_assignments, + splat_assignment: splat_assignment + ) + end + + def expand_array(lhss, rhs_type, optional) + element_type = rhs_type.args[0] or raise + + # @type var leading_assignments: Array[node_type_pair] + leading_assignments = [] + # @type var trailing_assignments: Array[node_type_pair] + trailing_assignments = [] + # @type var splat_assignment: node_type_pair? + splat_assignment = nil + + while !lhss.empty? + first = lhss.first or raise + + case + when first.type == :splat + break + else + leading_assignments << [first, AST::Builtin.optional(element_type)] + lhss.shift + end + end + + while !lhss.empty? + last = lhss.last or raise + + case + when last.type == :splat + break + else + trailing_assignments << [last, AST::Builtin.optional(element_type)] + lhss.pop + end + end + + case lhss.size + when 0 + # nop + when 1 + splat_assignment = [ + lhss.first || raise, + AST::Builtin::Array.instance_type(element_type) + ] + else + raise + end + + Assignments.new( + rhs_type: rhs_type, + optional: optional, + leading_assignments: leading_assignments, + trailing_assignments: trailing_assignments, + splat_assignment: splat_assignment + ) + end + + def expand_any(nodes, rhs_type, element_type, optional) + # @type var leading_assignments: Array[node_type_pair] + leading_assignments = [] + # @type var trailing_assignments: Array[node_type_pair] + trailing_assignments = [] + # @type var splat_assignment: node_type_pair? + splat_assignment = nil + + array = leading_assignments + + nodes.each do |node| + case node.type + when :splat + splat_assignment = [node, AST::Builtin::Array.instance_type(element_type)] + array = trailing_assignments + else + array << [node, element_type] + end + end + + Assignments.new( + rhs_type: rhs_type, + optional: optional, + leading_assignments: leading_assignments, + trailing_assignments: trailing_assignments, + splat_assignment: splat_assignment + ) + end + + def hint_for_mlhs(mlhs, env) + case mlhs.type + when :mlhs + types = mlhs.children.map do |node| + hint_for_mlhs(node, env) or return + end + AST::Types::Tuple.new(types: types) + when :lvasgn, :ivasgn, :gvasgn + name = mlhs.children[0] + env[name] || AST::Builtin.any_type + when :splat + return + else + return + end + end + end + end +end diff --git a/sig/steep/diagnostic/ruby.rbs b/sig/steep/diagnostic/ruby.rbs index 065154eca..b55c3d143 100644 --- a/sig/steep/diagnostic/ruby.rbs +++ b/sig/steep/diagnostic/ruby.rbs @@ -437,6 +437,23 @@ module Steep def header_line: () -> ::String end + # The `#to_ary` of RHS of multiple assignment is called, but returns not tuple nor Array. + # + # ```ruby + # a, b = foo() + # ^^^^^ + # ``` + # + class MultipleAssignmentConversionError < Base + attr_reader original_type: AST::Types::t + + attr_reader returned_type: AST::Types::t + + def initialize: (node: Parser::AST::Node, original_type: AST::Types::t, returned_type: AST::Types::t) -> void + + def header_line: () -> ::String + end + class UnsupportedSyntax < Base attr_reader message: untyped diff --git a/sig/steep/type_inference/multiple_assignment.rbs b/sig/steep/type_inference/multiple_assignment.rbs new file mode 100644 index 000000000..cfa89c67f --- /dev/null +++ b/sig/steep/type_inference/multiple_assignment.rbs @@ -0,0 +1,76 @@ +module Steep + module TypeInference + # This class provides an abstraction for multiple assignments. + # + class MultipleAssignment + type node_type_pair = [Parser::AST::Node, AST::Types::t] + + # Encapsulate assignments included in one `masgn` node + # + # ```ruby + # a, *b, c = rhs + # # ^ Leading assignments + # # ^^ Splat assignment + # # ^ Trailing assignments + # ``` + # + class Assignments + attr_reader rhs_type: AST::Types::t + + attr_reader optional: bool + + # Assignments before `*` assignment + attr_reader leading_assignments: Array[node_type_pair] + + # Assignments after `*` assignment + # + # Empty if there is no splat assignment. + # + attr_reader trailing_assignments: Array[node_type_pair] + + # Splat assignment if present + attr_reader splat_assignment: node_type_pair? + + def initialize: ( + rhs_type: AST::Types::t, + optional: bool, + leading_assignments: Array[node_type_pair], + trailing_assignments: Array[node_type_pair], + splat_assignment: node_type_pair? + ) -> void + + def each: () { (node_type_pair) -> void } -> void + | () -> Enumerator[node_type_pair, void] + end + + def initialize: () -> void + + # Receives multiple assignment left hand side, right hand side type, and `optional` flag, and returns Assignments object + # + # This implements a case analysis on `rhs_type`: + # + # 1. If `rhs_type` is tuple, it returns an Assignments object with corresponding assignments + # 2. If `rhs_type` is an array, it returns an Assignments object with corresponding assignments + # 3. If `rhs_type` is `untyped`, it returns an Assignments with `untyped` type + # 4. It returns `nil` otherwise + # + def expand: (Parser::AST::Node mlhs, AST::Types::t rhs_type, bool optional) -> Assignments? + + # Returns a type hint for multiple assignment right hand side + # + # It constructs a structure of tuple types, based on the assignment lhs, and variable types. + # + def hint_for_mlhs: (Parser::AST::Node mlhs, TypeEnv env) -> AST::Types::t? + + private + + def expand_tuple: (Array[Parser::AST::Node] assignments, AST::Types::t rhs_type, Array[AST::Types::t] types, bool optional) -> Assignments + + def expand_array: (Array[Parser::AST::Node] assignments, AST::Types::Name::Instance rhs_type, bool optional) -> Assignments + + def expand_any: (Array[Parser::AST::Node] assignments, AST::Types::t rhs_type, AST::Types::t element_type, bool optional) -> Assignments + + def expand_else: (Array[Parser::AST::Node] assignments, AST::Types::t rhs_type, bool optional) -> Assignments + end + end +end diff --git a/test/multiple_assignment_test.rb b/test/multiple_assignment_test.rb new file mode 100644 index 000000000..4f1dfafb5 --- /dev/null +++ b/test/multiple_assignment_test.rb @@ -0,0 +1,147 @@ +require_relative "test_helper" + +class MultipleAssignmentTest < Minitest::Test + include TestHelper + include FactoryHelper + include SubtypingHelper + + include Steep + + MultipleAssignment = TypeInference::MultipleAssignment + TypeEnv = TypeInference::TypeEnv + ConstantEnv = TypeInference::ConstantEnv + + def node(type, *children) + Parser::AST::Node.new(type, children) + end + + def constant_env(context: nil) + ConstantEnv.new( + factory: factory, + context: context, + resolver: RBS::Resolver::ConstantResolver.new(builder: factory.definition_builder) + ) + end + + def test_tuple_assignment + with_checker do + source = parse_ruby("a, *b, c = _") + mlhs, rhs = source.node.children + + masgn = MultipleAssignment.new() + asgns = masgn.expand(mlhs, parse_type("[::Integer, ::String, ::Symbol]"), false) + + assert_equal( + MultipleAssignment::Assignments.new( + rhs_type: parse_type("[::Integer, ::String, ::Symbol]"), + optional: false, + leading_assignments: [[node(:lvasgn, :a), parse_type("::Integer")]], + trailing_assignments: [[node(:lvasgn, :c), parse_type("::Symbol")]], + splat_assignment: [node(:splat, node(:lvasgn, :b)), parse_type("[::String]")] + ), + asgns + ) + end + end + + def test_tuple_assignment_optional + with_checker do + source = parse_ruby("a, *b, c = _") + mlhs, rhs = source.node.children + + masgn = MultipleAssignment.new() + asgns = masgn.expand(mlhs, parse_type("[::Integer, ::String, ::Symbol]"), true) + + assert_equal( + MultipleAssignment::Assignments.new( + rhs_type: parse_type("[::Integer, ::String, ::Symbol]"), + optional: true, + leading_assignments: [[node(:lvasgn, :a), parse_type("::Integer")]], + trailing_assignments: [[node(:lvasgn, :c), parse_type("::Symbol")]], + splat_assignment: [node(:splat, node(:lvasgn, :b)), parse_type("[::String]")] + ), + asgns + ) + end + end + + def test_array_assignment + with_checker do + source = parse_ruby("a, *b, c = _") + mlhs, rhs = source.node.children + + masgn = MultipleAssignment.new() + asgns = masgn.expand(mlhs, parse_type("::Array[::Integer]"), false) + + assert_equal( + MultipleAssignment::Assignments.new( + rhs_type: parse_type("::Array[::Integer]"), + optional: false, + leading_assignments: [[node(:lvasgn, :a), parse_type("::Integer?")]], + trailing_assignments: [[node(:lvasgn, :c), parse_type("::Integer?")]], + splat_assignment: [node(:splat, node(:lvasgn, :b)), parse_type("::Array[::Integer]")] + ), + asgns + ) + end + end + + def test_array_assignment_optional + with_checker do + source = parse_ruby("a, *b, c = _") + mlhs, rhs = source.node.children + + masgn = MultipleAssignment.new() + asgns = masgn.expand(mlhs, parse_type("::Array[::Integer]"), true) + + assert_equal( + MultipleAssignment::Assignments.new( + rhs_type: parse_type("::Array[::Integer]"), + optional: true, + leading_assignments: [[node(:lvasgn, :a), parse_type("::Integer?")]], + trailing_assignments: [[node(:lvasgn, :c), parse_type("::Integer?")]], + splat_assignment: [node(:splat, node(:lvasgn, :b)), parse_type("::Array[::Integer]")] + ), + asgns + ) + end + end + + def test_hint_for_mlhs + with_checker do + env = + + masgn = MultipleAssignment.new() + + masgn.hint_for_mlhs( + parse_ruby("a, b = _").node.children[0], + TypeEnv.new(constant_env) + ).tap do |hint| + assert_equal parse_type("[untyped, untyped]"), hint + end + + masgn.hint_for_mlhs( + parse_ruby("a, b, *c = _").node.children[0], + TypeEnv.new(constant_env) + ).tap do |hint| + assert_nil hint + end + + masgn.hint_for_mlhs( + parse_ruby("a, (b, c) = _").node.children[0], + TypeEnv.new(constant_env) + ).tap do |hint| + assert_equal parse_type("[untyped, [untyped, untyped]]"), hint + end + + masgn.hint_for_mlhs( + parse_ruby("a, @b, $c = _").node.children[0], + TypeEnv.new(constant_env) + .assign_local_variables({ a: parse_type("::String") }) + .update(instance_variable_types: { :"@b" => parse_type("::Symbol") }, global_types: { :"$c" => parse_type("::Integer") }) + ).tap do |hint| + assert_equal parse_type("[::String, ::Symbol, ::Integer]"), hint + end + end + end +end