Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better union typing #204

Merged
merged 3 commits into from
Sep 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 50 additions & 44 deletions lib/steep/ast/types/factory.rb
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,10 @@ def interface(type, private:, self_type: type)
definition.methods.each do |name, method|
next if method.private? && !private

interface.methods[name] = Interface::Interface::Combination.overload(
method.method_types.map do |type|
interface.methods[name] = Interface::Interface::Entry.new(
method_types: method.method_types.map do |type|
method_type(type, self_type: self_type, subst2: subst)
end,
incompatible: name == :initialize || name == :new
end
)
end
end
Expand All @@ -363,11 +362,10 @@ def interface(type, private:, self_type: type)
)

definition.methods.each do |name, method|
interface.methods[name] = Interface::Interface::Combination.overload(
method.method_types.map do |type|
interface.methods[name] = Interface::Interface::Entry.new(
method_types: method.method_types.map do |type|
method_type(type, self_type: self_type, subst2: subst)
end,
incompatible: false
end
)
end
end
Expand All @@ -389,11 +387,10 @@ def interface(type, private:, self_type: type)
definition.methods.each do |name, method|
next if !private && method.private?

interface.methods[name] = Interface::Interface::Combination.overload(
method.method_types.map do |type|
interface.methods[name] = Interface::Interface::Entry.new(
method_types: method.method_types.map do |type|
method_type(type, self_type: self_type, subst2: subst)
end,
incompatible: false
end
)
end
end
Expand All @@ -416,7 +413,25 @@ def interface(type, private:, self_type: type)
Interface::Interface.new(type: self_type, private: private).tap do |interface|
common_methods = Set.new(interface1.methods.keys) & Set.new(interface2.methods.keys)
common_methods.each do |name|
interface.methods[name] = Interface::Interface::Combination.union([interface1.methods[name], interface2.methods[name]])
types1 = interface1.methods[name].method_types
types2 = interface2.methods[name].method_types

if types1 == types2
interface.methods[name] = interface1.methods[name]
else
method_types = {}

types1.each do |type1|
types2.each do |type2|
type = type1 | type2 or next
method_types[type] = true
end
end

unless method_types.empty?
interface.methods[name] = Interface::Interface::Entry.new(method_types: method_types.keys)
end
end
end
end
end
Expand All @@ -427,11 +442,8 @@ def interface(type, private:, self_type: type)
interfaces = type.types.map {|ty| interface(ty, private: private, self_type: self_type) }
interfaces.inject do |interface1, interface2|
Interface::Interface.new(type: self_type, private: private).tap do |interface|
all_methods = Set.new(interface1.methods.keys) + Set.new(interface2.methods.keys)
all_methods.each do |name|
methods = [interface1.methods[name], interface2.methods[name]].compact
interface.methods[name] = Interface::Interface::Combination.intersection(methods)
end
interface.methods.merge!(interface1.methods)
interface.methods.merge!(interface2.methods)
end
end
end
Expand All @@ -442,8 +454,8 @@ def interface(type, private:, self_type: type)
array_type = Builtin::Array.instance_type(element_type)
interface(array_type, private: private, self_type: self_type).tap do |array_interface|
array_interface.methods[:[]] = array_interface.methods[:[]].yield_self do |aref|
Interface::Interface::Combination.overload(
type.types.map.with_index {|elem_type, index|
Interface::Interface::Entry.new(
method_types: type.types.map.with_index {|elem_type, index|
Interface::MethodType.new(
type_params: [],
params: Interface::Params.new(required: [AST::Types::Literal.new(value: index)],
Expand All @@ -456,14 +468,13 @@ def interface(type, private:, self_type: type)
return_type: elem_type,
location: nil
)
} + aref.types,
incompatible: false
} + aref.method_types
)
end

array_interface.methods[:[]=] = array_interface.methods[:[]=].yield_self do |update|
Interface::Interface::Combination.overload(
type.types.map.with_index {|elem_type, index|
Interface::Interface::Entry.new(
method_types: type.types.map.with_index {|elem_type, index|
Interface::MethodType.new(
type_params: [],
params: Interface::Params.new(required: [AST::Types::Literal.new(value: index), elem_type],
Expand All @@ -476,38 +487,35 @@ def interface(type, private:, self_type: type)
return_type: elem_type,
location: nil
)
} + update.types,
incompatible: false
} + update.method_types
)
end

array_interface.methods[:first] = array_interface.methods[:first].yield_self do |first|
Interface::Interface::Combination.overload(
[
Interface::Interface::Entry.new(
method_types: [
Interface::MethodType.new(
type_params: [],
params: Interface::Params.empty,
block: nil,
return_type: type.types[0] || AST::Builtin.nil_type,
location: nil
)
],
incompatible: false
]
)
end

array_interface.methods[:last] = array_interface.methods[:last].yield_self do |last|
Interface::Interface::Combination.overload(
[
Interface::Interface::Entry.new(
method_types: [
Interface::MethodType.new(
type_params: [],
params: Interface::Params.empty,
block: nil,
return_type: type.types.last || AST::Builtin.nil_type,
location: nil
)
],
incompatible: false
]
)
end
end
Expand All @@ -523,8 +531,8 @@ def interface(type, private:, self_type: type)

interface(hash_type, private: private, self_type: self_type).tap do |hash_interface|
hash_interface.methods[:[]] = hash_interface.methods[:[]].yield_self do |ref|
Interface::Interface::Combination.overload(
type.elements.map {|key_value, value_type|
Interface::Interface::Entry.new(
method_types: type.elements.map {|key_value, value_type|
key_type = Literal.new(value: key_value, location: nil)
Interface::MethodType.new(
type_params: [],
Expand All @@ -538,14 +546,13 @@ def interface(type, private:, self_type: type)
return_type: value_type,
location: nil
)
} + ref.types,
incompatible: false
} + ref.method_types
)
end

hash_interface.methods[:[]=] = hash_interface.methods[:[]=].yield_self do |update|
Interface::Interface::Combination.overload(
type.elements.map {|key_value, value_type|
Interface::Interface::Entry.new(
method_types: type.elements.map {|key_value, value_type|
key_type = Literal.new(value: key_value, location: nil)
Interface::MethodType.new(
type_params: [],
Expand All @@ -559,8 +566,7 @@ def interface(type, private:, self_type: type)
return_type: value_type,
location: nil
)
} + update.types,
incompatible: false
} + update.method_types
)
end
end
Expand All @@ -576,8 +582,8 @@ def interface(type, private:, self_type: type)
location: nil
)

interface.methods[:[]] = Interface::Interface::Combination.overload([method_type], incompatible: false)
interface.methods[:call] = Interface::Interface::Combination.overload([method_type], incompatible: false)
interface.methods[:[]] = Interface::Interface::Entry.new(method_types: [method_type])
interface.methods[:call] = Interface::Interface::Entry.new(method_types: [method_type])
end

else
Expand Down
21 changes: 12 additions & 9 deletions lib/steep/ast/types/intersection.rb
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,36 @@ def self.build(types:, location: nil)
else
type
end
end.compact.uniq.yield_self do |tys|
if tys.size == 1
end.compact.yield_self do |tys|
dups = Set.new(tys)

case dups.size
when 0
AST::Types::Top.new(location: location)
when 1
tys.first
else
new(types: tys.sort_by(&:hash), location: location)
new(types: dups.to_a, location: location)
end
end
end

def ==(other)
other.is_a?(Intersection) &&
other.types == types
other.is_a?(Intersection) && other.types == types
end

def hash
self.class.hash ^ types.hash
@hash ||= self.class.hash ^ types.hash
end

alias eql? ==

def subst(s)
self.class.build(location: location,
types: types.map {|ty| ty.subst(s) })
self.class.build(location: location, types: types.map {|ty| ty.subst(s) })
end

def to_s
"(#{types.map(&:to_s).sort.join(" & ")})"
"(#{types.map(&:to_s).join(" & ")})"
end

def free_variables()
Expand Down
11 changes: 5 additions & 6 deletions lib/steep/ast/types/union.rb
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,28 @@ def self.build(types:, location: nil)
when 1
tys.first
else
new(types: tys.sort_by(&:hash), location: location)
new(types: tys, location: location)
end
end
end

def ==(other)
other.is_a?(Union) &&
other.types == types
Set.new(other.types) == Set.new(types)
end

def hash
self.class.hash ^ types.hash
@hash ||= self.class.hash ^ types.sort_by(&:to_s).hash
end

alias eql? ==

def subst(s)
self.class.build(location: location,
types: types.map {|ty| ty.subst(s) })
self.class.build(location: location, types: types.map {|ty| ty.subst(s) })
end

def to_s
"(#{types.map(&:to_s).sort.join(" | ")})"
"(#{types.map(&:to_s).join(" | ")})"
end

def free_variables
Expand Down
67 changes: 5 additions & 62 deletions lib/steep/interface/interface.rb
Original file line number Diff line number Diff line change
@@ -1,72 +1,15 @@
module Steep
module Interface
class Interface
class Combination
attr_reader :operator
attr_reader :types
class Entry
attr_reader :method_types

def initialize(operator:, types:)
@types = types
@operator = operator
@incompatible = false
end

def overload?
operator == :overload
end

def union?
operator == :union
end

def intersection?
operator == :intersection
end

def self.overload(types, incompatible:)
new(operator: :overload, types: types).incompatible!(incompatible)
end

def incompatible?
@incompatible
end

def incompatible!(value)
@incompatible = value
self
end

def self.union(types)
case types.size
when 0
raise "Combination.union called with zero types"
when 1
types.first
else
new(operator: :union, types: types)
end
end

def self.intersection(types)
case types.size
when 0
raise "Combination.intersection called with zero types"
when 1
types.first
else
new(operator: :intersection, types: types)
end
def initialize(method_types:)
@method_types = method_types
end

def to_s
case operator
when :overload
"{ #{types.map(&:to_s).join(" | ")} }"
when :union
"[#{types.map(&:to_s).join(" | ")}]"
when :intersection
"[#{types.map(&:to_s).join(" & ")}]"
end
"{ #{method_types.join(" || ")} }"
end
end

Expand Down
Loading