From 81687f306781189bd00dcd611eb3d2cf465fd128 Mon Sep 17 00:00:00 2001 From: Simeon Schaub Date: Fri, 12 Feb 2021 22:25:57 +0100 Subject: [PATCH] improve constraint propagation with multiple || (#39618) fixes #39611 (cherry picked from commit abd56cdfbdd3ecf3806f4b4117479cdeff26a3c2) --- src/julia-syntax.scm | 40 +++++++++++++++++++++++++------------- test/broadcast.jl | 19 +++++------------- test/compiler/inference.jl | 8 ++++++++ 3 files changed, 39 insertions(+), 28 deletions(-) diff --git a/src/julia-syntax.scm b/src/julia-syntax.scm index 9b4b6abc27f710..4debe7324598e5 100644 --- a/src/julia-syntax.scm +++ b/src/julia-syntax.scm @@ -1926,8 +1926,8 @@ (blk? (and (pair? test) (eq? (car test) 'block))) (stmts (if blk? (cdr (butlast test)) '())) (test (if blk? (last test) test))) - (if (and (pair? test) (eq? (car test) '&&)) - (let ((clauses `(&& ,@(map expand-forms (cdr (flatten-ex '&& test)))))) + (if (and (pair? test) (memq (car test) '(&& |\|\||))) + (let ((clauses `(,(car test) ,@(map expand-forms (cdr (flatten-ex (car test) test)))))) `(if ,(if blk? `(block ,@(map expand-forms stmts) ,clauses) clauses) @@ -4086,18 +4086,30 @@ f(x) = yt(x) (compile (cadr e) break-labels value tail) #f)) ((if elseif) - (let ((tests (let* ((cond (cadr e)) - (cond (if (and (pair? cond) (eq? (car cond) 'block)) - (begin (if (length> cond 2) (compile (butlast cond) break-labels #f #f)) - (last cond)) - cond))) - (map (lambda (clause) - (emit `(gotoifnot ,(compile-cond clause break-labels) _))) - (if (and (pair? cond) (eq? (car cond) '&&)) - (cdr cond) - (list cond))))) - (end-jump `(goto _)) - (val (if (and value (not tail)) (new-mutable-var) #f))) + (let* ((cnd (cadr e)) + (cnd (if (and (pair? cnd) (eq? (car cnd) 'block)) + (begin (if (length> cnd 2) (compile (butlast cnd) break-labels #f #f)) + (last cnd)) + cnd)) + (or? (and (pair? cnd) (eq? (car cnd) '|\|\||))) + (tests (if or? + (let ((short-circuit `(goto _))) + (for-each + (lambda (clause) + (let ((jmp (emit `(gotoifnot ,(compile-cond clause break-labels) _)))) + (emit short-circuit) + (set-car! (cddr jmp) (make&mark-label)))) + (butlast (cdr cnd))) + (let ((last-jmp (emit `(gotoifnot ,(compile-cond (last (cdr cnd)) break-labels) _)))) + (set-car! (cdr short-circuit) (make&mark-label)) + (list last-jmp))) + (map (lambda (clause) + (emit `(gotoifnot ,(compile-cond clause break-labels) _))) + (if (and (pair? cnd) (eq? (car cnd) '&&)) + (cdr cnd) + (list cnd))))) + (end-jump `(goto _)) + (val (if (and value (not tail)) (new-mutable-var) #f))) (let ((v1 (compile (caddr e) break-labels value tail))) (if val (emit-assignment val v1)) (if (and (not tail) (or (length> e 3) val)) diff --git a/test/broadcast.jl b/test/broadcast.jl index dff306ee27c113..0cfede78afadff 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -516,7 +516,7 @@ Base.BroadcastStyle(::Type{T}) where {T<:AD2Dim} = AD2DimStyle() @test a .+ 1 .* 2 == @inferred(fadd2(aa)) @test a .* a' == @inferred(fprod(aa)) @test isequal(a .+ [missing; 1:9], fadd3(aa)) - @test_broken Core.Compiler.return_type(fadd3, (typeof(aa),)) <: Array19745{<:Union{Float64, Missing}} + @test Core.Compiler.return_type(fadd3, (typeof(aa),)) <: Array19745{<:Union{Float64, Missing}} @test isa(aa .+ 1, Array19745) @test isa(aa .+ 1 .* 2, Array19745) @test isa(aa .* aa', Array19745) @@ -953,29 +953,20 @@ p0 = copy(p) @testset "Issue #28382: inferrability of broadcast with Union eltype" begin @test isequal([1, 2] .+ [3.0, missing], [4.0, missing]) - @test_broken Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int}, - Vector{Union{Float64, Missing}}}) == - Vector{<:Union{Float64, Missing}} @test Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int}, Vector{Union{Float64, Missing}}}) == - AbstractVector{<:Union{Float64, Missing}} + Vector{<:Union{Float64, Missing}} @test isequal([1, 2] + [3.0, missing], [4.0, missing]) - @test_broken Core.Compiler.return_type(+, Tuple{Vector{Int}, - Vector{Union{Float64, Missing}}}) == + @test Core.Compiler.return_type(+, Tuple{Vector{Int}, + Vector{Union{Float64, Missing}}}) == Vector{<:Union{Float64, Missing}} @test Core.Compiler.return_type(+, Tuple{Vector{Int}, Vector{Union{Float64, Missing}}}) == - AbstractVector{<:Union{Float64, Missing}} - @test_broken Core.Compiler.return_type(+, Tuple{Vector{Int}, - Vector{Union{Float64, Missing}}}) == Vector{<:Union{Float64, Missing}} @test isequal(tuple.([1, 2], [3.0, missing]), [(1, 3.0), (2, missing)]) - @test_broken Core.Compiler.return_type(broadcast, Tuple{typeof(tuple), Vector{Int}, - Vector{Union{Float64, Missing}}}) == - Vector{<:Tuple{Int, Any}} @test Core.Compiler.return_type(broadcast, Tuple{typeof(tuple), Vector{Int}, Vector{Union{Float64, Missing}}}) == - AbstractVector{<:Tuple{Int, Any}} + Vector{<:Tuple{Int, Any}} # Check that corner cases do not throw an error @test isequal(broadcast(x -> x === 1 ? nothing : x, [1, 2, missing]), [nothing, 2, missing]) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index a2cc39e98e9b2e..1cbf36810f4d56 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -2994,3 +2994,11 @@ end # issue #40804 @test Base.return_types(()) do; ===(); end == Any[Union{}] @test Base.return_types(()) do; typeassert(); end == Any[Union{}] + +# issue #39611 +Base.return_types((Union{Int,Nothing},)) do x + if x === nothing || x < 0 + return 0 + end + x +end == [Int]