Skip to content

Commit

Permalink
Merge pull request #277 from soutaro/type-case
Browse files Browse the repository at this point in the history
Type-case based on literal value
  • Loading branch information
soutaro authored Dec 23, 2020
2 parents d2e3b2a + 202392b commit 978199c
Show file tree
Hide file tree
Showing 7 changed files with 185 additions and 20 deletions.
8 changes: 8 additions & 0 deletions lib/steep/ast/types/factory.rb
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,14 @@ def setup_primitives(method_name, method_def, method_type)
return_type: AST::Types::Logic::ArgIsReceiver.new(location: method_type.type.return_type.location)
)
)
when RBS::BuiltinNames::Object.name, RBS::BuiltinNames::String.name, RBS::BuiltinNames::Integer.name, RBS::BuiltinNames::Symbol.name,
RBS::BuiltinNames::TrueClass.name, RBS::BuiltinNames::FalseClass.name, TypeName("::NilClass")
# Value based type-case works on literal types which is available for String, Integer, Symbol, TrueClass, FalseClass, and NilClass
return method_type.with(
type: method_type.type.with(
return_type: AST::Types::Logic::ArgEqualsReceiver.new(location: method_type.type.return_type.location)
)
)
end
end
end
Expand Down
6 changes: 6 additions & 0 deletions lib/steep/ast/types/logic.rb
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ def initialize(location: nil)
end
end

class ArgEqualsReceiver < Base
def initialize(location: nil)
@location = location
end
end

class Env < Base
attr_reader :truthy, :falsy, :type

Expand Down
1 change: 1 addition & 0 deletions lib/steep/type_construction.rb
Original file line number Diff line number Diff line change
Expand Up @@ -1685,6 +1685,7 @@ def synthesize(node, hint: nil, condition: false)
falsy_env = cond_vars.inject(falsy_env) do |env, var|
env.assign!(var, node: test_node, type: env[first_var])
end

test_envs << truthy_env
test_constr = test_constr.update_lvar_env { falsy_env }
end
Expand Down
103 changes: 84 additions & 19 deletions lib/steep/type_inference/logic_type_interpreter.rb
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def eval(env:, type:, node:)

if receiver
_, arg_vars = decompose_value(arg)
receiver_type = typing.type_of(node: receiver)
receiver_type = factory.deep_expand_alias(typing.type_of(node: receiver))

if receiver_type.is_a?(AST::Types::Name::Singleton)
arg_vars.each do |var_name|
Expand All @@ -109,6 +109,23 @@ def eval(env:, type:, node:)
end
end
end
when AST::Types::Logic::ArgEqualsReceiver
case value_node.type
when :send
receiver, _, arg = value_node.children

if receiver
_, arg_vars = decompose_value(arg)

arg_vars.each do |var_name|
var_type = factory.deep_expand_alias(env[var_name])
truthy_types, falsy_types = literal_var_type_case_select(receiver, var_type)

truthy_env = truthy_env.assign!(var_name, node: node, type: AST::Types::Union.build(types: truthy_types, location: nil))
falsy_env = falsy_env.assign!(var_name, node: node, type: AST::Types::Union.build(types: falsy_types, location: nil))
end
end
end
when AST::Types::Logic::Not
receiver, * = value_node.children
receiver_type = typing.type_of(node: receiver)
Expand Down Expand Up @@ -159,6 +176,60 @@ def decompose_value(node)
end
end

def literal_var_type_case_select(value_node, arg_type)
case arg_type
when AST::Types::Union
truthy_types = []
falsy_types = []

arg_type.types.each do |type|
ts, fs = literal_var_type_case_select(value_node, type)
truthy_types.push(*ts)
falsy_types.push(*fs)
end

[truthy_types, falsy_types]
else
value_type = typing.type_of(node: value_node)
types = [arg_type]

case value_node.type
when :nil
types.partition do |type|
type.is_a?(AST::Types::Nil) || AST::Builtin::NilClass.instance_type?(type)
end
when :true
types.partition do |type|
AST::Builtin::TrueClass.instance_type?(type) ||
(type.is_a?(AST::Types::Literal) && type.value == true)
end
when :false
types.partition do |type|
AST::Builtin::FalseClass.instance_type?(type) ||
(type.is_a?(AST::Types::Literal) && type.value == false)
end
when :int, :str, :sym
types.each.with_object([[], []]) do |type, pair|
true_types, false_types = pair

case
when type.is_a?(AST::Types::Literal)
if type.value == value_node.children[0]
true_types << type
else
false_types << type
end
else
true_types << AST::Types::Literal.new(value: value_node.children[0])
false_types << type
end
end
else
[[arg_type], [arg_type]]
end
end
end

def type_case_select(type, klass)
truth_types, false_types = type_case_select0(type, klass)

Expand Down Expand Up @@ -189,20 +260,6 @@ def type_case_select0(type, klass)

[truthy_types, falsy_types]

when AST::Types::Name::Instance
relation = Subtyping::Relation.new(sub_type: type, super_type: instance_type)
if subtyping.check(relation, constraints: Subtyping::Constraints.empty, self_type: AST::Types::Self.new).success?
[
[type],
[]
]
else
[
[],
[type]
]
end

when AST::Types::Name::Alias
ty = factory.expand_alias(type)
type_case_select0(ty, klass)
Expand All @@ -220,10 +277,18 @@ def type_case_select0(type, klass)
]

else
[
[],
[type]
]
relation = Subtyping::Relation.new(sub_type: type, super_type: instance_type)
if subtyping.check(relation, constraints: Subtyping::Constraints.empty, self_type: AST::Types::Self.new).success?
[
[type],
[]
]
else
[
[],
[type]
]
end
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion smoke/type_case/a.rb
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

case x
when 1
# !expects NoMethodError: type=(::Integer | ::String | ::Symbol), method=foobar
# !expects NoMethodError: type=1, method=foobar
x.foobar
end

Expand Down
3 changes: 3 additions & 0 deletions test/test_helper.rb
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def to_s: -> String
def nil?: -> bool
def itself: -> self
def is_a?: (Module) -> bool
def ===: (untyped) -> bool
private
def require: (String) -> void
Expand Down Expand Up @@ -246,6 +247,8 @@ def `-@`: -> String
class Numeric
def `+`: (Numeric) -> Numeric
def to_int: -> Integer
def zero?: () -> bool
end
class Integer < Numeric
Expand Down
82 changes: 82 additions & 0 deletions test/type_construction_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -6961,4 +6961,86 @@ def ==(other)
end
end
end

def test_value_case
with_checker(<<RBS) do |checker|
type allowed_key = :foo | :bar | nil | Integer
RBS
source = parse_ruby(<<RUBY)
# @type var x: allowed_key
x = nil
# @type var y: nil
# @type var z: Symbol
case x
when nil
y = x
when :foo
z = x
when Symbol
z = x
when Integer
x + 1
end
RUBY

with_standard_construction(checker, source) do |construction, typing|
type, _ = construction.synthesize(source.node)
assert_no_error typing
end
end
end

def test_value_case2
with_checker(<<RBS) do |checker|
type allowed_key = :foo | :bar
RBS
source = parse_ruby(<<RUBY)
# @type var x: allowed_key
x = _ = nil
# @type var a: bool
a = case x
when :foo
true
when :bar
false
end
RUBY

with_standard_construction(checker, source) do |construction, typing|
type, _ = construction.synthesize(source.node)
assert_no_error typing
end
end
end

def test_value_case3
with_checker(<<RBS) do |checker|
type allowed_key = Integer | String
RBS
source = parse_ruby(<<RUBY)
# @type var x: allowed_key
x = _ = nil
# @type var a: bool
a = case x
when 1
(x + 1).zero?
when "2"
(x + "").size.zero?
when Integer, String
false
end
RUBY

with_standard_construction(checker, source) do |construction, typing|
type, _ = construction.synthesize(source.node)
assert_no_error typing
end
end
end
end

0 comments on commit 978199c

Please sign in to comment.