diff --git a/lib/steep.rb b/lib/steep.rb index 5ac22ad58..0aa750f28 100644 --- a/lib/steep.rb +++ b/lib/steep.rb @@ -60,6 +60,7 @@ require "steep/typing" require "steep/errors" require "steep/type_construction" +require "steep/type_inference/context" require "steep/type_inference/send_args" require "steep/type_inference/block_params" require "steep/type_inference/constant_env" @@ -72,6 +73,7 @@ require "steep/project/target" require "steep/project/dsl" require "steep/project/file_loader" +require "steep/project/hover_content" require "steep/drivers/utils/driver_helper" require "steep/drivers/check" require "steep/drivers/validate" diff --git a/lib/steep/drivers/langserver.rb b/lib/steep/drivers/langserver.rb index 64f5db9bb..99cf7e51b 100644 --- a/lib/steep/drivers/langserver.rb +++ b/lib/steep/drivers/langserver.rb @@ -187,6 +187,8 @@ def run_type_check() [diagnostics_raw(source.status.error.message, source.status.location)] when Project::SourceFile::ParseErrorStatus [] + when Project::SourceFile::TypeCheckErrorStatus + [] end report_diagnostics source.path, diagnostics @@ -264,22 +266,71 @@ def diagnostic_for_type_error(error) def response_to_hover(path:, line:, column:) Steep.logger.info { "path=#{path}, line=#{line}, column=#{column}" } - # line in LSP is zero-origin - project.type_of_node(path: path, line: line + 1, column: column) do |type, node| - Steep.logger.warn { "node = #{node.type}, type = #{type.to_s}" } - - start_position = { line: node.location.line - 1, character: node.location.column } - end_position = { line: node.location.last_line - 1, character: node.location.last_column } - range = { start: start_position, end: end_position } - - Steep.logger.warn { "range = #{range.inspect}" } + hover = Project::HoverContent.new(project: project) + content = hover.content_for(path: path, line: line+1, column: column+1) + if content + range = content.location.yield_self do |location| + start_position = { line: location.line - 1, character: location.column } + end_position = { line: location.last_line - 1, character: location.last_column } + { start: start_position, end: end_position } + end LanguageServer::Protocol::Interface::Hover.new( - contents: { kind: "markdown", value: "`#{type}`" }, + contents: { kind: "markdown", value: format_hover(content) }, range: range ) end end + + def format_hover(content) + case content + when Project::HoverContent::VariableContent + "`#{content.name}`: `#{content.type.to_s}`" + when Project::HoverContent::MethodCallContent + method_name = case content.method_name + when Project::HoverContent::InstanceMethodName + "#{content.method_name.class_name}##{content.method_name.method_name}" + when Project::HoverContent::SingletonMethodName + "#{content.method_name.class_name}.#{content.method_name.method_name}" + else + nil + end + + if method_name + string = < #{content.type} +``` +HOVER + if content.definition + if content.definition.comment + string << "\n----\n\n#{content.definition.comment.string}" + end + + string << "\n----\n\n#{content.definition.method_types.map {|x| "- `#{x}`\n" }.join()}" + end + else + "`#{content.type}`" + end + when Project::HoverContent::DefinitionContent + string = < 1 + string << "\n----\n\n#{content.definition.method_types.map {|x| "- `#{x}`\n" }.join()}" + end + + string + when Project::HoverContent::TypeContent + "`#{content.type}`" + end + end end end end diff --git a/lib/steep/interface/substitution.rb b/lib/steep/interface/substitution.rb index 51fe587e9..66a429ae7 100644 --- a/lib/steep/interface/substitution.rb +++ b/lib/steep/interface/substitution.rb @@ -1,6 +1,18 @@ module Steep module Interface class Substitution + class InvalidSubstitutionError < StandardError + attr_reader :vars_size + attr_reader :types_size + + def initialize(vars_size:, types_size:) + @var_size = vars_size + @types_size = types_size + + super "Invalid substitution: vars.size=#{vars_size}, types.size=#{types_size}" + end + end + attr_reader :dictionary attr_reader :instance_type attr_reader :module_type @@ -42,7 +54,7 @@ def key?(var) def self.build(vars, types = nil, instance_type: AST::Types::Instance.new, module_type: AST::Types::Class.new, self_type: AST::Types::Self.new) types ||= vars.map {|var| AST::Types::Var.fresh(var) } - raise "Invalid substitution: vars.size=#{vars.size}, types.size=#{types.size}" unless vars.size == types.size + raise InvalidSubstitutionError.new(vars_size: vars.size, types_size: types.size) unless vars.size == types.size dic = vars.zip(types).each.with_object({}) do |(var, type), d| d[var] = type diff --git a/lib/steep/project.rb b/lib/steep/project.rb index 91477361e..ca00063e3 100644 --- a/lib/steep/project.rb +++ b/lib/steep/project.rb @@ -24,6 +24,7 @@ def type_of_node(path:, line:, column:) source_file = targets.map {|target| target.source_files[path] }.compact[0] if source_file + case (status = source_file.status) when SourceFile::TypeCheckStatus node = status.source.find_node(line: line, column: column) diff --git a/lib/steep/project/file.rb b/lib/steep/project/file.rb index e0e3c5e3b..e986154c6 100644 --- a/lib/steep/project/file.rb +++ b/lib/steep/project/file.rb @@ -11,6 +11,7 @@ class SourceFile ParseErrorStatus = Struct.new(:error, keyword_init: true) AnnotationSyntaxErrorStatus = Struct.new(:error, :location, keyword_init: true) TypeCheckStatus = Struct.new(:typing, :source, :timestamp, keyword_init: true) + TypeCheckErrorStatus = Struct.new(:error, keyword_init: true) def initialize(path:) @path = path @@ -62,20 +63,22 @@ def type_check(subtyping, env_updated_at) checker: subtyping, annotations: annotations, source: source, - self_type: AST::Builtin::Object.instance_type, - block_context: nil, - module_context: TypeConstruction::ModuleContext.new( - instance_type: nil, - module_type: nil, - implement_name: nil, - current_namespace: AST::Namespace.root, - const_env: const_env, - class_name: nil + context: TypeInference::Context.new( + block_context: nil, + module_context: TypeInference::Context::ModuleContext.new( + instance_type: nil, + module_type: nil, + implement_name: nil, + current_namespace: AST::Namespace.root, + const_env: const_env, + class_name: nil + ), + method_context: nil, + break_context: nil, + self_type: AST::Builtin::Object.instance_type, + type_env: type_env ), - method_context: nil, - typing: typing, - break_context: nil, - type_env: type_env + typing: typing ) construction.synthesize(source.node) @@ -86,6 +89,8 @@ def type_check(subtyping, env_updated_at) source: source, timestamp: Time.now ) + rescue => exn + @status = TypeCheckErrorStatus.new(error: exn) end true diff --git a/lib/steep/project/hover_content.rb b/lib/steep/project/hover_content.rb new file mode 100644 index 000000000..ffa31c8ab --- /dev/null +++ b/lib/steep/project/hover_content.rb @@ -0,0 +1,128 @@ +module Steep + class Project + class HoverContent + TypeContent = Struct.new(:node, :type, :location, keyword_init: true) + VariableContent = Struct.new(:node, :name, :type, :location, keyword_init: true) + MethodCallContent = Struct.new(:node, :method_name, :type, :definition, :location, keyword_init: true) + DefinitionContent = Struct.new(:node, :method_name, :method_type, :definition, :location, keyword_init: true) + + InstanceMethodName = Struct.new(:class_name, :method_name) + SingletonMethodName = Struct.new(:class_name, :method_name) + + attr_reader :project + + def initialize(project:) + @project = project + end + + def method_definition_for(factory, module_name, singleton_method: nil, instance_method: nil) + type_name = factory.type_name_1(module_name) + + case + when instance_method + factory.definition_builder.build_instance(type_name).methods[instance_method] + when singleton_method + methods = factory.definition_builder.build_singleton(type_name).methods + + if singleton_method == :new + methods[:new] || methods[:initialize] + else + methods[singleton_method] + end + end + end + + def content_for(path:, line:, column:) + source_file = project.targets.map {|target| target.source_files[path] }.compact[0] + + if source_file + case (status = source_file.status) + when SourceFile::TypeCheckStatus + node, *parents = status.source.find_nodes(line: line, column: column) + + if node + case node.type + when :lvar, :lvasgn + var_name = node.children[0] + context = status.typing.context_of(node: node) + var_type = context.type_env.get(lvar: var_name.name) + + VariableContent.new(node: node, name: var_name.name, type: var_type, location: node.location.name) + when :send + receiver, method_name, *_ = node.children + + + result_node = if parents[0]&.type == :block + parents[0] + else + node + end + + context = status.typing.context_of(node: result_node) + + receiver_type = if receiver + status.typing.type_of(node: receiver) + else + context.self_type + end + + factory = context.type_env.subtyping.factory + method_name, definition = case receiver_type + when AST::Types::Name::Instance + method_definition = method_definition_for(factory, receiver_type.name, instance_method: method_name) + if method_definition&.defined_in + owner_name = factory.type_name(method_definition.defined_in.name.absolute!) + [ + InstanceMethodName.new(owner_name, method_name), + method_definition + ] + end + when AST::Types::Name::Class + method_definition = method_definition_for(factory, receiver_type.name, singleton_method: method_name) + if method_definition&.defined_in + owner_name = factory.type_name(method_definition.defined_in.name.absolute!) + [ + SingletonMethodName.new(owner_name, method_name), + method_definition + ] + end + else + nil + end + + MethodCallContent.new( + node: node, + method_name: method_name, + type: status.typing.type_of(node: result_node), + definition: definition, + location: result_node.location.expression + ) + when :def, :defs + context = status.typing.context_of(node: node) + method_context = context.method_context + + if method_context + DefinitionContent.new( + node: node, + method_name: method_context.name, + method_type: method_context.method_type, + definition: method_context.method, + location: node.loc.expression + ) + end + else + type = status.typing.type_of(node: node) + + TypeContent.new( + node: node, + type: type, + location: node.location.expression + ) + end + end + end + end + end + end + end +end diff --git a/lib/steep/source.rb b/lib/steep/source.rb index 9c56672e1..cf5801d3b 100644 --- a/lib/steep/source.rb +++ b/lib/steep/source.rb @@ -294,8 +294,7 @@ def each_annotation end end - # @type method find_node: (line: Integer, column: Integer, ?node: any, ?position: Integer?) -> any - def find_node(line:, column:, node: self.node, position: nil) + def find_nodes(line:, column:, node: self.node, position: nil, parents: []) position ||= (line-1).times.sum do |i| node.location.expression.source_buffer.source_line(i+1).size + 1 end + column @@ -306,14 +305,13 @@ def find_node(line:, column:, node: self.node, position: nil) if range if range === position + parents.unshift node + Source.each_child_node(node) do |child| - n = find_node(line: line, column: column, node: child, position: position) - if n - return n - end + ns = find_nodes(line: line, column: column, node: child, position: position, parents: parents) and return ns end - node + parents end end end diff --git a/lib/steep/type_construction.rb b/lib/steep/type_construction.rb index 34147a31d..980cee750 100644 --- a/lib/steep/type_construction.rb +++ b/lib/steep/type_construction.rb @@ -1,97 +1,43 @@ module Steep class TypeConstruction - class MethodContext - attr_reader :name - attr_reader :method - attr_reader :method_type - attr_reader :return_type - attr_reader :constructor - attr_reader :super_method - - def initialize(name:, method:, method_type:, return_type:, constructor:, super_method:) - @name = name - @method = method - @return_type = return_type - @method_type = method_type - @constructor = constructor - @super_method = super_method - end - - def block_type - method_type&.block - end - end + attr_reader :checker + attr_reader :source + attr_reader :annotations + attr_reader :typing + attr_reader :type_env - class BlockContext - attr_reader :body_type + attr_reader :context - def initialize(body_type:) - @body_type = body_type - end + def module_context + context.module_context end - class BreakContext - attr_reader :break_type - attr_reader :next_type + def method_context + context.method_context + end - def initialize(break_type:, next_type:) - @break_type = break_type - @next_type = next_type - end + def block_context + context.block_context end - class ModuleContext - attr_reader :instance_type - attr_reader :module_type - attr_reader :defined_instance_methods - attr_reader :defined_module_methods - attr_reader :const_env - attr_reader :implement_name - attr_reader :current_namespace - attr_reader :class_name - attr_reader :instance_definition - attr_reader :module_definition - - def initialize(instance_type:, module_type:, implement_name:, current_namespace:, const_env:, class_name:, instance_definition: nil, module_definition: nil) - @instance_type = instance_type - @module_type = module_type - @defined_instance_methods = Set.new - @defined_module_methods = Set.new - @implement_name = implement_name - @current_namespace = current_namespace - @const_env = const_env - @class_name = class_name - @instance_definition = instance_definition - @module_definition = module_definition - end + def break_context + context.break_context + end - def const_context - const_env.context - end + def self_type + context.self_type end - attr_reader :checker - attr_reader :source - attr_reader :annotations - attr_reader :typing - attr_reader :method_context - attr_reader :block_context - attr_reader :module_context - attr_reader :self_type - attr_reader :break_context - attr_reader :type_env + def type_env + context.type_env + end - def initialize(checker:, source:, annotations:, type_env:, typing:, self_type:, method_context:, block_context:, module_context:, break_context:) + def initialize(checker:, source:, annotations:, typing:, context:) @checker = checker @source = source @annotations = annotations @typing = typing - @self_type = self_type - @block_context = block_context - @method_context = method_context - @module_context = module_context - @break_context = break_context - @type_env = type_env + @context = context end def with_new_typing(typing) @@ -99,13 +45,8 @@ def with_new_typing(typing) checker: checker, source: source, annotations: annotations, - type_env: type_env, typing: typing, - self_type: self_type, - method_context: method_context, - block_context: block_context, - module_context: module_context, - break_context: break_context + context: context ) end @@ -173,7 +114,7 @@ def for_new_method(method_name, node, args:, self_type:, definition:) end end - method_context = MethodContext.new( + method_context = TypeInference::Context::MethodContext.new( name: method_name, method: definition && definition.methods[method_name], method_type: method_type, @@ -205,13 +146,15 @@ def for_new_method(method_name, node, args:, self_type:, definition:) checker: checker, source: source, annotations: annots, - type_env: type_env, - block_context: nil, - self_type: annots.self_type || self_type, - method_context: method_context, + context: TypeInference::Context.new( + method_context: method_context, + module_context: module_context, + block_context: nil, + break_context: nil, + self_type: annots.self_type || self_type, + type_env: type_env + ), typing: typing, - module_context: module_context, - break_context: nil ) end @@ -283,7 +226,7 @@ def for_module(node) end module_const_env = TypeInference::ConstantEnv.new(factory: checker.factory, context: const_context) - module_context_ = ModuleContext.new( + module_context_ = TypeInference::Context::ModuleContext.new( instance_type: instance_type, module_type: annots.self_type || module_type, implement_name: implement_module_name, @@ -303,13 +246,15 @@ def for_module(node) checker: checker, source: source, annotations: annots, - type_env: module_type_env, typing: typing, - method_context: nil, - block_context: nil, - module_context: module_context_, - self_type: module_context_.module_type, - break_context: nil + context: TypeInference::Context.new( + method_context: nil, + block_context: nil, + break_context: nil, + module_context: module_context_, + self_type: module_context_.module_type, + type_env: module_type_env + ) ) end @@ -380,7 +325,7 @@ def for_class(node) end class_const_env = TypeInference::ConstantEnv.new(factory: checker.factory, context: const_context) - module_context = ModuleContext.new( + module_context = TypeInference::Context::ModuleContext.new( instance_type: annots.instance_type || instance_type, module_type: annots.self_type || annots.module_type || module_type, implement_name: implement_module_name, @@ -401,21 +346,21 @@ def for_class(node) checker: checker, source: source, annotations: annots, - type_env: class_type_env, typing: typing, - method_context: nil, - block_context: nil, - module_context: module_context, - self_type: module_context.module_type, - break_context: nil + context: TypeInference::Context.new( + method_context: nil, + block_context: nil, + module_context: module_context, + break_context: nil, + self_type: module_context.module_type, + type_env: class_type_env + ) ) end - def for_branch(node, truthy_vars: Set.new, type_case_override: nil) + def for_branch(node, truthy_vars: Set.new, type_case_override: nil, break_context: context.break_context) annots = source.annotations(block: node, factory: checker.factory, current_module: current_namespace) - type_env = self.type_env - lvar_types = self.type_env.lvar_types.each.with_object({}) do |(var, type), env| if truthy_vars.member?(var) env[var] = unwrap(type) @@ -423,7 +368,8 @@ def for_branch(node, truthy_vars: Set.new, type_case_override: nil) env[var] = type end end - type_env = type_env.with_annotations(lvar_types: lvar_types, self_type: self_type) do |var, relation, result| + + type_env = self.type_env.with_annotations(lvar_types: lvar_types, self_type: self_type) do |var, relation, result| raise "Unexpected annotate failure: #{relation}" end @@ -453,23 +399,18 @@ def for_branch(node, truthy_vars: Set.new, type_case_override: nil) ) end - with(type_env: type_env) + with(context: context.with(type_env: type_env, break_context: break_context)) end NOTHING = ::Object.new - def with(annotations: NOTHING, type_env: NOTHING, method_context: NOTHING, block_context: NOTHING, module_context: NOTHING, self_type: NOTHING, break_context: NOTHING) + def with(annotations: NOTHING, context: NOTHING) self.class.new( checker: checker, source: source, annotations: annotations.equal?(NOTHING) ? self.annotations : annotations, - type_env: type_env.equal?(NOTHING) ? self.type_env : type_env, typing: typing, - method_context: method_context.equal?(NOTHING) ? self.method_context : method_context, - block_context: block_context.equal?(NOTHING) ? self.block_context : block_context, - module_context: module_context.equal?(NOTHING) ? self.module_context : module_context, - self_type: self_type.equal?(NOTHING) ? self.self_type : self_type, - break_context: break_context.equal?(NOTHING) ? self.break_context : break_context + context: context.equal?(NOTHING) ? self.context : context ) end @@ -489,7 +430,7 @@ def synthesize(node, hint: nil) type = AST::Builtin.nil_type end - typing.add_typing(node, type) + typing.add_typing(node, type, context) end when :lvasgn @@ -500,9 +441,9 @@ def synthesize(node, hint: nil) case var.name when :_, :__any__ synthesize(rhs, hint: AST::Builtin.any_type) - typing.add_typing(node, AST::Builtin.any_type) + typing.add_typing(node, AST::Builtin.any_type, context) when :__skip__ - typing.add_typing(node, AST::Builtin.any_type) + typing.add_typing(node, AST::Builtin.any_type, context) else type_assignment(var, rhs, node, hint: hint) end @@ -515,7 +456,7 @@ def synthesize(node, hint: nil) fallback_to_any node end - typing.add_typing node, type + typing.add_typing node, type, context end when :ivasgn @@ -530,7 +471,7 @@ def synthesize(node, hint: nil) type = type_env.get(ivar: name) do fallback_to_any node end - typing.add_typing(node, type) + typing.add_typing(node, type, context) end when :send @@ -543,7 +484,7 @@ def synthesize(node, hint: nil) module_type end - typing.add_typing(node, type) + typing.add_typing(node, type, context) else type_send(node, send_node: node, block_params: nil, block_body: nil) end @@ -558,7 +499,7 @@ def synthesize(node, hint: nil) else module_type end - typing.add_typing(node, type) + typing.add_typing(node, type, context) else type_send(node, send_node: node, block_params: nil, block_body: nil, unwrap: true) end @@ -570,7 +511,7 @@ def synthesize(node, hint: nil) each_child_node(node) do |child| synthesize(child) end - typing.add_typing(node, AST::Builtin.any_type) + typing.add_typing(node, AST::Builtin.any_type, context) when :op_asgn yield_self do @@ -594,7 +535,7 @@ def synthesize(node, hint: nil) case when lhs_type == AST::Builtin.any_type - typing.add_typing(node, lhs_type) + typing.add_typing(node, lhs_type, context) when !lhs_type fallback_to_any(node) else @@ -627,7 +568,7 @@ def synthesize(node, hint: nil) typing.add_error Errors::NoMethod.new(node: node, method: op, type: expand_self(lhs_type)) end - typing.add_typing(node, lhs_type) + typing.add_typing(node, lhs_type, context) end end @@ -656,7 +597,7 @@ def synthesize(node, hint: nil) block_body: nil, topdown_hint: true) - typing.add_typing node, return_type + typing.add_typing node, return_type, context else fallback_to_any node do Errors::UnexpectedSuper.new(node: node, method: method_context.name) @@ -717,7 +658,7 @@ def synthesize(node, hint: nil) module_context.defined_instance_methods << node.children[0] end - typing.add_typing(node, AST::Builtin.any_type) + typing.add_typing(node, AST::Builtin.any_type, new.context) when :defs synthesize(node.children[0]).tap do |self_type| @@ -762,7 +703,7 @@ def synthesize(node, hint: nil) end end - typing.add_typing(node, AST::Builtin::Symbol.instance_type) + typing.add_typing(node, AST::Builtin::Symbol.instance_type, context) when :return yield_self do @@ -791,7 +732,7 @@ def synthesize(node, hint: nil) end end - typing.add_typing(node, AST::Builtin.any_type) + typing.add_typing(node, AST::Builtin.any_type, context) end when :break @@ -817,7 +758,7 @@ def synthesize(node, hint: nil) typing.add_error Errors::UnexpectedJump.new(node: node) end - typing.add_typing(node, AST::Builtin.any_type) + typing.add_typing(node, AST::Builtin.any_type, context) when :next value = node.children[0] @@ -842,13 +783,13 @@ def synthesize(node, hint: nil) typing.add_error Errors::UnexpectedJump.new(node: node) end - typing.add_typing(node, AST::Builtin.any_type) + typing.add_typing(node, AST::Builtin.any_type, context) when :retry unless break_context typing.add_error Errors::UnexpectedJump.new(node: node) end - typing.add_typing(node, AST::Builtin.any_type) + typing.add_typing(node, AST::Builtin.any_type, context) when :arg, :kwarg, :procarg0 yield_self do @@ -856,7 +797,7 @@ def synthesize(node, hint: nil) type = type_env.get(lvar: var.name) do fallback_to_any node end - typing.add_typing(node, type) + typing.add_typing(node, type, context) end when :optarg, :kwoptarg @@ -874,7 +815,7 @@ def synthesize(node, hint: nil) AST::Builtin::Array.instance_type(AST::Builtin.any_type) end - typing.add_typing(node, type) + typing.add_typing(node, type, context) end when :kwrestarg @@ -885,23 +826,23 @@ def synthesize(node, hint: nil) AST::Builtin::Hash.instance_type(AST::Builtin::Symbol.instance_type, AST::Builtin.any_type) end - typing.add_typing(node, type) + typing.add_typing(node, type, context) end when :float - typing.add_typing(node, AST::Builtin::Float.instance_type) + typing.add_typing(node, AST::Builtin::Float.instance_type, context) when :nil - typing.add_typing(node, AST::Builtin.nil_type) + typing.add_typing(node, AST::Builtin.nil_type, context) when :int yield_self do literal_type = expand_alias(hint) {|hint_| test_literal_type(node.children[0], hint_)} if literal_type - typing.add_typing(node, literal_type) + typing.add_typing(node, literal_type, context) else - typing.add_typing(node, AST::Builtin::Integer.instance_type) + typing.add_typing(node, AST::Builtin::Integer.instance_type, context) end end @@ -910,9 +851,9 @@ def synthesize(node, hint: nil) literal_type = expand_alias(hint) {|hint_| test_literal_type(node.children[0], hint_)} if literal_type - typing.add_typing(node, literal_type) + typing.add_typing(node, literal_type, context) else - typing.add_typing(node, AST::Builtin::Symbol.instance_type) + typing.add_typing(node, AST::Builtin::Symbol.instance_type, context) end end @@ -921,14 +862,14 @@ def synthesize(node, hint: nil) literal_type = expand_alias(hint) {|hint_| test_literal_type(node.children[0], hint_)} if literal_type - typing.add_typing(node, literal_type) + typing.add_typing(node, literal_type, context) else - typing.add_typing(node, AST::Builtin::String.instance_type) + typing.add_typing(node, AST::Builtin::String.instance_type, context) end end when :true, :false - typing.add_typing(node, AST::Types::Boolean.new) + typing.add_typing(node, AST::Types::Boolean.new, context) when :hash yield_self do @@ -975,7 +916,7 @@ def synthesize(node, hint: nil) typing.add_error Errors::FallbackAny.new(node: node) end - typing.add_typing(node, AST::Builtin::Hash.instance_type(key_type, value_type)) + typing.add_typing(node, AST::Builtin::Hash.instance_type(key_type, value_type), context) end when :dstr, :xstr @@ -983,14 +924,14 @@ def synthesize(node, hint: nil) synthesize(child) end - typing.add_typing(node, AST::Builtin::String.instance_type) + typing.add_typing(node, AST::Builtin::String.instance_type, context) when :dsym each_child_node(node) do |child| synthesize(child) end - typing.add_typing(node, AST::Builtin::Symbol.instance_type) + typing.add_typing(node, AST::Builtin::Symbol.instance_type, context) when :class yield_self do @@ -1002,7 +943,7 @@ def synthesize(node, hint: nil) end end - typing.add_typing(node, AST::Builtin.nil_type) + typing.add_typing(node, AST::Builtin.nil_type, context) end when :module @@ -1015,11 +956,11 @@ def synthesize(node, hint: nil) end end - typing.add_typing(node, AST::Builtin.nil_type) + typing.add_typing(node, AST::Builtin.nil_type, context) end when :self - typing.add_typing node, AST::Types::Self.new + typing.add_typing node, AST::Types::Self.new, context when :const const_name = Names::Module.from_node(node) @@ -1027,7 +968,7 @@ def synthesize(node, hint: nil) type = type_env.get(const: const_name) do fallback_to_any node end - typing.add_typing node, type + typing.add_typing node, type, context else fallback_to_any node end @@ -1050,7 +991,7 @@ def synthesize(node, hint: nil) end end - typing.add_typing(node, type) + typing.add_typing(node, type, context) else synthesize(node.children.last) fallback_to_any(node) @@ -1072,7 +1013,7 @@ def synthesize(node, hint: nil) end end - typing.add_typing(node, block_type.type.return_type) + typing.add_typing(node, block_type.type.return_type, context) else typing.add_error(Errors::UnexpectedYield.new(node: node)) fallback_to_any node @@ -1095,7 +1036,7 @@ def synthesize(node, hint: nil) raise "Unexpected method_type: #{method_type.inspect}" end } - typing.add_typing(node, union_type(*types)) + typing.add_typing(node, union_type(*types), context) else typing.add_error(Errors::UnexpectedSuper.new(node: node, method: method_context.name)) fallback_to_any node @@ -1117,7 +1058,7 @@ def synthesize(node, hint: nil) end end - typing.add_typing(node, array_type || AST::Builtin::Array.instance_type(AST::Builtin.any_type)) + typing.add_typing(node, array_type || AST::Builtin::Array.instance_type(AST::Builtin.any_type), context) else is_tuple = nil @@ -1170,7 +1111,7 @@ def synthesize(node, hint: nil) array_type = AST::Builtin::Array.instance_type(AST::Types::Union.build(types: element_types)) end - typing.add_typing(node, array_type) + typing.add_typing(node, array_type, context) end end @@ -1189,9 +1130,9 @@ def synthesize(node, hint: nil) const_env: nil)]) if left_type.is_a?(AST::Types::Boolean) - typing.add_typing(node, union_type(left_type, right_type)) + typing.add_typing(node, union_type(left_type, right_type), context) else - typing.add_typing(node, union_type(right_type, AST::Builtin.nil_type)) + typing.add_typing(node, union_type(right_type, AST::Builtin.nil_type), context) end end @@ -1201,7 +1142,7 @@ def synthesize(node, hint: nil) t1 = synthesize(c1, hint: hint) t2 = synthesize(c2, hint: unwrap(t1)) type = union_type(unwrap(t1), t2) - typing.add_typing(node, type) + typing.add_typing(node, type, context) end when :if @@ -1224,7 +1165,7 @@ def synthesize(node, hint: nil) end type_env.join!([true_env, false_env].compact) - typing.add_typing(node, union_type(true_type, false_type)) + typing.add_typing(node, union_type(true_type, false_type), context) when :case yield_self do @@ -1306,7 +1247,7 @@ def synthesize(node, hint: nil) end type_env.join!(envs.compact) - typing.add_typing(node, union_type(*types)) + typing.add_typing(node, union_type(*types), context) end when :rescue @@ -1373,7 +1314,7 @@ def synthesize(node, hint: nil) type_env.join!([*resbody_envs, else_env].compact) types = [body_type, *resbody_types, else_type].compact - typing.add_typing(node, union_type(*types)) + typing.add_typing(node, union_type(*types), context) end when :resbody @@ -1382,7 +1323,7 @@ def synthesize(node, hint: nil) synthesize(klasses) if klasses synthesize(asgn) if asgn body_type = synthesize(body) if body - typing.add_typing(node, body_type) + typing.add_typing(node, body_type, context) end when :ensure @@ -1390,7 +1331,7 @@ def synthesize(node, hint: nil) body, ensure_body = node.children body_type = synthesize(body) if body synthesize(ensure_body) if ensure_body - typing.add_typing(node, union_type(body_type)) + typing.add_typing(node, union_type(body_type), context) end when :masgn @@ -1404,38 +1345,43 @@ def synthesize(node, hint: nil) truthy_vars = node.type == :while ? TypeConstruction.truthy_variables(cond) : Set.new if body - for_loop = for_branch(body, truthy_vars: truthy_vars).with(break_context: BreakContext.new(break_type: nil, next_type: nil)) + for_loop = for_branch(body, + truthy_vars: truthy_vars, + break_context: TypeInference::Context::BreakContext.new( + break_type: nil, + next_type: nil + )) for_loop.synthesize(body) type_env.join!([for_loop.type_env]) end - typing.add_typing(node, AST::Builtin.any_type) + typing.add_typing(node, AST::Builtin.any_type, context) end when :irange, :erange types = node.children.map {|n| synthesize(n)} type = AST::Builtin::Range.instance_type(union_type(*types)) - typing.add_typing(node, type) + typing.add_typing(node, type, context) when :regexp each_child_node(node) do |child| synthesize(child) end - typing.add_typing(node, AST::Builtin::Regexp.instance_type) + typing.add_typing(node, AST::Builtin::Regexp.instance_type, context) when :regopt # ignore - typing.add_typing(node, AST::Builtin.any_type) + typing.add_typing(node, AST::Builtin.any_type, context) when :nth_ref, :back_ref - typing.add_typing(node, AST::Builtin::String.instance_type) + typing.add_typing(node, AST::Builtin::String.instance_type, context) when :or_asgn, :and_asgn yield_self do _, rhs = node.children rhs_type = synthesize(rhs) - typing.add_typing(node, rhs_type) + typing.add_typing(node, rhs_type, context) end when :defined? @@ -1443,7 +1389,7 @@ def synthesize(node, hint: nil) synthesize(child) end - typing.add_typing(node, AST::Builtin.any_type) + typing.add_typing(node, AST::Builtin.any_type, context) when :gvasgn yield_self do @@ -1469,7 +1415,7 @@ def synthesize(node, hint: nil) typing.add_error Errors::FallbackAny.new(node: node) end - typing.add_typing(node, type) + typing.add_typing(node, type, context) end when :block_pass @@ -1499,7 +1445,7 @@ def synthesize(node, hint: nil) type ||= synthesize(node.children[0], hint: hint) - typing.add_typing node, type + typing.add_typing node, type, context end when :blockarg @@ -1508,7 +1454,7 @@ def synthesize(node, hint: nil) synthesize(child) end - typing.add_typing node, AST::Builtin.any_type + typing.add_typing node, AST::Builtin.any_type, context end when :splat, :sclass, :alias @@ -1519,7 +1465,7 @@ def synthesize(node, hint: nil) synthesize(child) end - typing.add_typing node, AST::Builtin.any_type + typing.add_typing node, AST::Builtin.any_type, context end else @@ -1541,14 +1487,14 @@ def type_assignment(var, rhs, node, hint: nil) if rhs expand_alias(synthesize(rhs, hint: type_env.lvar_types[var.name] || hint)) do |rhs_type| node_type = assign_type_to_variable(var, rhs_type, node) - typing.add_typing(node, node_type) + typing.add_typing(node, node_type, context) end else raise lhs_type = variable_type(var) if lhs_type - typing.add_typing(node, lhs_type) + typing.add_typing(node, lhs_type, context) else fallback_to_any node end @@ -1580,7 +1526,7 @@ def type_ivasgn(name, rhs, node) fallback_to_any node end end - typing.add_typing(node, ivar_type) + typing.add_typing(node, ivar_type, context) end def type_masgn(node) @@ -1600,7 +1546,7 @@ def type_masgn(node) end end - typing.add_typing(node, rhs_type) + typing.add_typing(node, rhs_type, context) when rhs_type.is_a?(AST::Types::Tuple) && lhs.children.all? {|a| a.type == :lvasgn || a.type == :ivasgn} lhs.children.each.with_index do |asgn, index| @@ -1628,7 +1574,7 @@ def type_masgn(node) end end - typing.add_typing(node, rhs_type) + typing.add_typing(node, rhs_type, context) when rhs_type.is_a?(AST::Types::Any) fallback_to_any(node) @@ -1658,7 +1604,7 @@ def type_masgn(node) end end - typing.add_typing node, rhs_type + typing.add_typing node, rhs_type, context when rhs_type.is_a?(AST::Types::Union) && rhs_type.types.all? {|type| AST::Builtin::Array.instance_type?(type)} @@ -1691,7 +1637,7 @@ def type_masgn(node) end end - typing.add_typing node, rhs_type + typing.add_typing node, rhs_type, context else Steep.logger.error("Unsupported masgn: #{rhs.type} (#{rhs_type})") @@ -1717,7 +1663,7 @@ def type_lambda(node, block_params:, block_body:, type_hint:) block_annotations: block_annotations, topdown_hint: true) - typing.add_typing node, block_type + typing.add_typing node, block_type, context end def type_send(node, send_node:, block_params:, block_body:, unwrap: false) @@ -1732,7 +1678,7 @@ def type_send(node, send_node:, block_params:, block_body:, unwrap: false) return_type = case receiver_type when AST::Types::Any - typing.add_typing node, AST::Builtin.any_type + typing.add_typing node, AST::Builtin.any_type, context when nil fallback_to_any node @@ -1768,7 +1714,7 @@ def type_send(node, send_node:, block_params:, block_body:, unwrap: false) receiver_type: receiver_type, topdown_hint: true) - typing.add_typing node, return_type + typing.add_typing node, return_type, context else fallback_to_any node do Errors::NoMethod.new(node: node, method: method_name, type: expanded_receiver_type) @@ -1793,7 +1739,7 @@ def type_send(node, send_node:, block_params:, block_body:, unwrap: false) unless typing.has_type?(arg) if arg.type == :splat type = synthesize(arg.children[0]) - typing.add_typing(arg, AST::Builtin::Array.instance_type(type)) + typing.add_typing(arg, AST::Builtin::Array.instance_type(type), context) else synthesize(arg) end @@ -1843,10 +1789,10 @@ def for_block(block_annotations:, param_pairs:, method_return_type:, typing:) end Steep.logger.debug("return_type = #{return_type}") - block_context = BlockContext.new(body_type: block_annotations.block_type) + block_context = TypeInference::Context::BlockContext.new(body_type: block_annotations.block_type) Steep.logger.debug("block_context { body_type: #{block_context.body_type} }") - break_context = BreakContext.new( + break_context = TypeInference::Context::BreakContext.new( break_type: block_annotations.break_type || method_return_type, next_type: block_annotations.block_type ) @@ -1856,13 +1802,15 @@ def for_block(block_annotations:, param_pairs:, method_return_type:, typing:) checker: checker, source: source, annotations: annotations.merge_block_annotations(block_annotations), - type_env: block_type_env, - block_context: block_context, typing: typing, - method_context: method_context, - module_context: module_context, - self_type: block_annotations.self_type || self_type, - break_context: break_context + context: TypeInference::Context.new( + block_context: block_context, + method_context: method_context, + module_context: module_context, + break_context: break_context, + self_type: block_annotations.self_type || self_type, + type_env: block_type_env + ) ), return_type] end @@ -1979,7 +1927,7 @@ def type_method_call(node, method_name:, receiver_type:, method:, args:, block_p this.synthesize(block_body) end - child_typing.add_typing node, AST::Builtin.any_type + child_typing.add_typing node, AST::Builtin.any_type, context [[AST::Builtin.any_type, child_typing, :any]] end @@ -2029,7 +1977,7 @@ def check_keyword_arg(receiver_type:, node:, method_type:, constraints:) when :hash keyword_hash_type = AST::Builtin::Hash.instance_type(AST::Builtin::Symbol.instance_type, AST::Builtin.any_type) - typing.add_typing node, keyword_hash_type + typing.add_typing node, keyword_hash_type, context given_keys = Set.new() @@ -2053,7 +2001,7 @@ def check_keyword_arg(receiver_type:, node:, method_type:, constraints:) params.rest_keywords end - typing.add_typing key_node, AST::Builtin::Symbol.instance_type + typing.add_typing key_node, AST::Builtin::Symbol.instance_type, context given_keys << key_symbol @@ -2159,13 +2107,8 @@ def try_method_type(node, receiver_type:, method_type:, args:, arg_pairs:, block checker: checker, source: source, annotations: annotations, - type_env: type_env, typing: child_typing, - self_type: self_type, - method_context: method_context, - block_context: block_context, - module_context: module_context, - break_context: break_context + context: context ) method_type.instantiate(instantiation).yield_self do |method_type| @@ -2182,7 +2125,7 @@ def try_method_type(node, receiver_type:, method_type:, args:, arg_pairs:, block arg_type = if arg_node.type == :splat type = construction.synthesize(arg_node.children[0]) - child_typing.add_typing(arg_node, type) + child_typing.add_typing(arg_node, type, context) else construction.synthesize(arg_node, hint: topdown_hint ? param_type : nil) end @@ -2364,10 +2307,10 @@ def type_block(block_param_hint:, block_type_hint:, node_type_hint:, block_param end Steep.logger.debug("return_type = #{break_type}") - block_context = BlockContext.new(body_type: block_annotations.block_type) + block_context = TypeInference::Context::BlockContext.new(body_type: block_annotations.block_type) Steep.logger.debug("block_context { body_type: #{block_context.body_type} }") - break_context = BreakContext.new( + break_context = TypeInference::Context::BreakContext.new( break_type: break_type, next_type: block_context.body_type ) @@ -2377,13 +2320,15 @@ def type_block(block_param_hint:, block_type_hint:, node_type_hint:, block_param checker: checker, source: source, annotations: annotations.merge_block_annotations(block_annotations), - type_env: block_type_env, - block_context: block_context, typing: typing, - method_context: method_context, - module_context: module_context, - self_type: block_annotations.self_type || self_type, - break_context: break_context + context: TypeInference::Context.new( + block_context: block_context, + method_context: method_context, + module_context: module_context, + break_context: break_context, + self_type: block_annotations.self_type || self_type, + type_env: block_type_env + ) ) if block_body @@ -2600,7 +2545,7 @@ def fallback_to_any(node) typing.add_error Errors::FallbackAny.new(node: node) end - typing.add_typing node, AST::Builtin.any_type + typing.add_typing node, AST::Builtin.any_type, context end def self_class?(node) @@ -2774,7 +2719,7 @@ def try_hash_type(node, hint) child_typing.save! hash = AST::Types::Record.new(elements: elements) - typing.add_typing(node, hash) + typing.add_typing(node, hash, context) end when AST::Types::Union hint.types.find do |type| diff --git a/lib/steep/type_inference/context.rb b/lib/steep/type_inference/context.rb new file mode 100644 index 000000000..eecfe236a --- /dev/null +++ b/lib/steep/type_inference/context.rb @@ -0,0 +1,107 @@ +module Steep + module TypeInference + class Context + class MethodContext + attr_reader :name + attr_reader :method + attr_reader :method_type + attr_reader :return_type + attr_reader :constructor + attr_reader :super_method + + def initialize(name:, method:, method_type:, return_type:, constructor:, super_method:) + @name = name + @method = method + @return_type = return_type + @method_type = method_type + @constructor = constructor + @super_method = super_method + end + + def block_type + method_type&.block + end + end + + class BlockContext + attr_reader :body_type + + def initialize(body_type:) + @body_type = body_type + end + end + + class BreakContext + attr_reader :break_type + attr_reader :next_type + + def initialize(break_type:, next_type:) + @break_type = break_type + @next_type = next_type + end + end + + class ModuleContext + attr_reader :instance_type + attr_reader :module_type + attr_reader :defined_instance_methods + attr_reader :defined_module_methods + attr_reader :const_env + attr_reader :implement_name + attr_reader :current_namespace + attr_reader :class_name + attr_reader :instance_definition + attr_reader :module_definition + + def initialize(instance_type:, module_type:, implement_name:, current_namespace:, const_env:, class_name:, instance_definition: nil, module_definition: nil) + @instance_type = instance_type + @module_type = module_type + @defined_instance_methods = Set.new + @defined_module_methods = Set.new + @implement_name = implement_name + @current_namespace = current_namespace + @const_env = const_env + @class_name = class_name + @instance_definition = instance_definition + @module_definition = module_definition + end + + def const_context + const_env.context + end + end + + attr_reader :method_context + attr_reader :block_context + attr_reader :break_context + attr_reader :module_context + attr_reader :self_type + attr_reader :type_env + + def initialize(method_context:, block_context:, break_context:, module_context:, self_type:, type_env:) + @method_context = method_context + @block_context = block_context + @break_context = break_context + @module_context = module_context + @self_type = self_type + @type_env = type_env + end + + def with(method_context: self.method_context, + block_context: self.block_context, + break_context: self.break_context, + module_context: self.module_context, + self_type: self.self_type, + type_env: self.type_env) + self.class.new( + method_context: method_context, + block_context: block_context, + break_context: break_context, + module_context: module_context, + self_type: self_type, + type_env: type_env + ) + end + end + end +end diff --git a/lib/steep/typing.rb b/lib/steep/typing.rb index f41c0e53b..7c9eae1f1 100644 --- a/lib/steep/typing.rb +++ b/lib/steep/typing.rb @@ -2,12 +2,11 @@ module Steep class Typing attr_reader :errors attr_reader :typing - attr_reader :nodes - attr_reader :var_typing attr_reader :parent attr_reader :parent_last_update attr_reader :last_update attr_reader :should_update + attr_reader :contexts def initialize(parent: nil, parent_last_update: parent&.last_update) @parent = parent @@ -16,18 +15,17 @@ def initialize(parent: nil, parent_last_update: parent&.last_update) @should_update = false @errors = [] - @nodes = {} - @var_typing = {} - @typing = {} + @typing = {}.compare_by_identity + @contexts = {}.compare_by_identity end def add_error(error) errors << error end - def add_typing(node, type) - typing[node.__id__] = type - nodes[node.__id__] = node + def add_typing(node, type, context) + typing[node] = type + contexts[node] = context if should_update @last_update += 1 @@ -38,11 +36,11 @@ def add_typing(node, type) end def has_type?(node) - typing.key?(node.__id__) + typing.key?(node) end def type_of(node:) - type = typing[node.__id__] + type = typing[node] if type type @@ -55,6 +53,20 @@ def type_of(node:) end end + def context_of(node:) + ctx = contexts[node] + + if ctx + ctx + else + if parent + parent.context_of(node: node) + else + raise "Unknown node for context: #{node.inspect}" + end + end + end + def dump(io) io.puts "Typing: " nodes.each_value do |node| @@ -86,10 +98,8 @@ def new_child end end - def each_typing - nodes.each do |id, node| - yield node, typing[id] - end + def each_typing(&block) + typing.each(&block) end def save! @@ -97,7 +107,7 @@ def save! raise "Parent modified since new_child" unless parent.last_update == parent_last_update each_typing do |node, type| - parent.add_typing(node, type) + parent.add_typing(node, type, contexts[node]) end errors.each do |error| diff --git a/test/langserver_test.rb b/test/langserver_test.rb index c10abffcc..13e1a0160 100644 --- a/test/langserver_test.rb +++ b/test/langserver_test.rb @@ -152,4 +152,160 @@ def foo end end end + + def test_hover + in_tmpdir do + path = current_dir.realpath + + (path + "Steepfile").write < String + | (String x, Symbol y) -> String +end +RBS + (path+"lib").mkdir + (path+"lib/example.rb").write < ::String +``` + +---- + +Returns the name or string corresponding to *sym*. + + :fred.id2name #=> \"fred\" + :ginger.to_s #=> \"ginger\" + + +---- + +- `() -> ::String` +MSG + assert_equal({ start: { line: 2, character: 4 }, end: { line: 2, character: 10 }}, response[:result][:range]) + end + + lsp.send_request( + method: "textDocument/hover", + params: { + textDocument: { + uri: "file://#{path}/lib/example.rb" + }, + position: { + line: 1, + character: 3 + } + } + ) do |response| + assert_equal < ::String +``` + +---- + +foo method processes given argument. + + +---- + +- `(::Integer x) -> ::String` +- `(::String x, ::Symbol y) -> ::String` +EOF + assert_equal({ start: { line: 1, character: 2 }, end: { line: 3, character: 5 }}, response[:result][:range]) + end + end + end + end + end end diff --git a/test/project_test.rb b/test/project_test.rb index 84aeff2f5..a1c01475f 100644 --- a/test/project_test.rb +++ b/test/project_test.rb @@ -5,6 +5,7 @@ class ProjectTest < Minitest::Test include ShellHelper include Steep + HoverContent = Steep::Project::HoverContent def dirs @dirs ||= [] @@ -41,4 +42,151 @@ class Foo assert_equal Set[Pathname("sig/foo.rbs")], Set.new(target.signature_files.keys) end end + + def test_hover_content + in_tmpdir do + project = Project.new(base_dir: current_dir) + Project::DSL.parse(project, < String + | (String x) -> String +end + EOF + + target.type_check + + hover = Project::HoverContent.new(project: project) + + hover.content_for(path: Pathname("hello.rb"), line: 2, column: 10).tap do |content| + assert_instance_of HoverContent::DefinitionContent, content + assert_equal [2,2]...[4, 5], [content.location.line,content.location.column]...[content.location.last_line, content.location.last_column] + assert_equal :do_something, content.method_name + assert_equal "((::Integer | ::String)) -> ::String", content.method_type.to_s + assert_equal ["(::Integer x) -> ::String", "(::String x) -> ::String"], content.definition.method_types.map(&:to_s) + assert_instance_of Ruby::Signature::Definition::Method, content.definition + assert_equal "Do something super for given argument `x`.\n", content.definition.comment.string + end + end + end end diff --git a/test/source_test.rb b/test/source_test.rb index 819a2182a..6df74fc1d 100644 --- a/test/source_test.rb +++ b/test/source_test.rb @@ -413,17 +413,20 @@ def self.foo(bar) end EOF - assert_equal source.node, source.find_node(line: 1, column: 2) # class - assert_equal dig(source.node, 0), source.find_node(line: 1, column: 6) # A - assert_equal dig(source.node, 0), source.find_node(line: 1, column: 7) # A - assert_equal dig(source.node, 2, 0), source.find_node(line: 2, column: 6) # self - assert_equal dig(source.node, 2), source.find_node(line: 2, column: 11) # def - assert_equal dig(source.node, 2, 2, 0), source.find_node(line: 2, column: 15) # bar - assert_equal dig(source.node, 2, 3), source.find_node(line: 4, column: 5) # x - assert_equal dig(source.node, 2, 3, 1), source.find_node(line: 4, column: 8) # 123 - assert_equal dig(source.node, 2, 3, 1), source.find_node(line: 4, column: 9) # 123 - assert_equal dig(source.node, 2, 3, 1), source.find_node(line: 4, column: 10) # 123 - assert_equal dig(source.node, 2, 3, 1), source.find_node(line: 4, column: 11) # 123 + assert_equal [source.node], + source.find_nodes(line: 1, column: 2) # class + assert_equal [dig(source.node, 0), source.node], + source.find_nodes(line: 1, column: 6) # A + assert_equal [dig(source.node, 0), source.node], + source.find_nodes(line: 1, column: 7) # A + assert_equal [dig(source.node, 2, 0), dig(source.node, 2), source.node], + source.find_nodes(line: 2, column: 6) # self + assert_equal [dig(source.node, 2), source.node], + source.find_nodes(line: 2, column: 11) # def + assert_equal [dig(source.node, 2, 2, 0), dig(source.node, 2, 2), dig(source.node, 2), source.node], + source.find_nodes(line: 2, column: 15) # bar + assert_equal [dig(source.node, 2, 3), dig(source.node, 2), source.node], + source.find_nodes(line: 4, column: 5) # x end end end diff --git a/test/type_construction_test.rb b/test/type_construction_test.rb index ab2025199..c60a5e2a8 100644 --- a/test/type_construction_test.rb +++ b/test/type_construction_test.rb @@ -13,6 +13,7 @@ class TypeConstructionTest < Minitest::Test TypeConstruction = Steep::TypeConstruction Annotation = Steep::AST::Annotation Names = Steep::Names + Context = Steep::TypeInference::Context DEFAULT_SIGS = <<-EOS interface _A @@ -70,13 +71,15 @@ def with_standard_construction(checker, source) construction = TypeConstruction.new(checker: checker, source: source, annotations: annotations, - type_env: type_env, - self_type: parse_type("::Object"), - block_context: nil, - method_context: nil, - typing: typing, - module_context: nil, - break_context: nil) + context: Context.new( + block_context: nil, + method_context: nil, + module_context: nil, + break_context: nil, + self_type: parse_type("::Object"), + type_env: type_env + ), + typing: typing) yield construction, typing end @@ -935,7 +938,7 @@ class Steep::Names::Module end const_env: const_env, signatures: checker.factory.env) - module_context = TypeConstruction::ModuleContext.new( + module_context = Context::ModuleContext.new( instance_type: parse_type("::Steep"), module_type: parse_type("singleton(::Steep)"), implement_name: nil, @@ -949,13 +952,15 @@ class Steep::Names::Module end construction = TypeConstruction.new(checker: checker, source: source, annotations: annotations, - type_env: type_env, - self_type: nil, - block_context: nil, - method_context: nil, - typing: typing, - module_context: module_context, - break_context: nil) + context: Context.new( + block_context: nil, + method_context: nil, + module_context: module_context, + break_context: nil, + self_type: nil, + type_env: type_env + ), + typing: typing) for_module = construction.for_class(module_name_class_node) @@ -1024,7 +1029,7 @@ module Steep::Printable end const_env: const_env, signatures: checker.factory.env) - module_context = TypeConstruction::ModuleContext.new( + module_context = Context::ModuleContext.new( instance_type: parse_type("::Steep"), module_type: parse_type("singleton(::Steep)"), implement_name: nil, @@ -1038,13 +1043,15 @@ module Steep::Printable end construction = TypeConstruction.new(checker: checker, source: source, annotations: annotations, - type_env: type_env, - self_type: nil, - block_context: nil, - method_context: nil, - typing: typing, - module_context: module_context, - break_context: nil) + context: Context.new( + block_context: nil, + method_context: nil, + module_context: module_context, + break_context: nil, + self_type: nil, + type_env: type_env + ), + typing: typing) for_module = construction.for_module(module_node) @@ -4007,4 +4014,62 @@ def get: -> A end end end + + def test_context_toplevel + with_checker <<-EOF do |checker| + EOF + source = parse_ruby(<<-EOF) +a = "Hello" +b = 123 + EOF + + with_standard_construction(checker, source) do |construction, typing| + construction.synthesize(source.node) + assert_empty typing.errors + + # a = ... + typing.context_of(node: dig(source.node, 0)).tap do |ctx| + assert_instance_of Context, ctx + assert_nil ctx.module_context + assert_nil ctx.method_context + assert_nil ctx.block_context + assert_nil ctx.break_context + assert_equal parse_type("::Object"), ctx.self_type + end + end + end + end + + def test_context_class + with_checker <<-EOF do |checker| +class Hello +end + EOF + source = parse_ruby(<<-EOF) +class Hello < Object + a = "foo" + b = :bar +end + +b = 123 + EOF + + with_standard_construction(checker, source) do |construction, typing| + construction.synthesize(source.node) + assert_empty typing.errors + + # class Hello + typing.context_of(node: dig(source.node, 0, 2)).tap do |ctx| + assert_instance_of Context, ctx + assert_equal "::Hello", ctx.module_context.class_name.to_s + assert_nil ctx.method_context + assert_nil ctx.block_context + assert_nil ctx.break_context + assert_equal parse_type("singleton(::Hello)"), ctx.self_type + assert_equal parse_type("::String"), ctx.type_env.get(lvar: :a) + assert_equal parse_type("::Symbol"), ctx.type_env.get(lvar: :b) + end + end + end + end end diff --git a/test/typing_test.rb b/test/typing_test.rb index 95bf18ca1..f22d2997b 100644 --- a/test/typing_test.rb +++ b/test/typing_test.rb @@ -2,6 +2,8 @@ class TypingTest < Minitest::Test Typing = Steep::Typing + TypeEnv = Steep::TypeInference::TypeEnv + Context = Steep::TypeInference::Context include TestHelper include FactoryHelper @@ -13,15 +15,25 @@ def around end end + def context + @context ||= Context.new(method_context: nil, + block_context: nil, + break_context: nil, + module_context: nil, + self_type: parse_type("::Object"), + type_env: nil) + end + def test_1 typing = Steep::Typing.new node = parse_ruby("123").node type = parse_method_type("() -> String").return_type - typing.add_typing(node, type) + typing.add_typing(node, type, context) assert_equal type, typing.type_of(node: node) + assert_equal context, typing.context_of(node: node) end def test_new_child_with_save @@ -30,13 +42,13 @@ def test_new_child_with_save node = parse_ruby("123 + 456").node type = parse_method_type("() -> String").return_type - typing.add_typing(node, type) + typing.add_typing(node, type, context) typing.new_child do |typing_| assert_equal type, typing.type_of(node: node) - typing_.add_typing(node.children[0], type) - typing_.add_typing(node.children[1], type) + typing_.add_typing(node.children[0], type, context) + typing_.add_typing(node.children[1], type, context) typing_.save! end @@ -52,13 +64,13 @@ def test_new_child_without_save node = parse_ruby("123 + 456").node type = parse_method_type("() -> String").return_type - typing.add_typing(node, type) + typing.add_typing(node, type, context) typing.new_child do |typing_| assert_equal type, typing.type_of(node: node) - typing_.add_typing(node.children[0], type) - typing_.add_typing(node.children[1], type) + typing_.add_typing(node.children[0], type, context) + typing_.add_typing(node.children[1], type, context) end assert_equal type, typing.type_of(node: node) @@ -72,12 +84,12 @@ def test_new_child_check node = parse_ruby("123 + 456").node type = parse_method_type("() -> String").return_type - typing.add_typing(node, type) + typing.add_typing(node, type, context) child1 = typing.new_child() - child1.add_typing(node.children[0], type) + child1.add_typing(node.children[0], type, context) - typing.add_typing(node.children[1], type) + typing.add_typing(node.children[1], type, context) assert_raises do child1.save! @@ -91,10 +103,10 @@ def test_new_child_check2 type = parse_method_type("() -> String").return_type child1 = typing.new_child() - child1.add_typing(node.children[0], type) + child1.add_typing(node.children[0], type, context) child2 = typing.new_child() - child2.add_typing(node.children[1], type) + child2.add_typing(node.children[1], type, context) child1.save!