Skip to content

Commit

Permalink
Merge pull request #182 from soutaro/improvements
Browse files Browse the repository at this point in the history
Support for and class variables
  • Loading branch information
soutaro authored Aug 15, 2020
2 parents 019cc7c + 8149d84 commit 6754152
Show file tree
Hide file tree
Showing 7 changed files with 269 additions and 12 deletions.
6 changes: 4 additions & 2 deletions lib/steep/project/target.rb
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,10 @@ def run_type_check(env, check, timestamp, target_sources: source_files.values)
type_check_sources = []

target_sources.each do |file|
if file.type_check(check, timestamp)
type_check_sources << file
Steep.logger.tagged("path=#{file.path}") do
if file.type_check(check, timestamp)
type_check_sources << file
end
end
end

Expand Down
108 changes: 102 additions & 6 deletions lib/steep/type_construction.rb
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,15 @@ def synthesize(node, hint: nil)

constr.add_typing(node, type: type)

when :cvasgn
var_node = lhs.updated(:cvar)
send_node = rhs.updated(:send, [var_node, op, rhs])
new_node = node.updated(:cvasgn, [lhs.children[0], send_node])

type, constr = synthesize(new_node, hint: hint)

constr.add_typing(node, type: type)

else
Steep.logger.error("Unexpected op_asgn lhs: #{lhs.type}")

Expand Down Expand Up @@ -1001,7 +1010,9 @@ def synthesize(node, hint: nil)
type = context.lvar_env[var.name]
unless type
type = AST::Builtin.any_type
Steep.logger.error { "Unknown arg type: #{node}" }
if context&.method_context&.method_type
Steep.logger.error { "Unknown arg type: #{node}" }
end
end
add_typing(node, type: type)
end
Expand Down Expand Up @@ -1034,7 +1045,9 @@ def synthesize(node, hint: nil)
var = node.children[0]
type = context.lvar_env[var.name]
unless type
Steep.logger.error { "Unknown variable: #{node}" }
if context&.method_context&.method_type
Steep.logger.error { "Unknown variable: #{node}" }
end
typing.add_error Errors::FallbackAny.new(node: node)
type = AST::Builtin::Array.instance_type(AST::Builtin.any_type)
end
Expand All @@ -1047,7 +1060,9 @@ def synthesize(node, hint: nil)
var = node.children[0]
type = context.lvar_env[var.name]
unless type
Steep.logger.error { "Unknown variable: #{node}" }
if context&.method_context&.method_type
Steep.logger.error { "Unknown variable: #{node}" }
end
typing.add_error Errors::FallbackAny.new(node: node)
type = AST::Builtin::Hash.instance_type(AST::Builtin::Symbol.instance_type, AST::Builtin.any_type)
end
Expand Down Expand Up @@ -1647,6 +1662,49 @@ def synthesize(node, hint: nil)
when :masgn
type_masgn(node)

when :for
yield_self do
asgn, collection, body = node.children

collection_type, constr = synthesize(collection)
collection_type = expand_self(collection_type)

var_type = case collection_type
when AST::Types::Any
AST::Types::Any.new
else
each = checker.factory.interface(collection_type, private: true).methods[:each]
method_type = (each&.types || []).find {|type| type.block && type.block.type.params.first_param }
method_type&.yield_self do |method_type|
method_type.block.type.params.first_param&.type
end
end

if var_type
if body
body_constr = constr.with_updated_context(
lvar_env: constr.context.lvar_env.assign(asgn.children[0].name, node: asgn, type: var_type)
)

typing.add_context_for_body(node, context: body_constr.context)
_, _, body_context = body_constr.synthesize(body)

constr = constr.update_lvar_env {|env| env.join(constr.context.lvar_env, body_context.lvar_env) }
else
constr = self
end

add_typing(node, type: collection_type, constr: constr)
else
fallback_to_any(node) do
Errors::NoMethod.new(
node: node,
method: :each,
type: collection_type
)
end
end
end
when :while, :until
yield_self do
cond, body = node.children
Expand Down Expand Up @@ -1806,9 +1864,47 @@ def synthesize(node, hint: nil)
add_typing node, type: AST::Builtin.any_type
end

when :cvasgn
name, rhs = node.children

type, constr = synthesize(rhs, hint: hint)

var_type = if module_context&.class_variables
module_context.class_variables[name]&.yield_self {|ty| checker.factory.type(ty) }
end

if var_type
result = constr.check_relation(sub_type: type, super_type: var_type)

if result.success?
add_typing node, type: type, constr: constr
else
fallback_to_any node do
Errors::IncompatibleAssignment.new(node: node,
lhs_type: var_type,
rhs_type: type,
result: result)
end
end
else
fallback_to_any(node)
end

when :cvar
name = node.children[0]
var_type = if module_context&.class_variables
module_context.class_variables[name]&.yield_self {|ty| checker.factory.type(ty) }
end

if var_type
add_typing node, type: var_type
else
fallback_to_any node
end

when :splat, :sclass, :alias
yield_self do
Steep.logger.error "Unsupported node #{node.type} (#{node.location.expression.source_buffer.name}:#{node.location.expression.line})"
Steep.logger.warn { "Unsupported node #{node.type} (#{node.location.expression.source_buffer.name}:#{node.location.expression.line})" }

each_child_node node do |child|
synthesize(child)
Expand Down Expand Up @@ -1885,7 +1981,7 @@ def type_masgn(node)
end
add_typing(lhs,
type: type,
constr: ctr.with_updated_context(lvar_env: env))
constr: ctr.with_updated_context(lvar_env: env)).constr
when :ivasgn
type_ivasgn(lhs.children.first, rhs, lhs)
constr
Expand Down Expand Up @@ -2038,7 +2134,7 @@ def type_send(node, send_node:, block_params:, block_body:, unwrap: false)
else
case expanded_receiver_type = expand_self(receiver_type)
when AST::Types::Self
Steep.logger.error "`self` type cannot be resolved to concrete type"
Steep.logger.debug { "`self` type cannot be resolved to concrete type" }
fallback_to_any node do
Errors::NoMethod.new(node: node, method: method_name, type: receiver_type)
end
Expand Down
2 changes: 1 addition & 1 deletion lib/steep/type_inference/constant_env.rb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def lookup(name)
factory.type(constant.type)
end
rescue => exn
Steep.logger.error "Looking up a constant failed: name=#{name}, context=[#{context.join(", ")}], error=#{exn.inspect}"
Steep.logger.debug "Looking up a constant failed: name=#{name}, context=[#{context.join(", ")}], error=#{exn.inspect}"
nil
end
end
Expand Down
8 changes: 8 additions & 0 deletions lib/steep/type_inference/context.rb
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ def initialize(instance_type:, module_type:, implement_name:, current_namespace:
def const_context
const_env.context
end

def class_variables
if module_definition
@class_variables ||= module_definition.class_variables.transform_values do |var_def|
var_def.type
end
end
end
end

attr_reader :method_context
Expand Down
7 changes: 4 additions & 3 deletions lib/steep/type_inference/context_array.rb
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ def self.from_source(source:, range: nil, context: nil)

def insert_context(range, context:, entry: self.root)
entry.sub_entries.each do |sub|
next if sub.range.begin < range.begin && range.end <= sub.range.end
next if range.begin < sub.range.begin && sub.range.end <= range.end
next if sub.range.begin <= range.begin && range.end <= sub.range.end
next if range.begin <= sub.range.begin && sub.range.end <= range.end
next if range.end <= sub.range.begin
next if sub.range.end <= range.begin

raise "Range crossing: sub range=#{sub.range}, new range=#{range}"
Steep.logger.error { "Range crossing: sub range=#{sub.range}, new range=#{range}" }
raise
end

sup = entry.sub_entries.find do |sub|
Expand Down
7 changes: 7 additions & 0 deletions lib/steep/typing.rb
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,13 @@ def add_context_for_body(node, context:)
end_pos = node.loc.end.begin_pos
add_context(begin_pos..end_pos, context: context)

when :for
_, collection, _ = node.children

begin_pos = collection.loc.expression.end_pos
end_pos = node.loc.end.begin_pos

add_context(begin_pos..end_pos, context: context)
else
raise "Unexpected node for insert_context: #{node.type}"
end
Expand Down
143 changes: 143 additions & 0 deletions test/type_construction_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,63 @@ def test_post_loop
end
end

def test_for_0
with_checker <<-'RBS' do |checker|
RBS
source = parse_ruby(<<-'RUBY')
for x in [1,2,3]
y = x + 1
end
puts y.to_s
RUBY

with_standard_construction(checker, source) do |construction, typing|
_, _, context = construction.synthesize(source.node)

assert_no_error typing

assert_equal parse_type("::Integer?"), context.lvar_env[:y]
end
end
end

def test_for_1
with_checker <<-'RBS' do |checker|
RBS
source = parse_ruby(<<-'RUBY')
for x in [1]
end
RUBY

with_standard_construction(checker, source) do |construction, typing|
_, _, context = construction.synthesize(source.node)

assert_no_error typing
end
end
end

def test_for_2
with_checker <<-'RBS' do |checker|
RBS
source = parse_ruby(<<-'RUBY')
for x in self
end
RUBY

with_standard_construction(checker, source) do |construction, typing|
_, _, context = construction.synthesize(source.node)

assert_equal 1, typing.errors.size

assert_any!(typing.errors) do |error|
assert_instance_of Steep::Errors::NoMethod, error
end
end
end
end

def test_range
with_checker do |checker|
source = parse_ruby(<<-EOF)
Expand Down Expand Up @@ -4830,4 +4887,90 @@ def test_tuple_typing
end
end
end

def test_class_variables
with_checker <<-'RBS' do |checker|
class Object
def ==: (untyped) -> bool
end
class TypeVariable
@@index: Integer
attr_reader name: String
def initialize: (String name) -> void
def self.fresh: () -> instance
def last?: () -> bool
end
RBS
source = parse_ruby(<<-'RUBY')
class TypeVariable
@@index = 0
def name
@name
end
def initialize(name)
@name = name
end
def last?
name == "#{@@index}"
end
def self.fresh
@@index += 1
new("#{@@index}")
end
end
RUBY

with_standard_construction(checker, source) do |construction, typing|
construction.synthesize(source.node)

assert_no_error typing
end
end
end

def test_class_variables_error
with_checker <<-'RBS' do |checker|
class TypeVariable
@@index: Integer
end
RBS
source = parse_ruby(<<-'RUBY')
class TypeVariable
@@no_error = @@unknown_error2
@@index = ""
end
RUBY

with_standard_construction(checker, source) do |construction, typing|
construction.synthesize(source.node)

assert_equal 3, typing.errors.size

assert_any!(typing.errors) do |error|
assert_instance_of Steep::Errors::FallbackAny, error
assert_equal :cvasgn, error.node.type
end

assert_any!(typing.errors) do |error|
assert_instance_of Steep::Errors::FallbackAny, error
assert_equal :cvar, error.node.type
end

assert_any!(typing.errors) do |error|
assert_instance_of Steep::Errors::IncompatibleAssignment, error
assert_equal :cvasgn, error.node.type
end
end
end
end
end

0 comments on commit 6754152

Please sign in to comment.