Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate constant calls to new! #28284

Merged
merged 4 commits into from
Sep 17, 2018
Merged

Conversation

willow-ahrens
Copy link
Contributor

Please infer that calls to new with constant arguments are constant.

A MWE. Without this feature:

julia> fizzle(x) = Val(Base.OneTo(3))
fizzle (generic function with 1 method)

julia> @code_typed fizzle(1)
CodeInfo(
1 1 ─      Base.ifelse(false, 0, 3)     │╻╷╷ Type
  │   %2 = new(Base.OneTo{Int64}, 3)::Base.OneTo{Int64}     ││┃   Type
  │   %3 = invoke Main.Val(%2::Base.OneTo{Int64})::Val{_1} where _1     │
  └──      return %3     │
) => Val{_1} where _1

With this feature:

julia> fizzle(x) = Val(Base.OneTo(3))
fizzle (generic function with 1 method)

julia> @code_typed fizzle(1)
CodeInfo(
1 1 ─     Base.ifelse(false, 0, 3)                                        │╻╷╷ Type
  └──     return :($(QuoteNode(Val{Base.OneTo(3)}())))                    │
) => Val{Base.OneTo(3)}

@willow-ahrens
Copy link
Contributor Author

willow-ahrens commented Jul 26, 2018

Additionally, I request that we only merge this if it doesn't break Cassette. It could help Cassette IDK @jrevels? Anyways, I can't think of any other way to pass a constant value from the runtime to a generated function than to lift it to the type domain like this.

@Keno
Copy link
Member

Keno commented Jul 26, 2018

This direction seems generally fine with me, but I'm not sure I like the try/catch too much. I wonder if we should have a version of jl_struct_v that signals its error rather than raising it.

@willow-ahrens
Copy link
Contributor Author

willow-ahrens commented Jul 26, 2018

Yeah, @vtjnash and I couldn't really figure out which errors to rethrow. Any thoughts?

flds = Any[ argtypes[i].val for i = 2:length(argtypes) ]
try
t = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), t, flds, length(flds)))
catch ex
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to check for the error conditions explicitly; it's faster than try/catch and would allow returning Bottom for errors. I think all you need to check is that each value is of the corresponding field type. Elsewhere we assume that the number of arguments is correct.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that's correct. Plus it technically lets you do cool thinks like return a Const inferred value even if one of the values is only inferred as a union (as long as one of the union elements matches precisely - and they're both singletons at the moment).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was worried there might be another failure possibility (OOM? others?), but yeah, I guess it's just the type error, and we can check for that in the for-loop.

IIRC, the number of arguments is checked syntactically, so it would be invalid IR to discover otherwise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'm still a little new at this and I couldn't figure it out after looking around: anyone know which function would give the field types of the struct?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fieldtype(type, index) gives you the type of a field.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, thanks! All fixed up, I think.

@JeffBezanson JeffBezanson added the compiler:inference Type inference label Jul 26, 2018
@jrevels
Copy link
Member

jrevels commented Jul 26, 2018

Additionally, I request that we only merge this if it doesn't break Cassette. It could help Cassette IDK @jrevels?

IIUC, it shouldn't break Cassette since Cassette intercepts new calls for isbitstype types in the lowered IR pre-inference 🙂

end
isconst &= ae isa Const
isconst = isconst && (ae.val isa fieldtype(t, i - 1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if !(ae.val isa fieldtype(t, i - 1)); t = Bottom; end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, also fieldtype(Bottom, i) will error.

Copy link
Member

@Keno Keno left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add your test case, e.g. to test/core.jl

end
isconst &= ae isa Const
isconst = isconst && (ae.val isa fieldtype(t, i - 1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, also fieldtype(Bottom, i) will error.

@JeffBezanson
Copy link
Member

Inference test cases should go in test/compiler/compiler.jl.

@JeffBezanson
Copy link
Member

This might be a good implementation:

        t = instanceof_tfunc(abstract_eval(e.args[1], vtypes, sv))[1]
        if isbitstype(t)
            args = Vector{Any}(undef, length(e.args)-1)
            isconst = true
            for i = 2:length(e.args)
                at = abstract_eval(e.args[i], vtypes, sv)
                if at === Bottom
                    t = Bottom
                    isconst = false
                    break
                elseif at isa Const
                    if !(at.val isa fieldtype(t, i - 1))
                        t = Bottom
                        isconst = false
                        break
                    end
                    args[i-1] = at.val
                else
                    isconst = false
                end
            end
            if isconst
                t = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), t, args, length(args)))
            end
        end

The idea is:

  • When isbitstype(t), don't bother doing the rest of the work.
  • I'm guessing it's pretty rare for an argument to :new to be Bottom
  • It's also somewhat rare for an argument to be a constant, so try to do as little work as possible before discovering a non-Const argument.

In fact it's possible we should break out as soon as we hit non-Const argument.

@vtjnash
Copy link
Member

vtjnash commented Jul 26, 2018

so try to do as little work as possible before discovering a non-Const argument.

all of this work is extremely cheap, but seeing Union{} here is not supposed to be possible (it is, but it's not supposed to be)

JeffBezanson
JeffBezanson previously approved these changes Jul 27, 2018
@JeffBezanson
Copy link
Member

So it looks like on windows ccall is having trouble passing a constant struct from a QuoteNode. We might be missing code to generate an address for a constant in that case.

@JeffBezanson JeffBezanson dismissed their stale review July 27, 2018 19:12

seems to trigger a codegen issue

@willow-ahrens
Copy link
Contributor Author

willow-ahrens commented Jul 28, 2018

One of the appveyor tests fails. This makes me sad. The failing example is line 566 of test/ccall.c. I have created an MWE of the failing test case which should fail on appveyor on julia master:

using InteractiveUtils

struct S
  wild::Float32
  wacky::Float32
  waving::Float32
  inflatable::Float64
  armflailing::Float64
  tubeman::Float64
end

f_julia(x::S) = x
f_c = @cfunction(f_julia, S, (S,))

eval(quote
  function g()
      return ccall(f_c, S, (S,), $(QuoteNode(S(42.0f0, 42.0f0, 42.0f0, 42.0, 42.0, 42.0))))
  end
end)
println(g())
println(@code_typed(g()))
@code_llvm(g())

Executing the code on the master branch of julia produces:

S(42.0f0, 42.0f0, 42.0f0, 42.0, 42.0, 42.0)
CodeInfo(
17 1 ─ %1 = $(Expr(:foreigncall, :(Main.f_c), S, svec(S), :(:ccall), 1, :($(QuoteNode(S(42.0, 42.0, 42.0, 42.0, 42.0, 42.0)))), :($(QuoteNode(S(42.0, 42.0, 42.0, 42.0, 42.0, 42.0))))))::S
   └──      return %1
) => S

; Function g
; Location: /Users/Peter/Projects/julia/foo.jl:17
define void @julia_g_33425({ float, float, float, double, double, double }* noalias nocapture sret) {
top:
  %gcframe = alloca %jl_value_t addrspace(10)*, i32 3
  %1 = bitcast %jl_value_t addrspace(10)** %gcframe to i8*
  call void @llvm.memset.p0i8.i32(i8* %1, i8 0, i32 24, i32 0, i1 false)
  %2 = alloca { float, float, float, double, double, double }, align 8
  %3 = call %jl_value_t*** inttoptr (i64 4558284144 to %jl_value_t*** ()*)() #5
  %4 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 0
  %5 = bitcast %jl_value_t addrspace(10)** %4 to i64*
  store i64 2, i64* %5
  %6 = getelementptr %jl_value_t**, %jl_value_t*** %3, i32 0
  %7 = load %jl_value_t**, %jl_value_t*** %6
  %8 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 1
  %9 = bitcast %jl_value_t addrspace(10)** %8 to %jl_value_t***
  store %jl_value_t** %7, %jl_value_t*** %9
  %10 = bitcast %jl_value_t*** %6 to %jl_value_t addrspace(10)***
  store %jl_value_t addrspace(10)** %gcframe, %jl_value_t addrspace(10)*** %10
  %11 = load %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** inttoptr (i64 4767065832 to %jl_value_t addrspace(10)**), align 8
  %12 = bitcast %jl_value_t addrspace(10)* %11 to i64 addrspace(10)*
  %13 = getelementptr i64, i64 addrspace(10)* %12, i64 -1
  %14 = load i64, i64 addrspace(10)* %13
  %15 = and i64 %14, -16
  %16 = inttoptr i64 %15 to %jl_value_t*
  %17 = addrspacecast %jl_value_t* %16 to %jl_value_t addrspace(10)*
  %18 = bitcast %jl_value_t addrspace(10)* %17 to i64 addrspace(10)*
  %19 = getelementptr i64, i64 addrspace(10)* %18, i64 -1
  %20 = load i64, i64 addrspace(10)* %19
  %21 = and i64 %20, -16
  %22 = inttoptr i64 %21 to %jl_value_t*
  %23 = addrspacecast %jl_value_t* %22 to %jl_value_t addrspace(10)*
  %24 = icmp eq %jl_value_t addrspace(10)* %23, addrspacecast (%jl_value_t* inttoptr (i64 4762730512 to %jl_value_t*) to %jl_value_t addrspace(10)*)
  br i1 %24, label %pass, label %fail

fail:                                             ; preds = %top
  %25 = addrspacecast %jl_value_t addrspace(10)* %17 to %jl_value_t addrspace(12)*
  call void @jl_type_error_rt(i8* inttoptr (i64 140476800909664 to i8*), i8* inttoptr (i64 140476801052240 to i8*), %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 4762730512 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(12)* %25)
  unreachable

pass:                                             ; preds = %top
  %26 = addrspacecast %jl_value_t addrspace(10)* %17 to %jl_value_t addrspace(11)*
  %27 = bitcast %jl_value_t addrspace(11)* %26 to %jl_value_t addrspace(10)* addrspace(11)*
  %28 = load %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)* addrspace(11)* %27, align 8
  %29 = addrspacecast %jl_value_t addrspace(10)* %28 to %jl_value_t addrspace(12)*
  %30 = icmp eq %jl_value_t addrspace(12)* %29, addrspacecast (%jl_value_t* inttoptr (i64 4762754336 to %jl_value_t*) to %jl_value_t addrspace(12)*)
  br i1 %30, label %pass2, label %fail1

fail1:                                            ; preds = %pass
  %31 = addrspacecast %jl_value_t addrspace(10)* %11 to %jl_value_t addrspace(12)*
  %32 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 2
  store %jl_value_t addrspace(10)* %11, %jl_value_t addrspace(10)** %32
  call void @jl_type_error_rt(i8* inttoptr (i64 140476800909664 to i8*), i8* inttoptr (i64 140476801052240 to i8*), %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 4762814512 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(12)* %31)
  unreachable

pass2:                                            ; preds = %pass
  %33 = bitcast %jl_value_t addrspace(10)* %11 to i64 addrspace(10)*
  %34 = load i64, i64 addrspace(10)* %33, align 8
  %35 = icmp eq i64 %34, 0
  br i1 %35, label %fail3, label %pass4

fail3:                                            ; preds = %pass2
  call void @jl_throw(%jl_value_t addrspace(12)* addrspacecast (%jl_value_t* inttoptr (i64 4619787712 to %jl_value_t*) to %jl_value_t addrspace(12)*))
  unreachable

pass4:                                            ; preds = %pass2
  %36 = inttoptr i64 %34 to void ({ float, float, float, double, double, double }*, { float, float, float, double, double, double } addrspace(11)*)*
  call void %36({ float, float, float, double, double, double }* noalias nonnull sret %2, { float, float, float, double, double, double } addrspace(11)* byval addrspacecast ({ float, float, float, double, double, double }* @0 to { float, float, float, double, double, double } addrspace(11)*))
  %37 = bitcast { float, float, float, double, double, double }* %0 to i8*
  %38 = bitcast { float, float, float, double, double, double }* %2 to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* %37, i8* %38, i64 40, i32 8, i1 false)
  %39 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 1
  %40 = load %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %39
  %41 = getelementptr %jl_value_t**, %jl_value_t*** %3, i32 0
  %42 = bitcast %jl_value_t*** %41 to %jl_value_t addrspace(10)**
  store %jl_value_t addrspace(10)* %40, %jl_value_t addrspace(10)** %42
  ret void

Which is a real sad time because if you look at the line 10th to the bottom or so we have:

  call void %36({ float, float, float, double, double, double }* noalias nonnull sret %2, { float, float, float, double, double, double } addrspace(11)* byval addrspacecast ({ float, float, float, double, double, double }* @0 to { float, float, float, double, double, double } addrspace(11)*))

And there's something wrong with that null @0 right? Anyways, that's what seems to break appveyor.

@vtjnash
Copy link
Member

vtjnash commented Jul 29, 2018

@0 is a reference to anonymous global variable #0. That’s mostly ok, but I think ccall expects that we are supposed to copy this to the stack (the mutation rules differ across the ccall boundary, which makes managing this interface a bit more special than normal codegen)

@Keno
Copy link
Member

Keno commented Sep 8, 2018

As in #28335, I can no longer reproduce the failure on windows. Rebased, so I think this is good to go if the AppVeyor failure does not come back.

@vchuravy
Copy link
Member

Ping! Would be good to squash this a bit.

@Keno
Copy link
Member

Keno commented Sep 17, 2018

I was just about to rebase this. We can squash this to one commit while merging.

@Keno
Copy link
Member

Keno commented Sep 17, 2018

Looks like there was an httpbin outage while this was running, but the tests passed ok. Fine to squash merge from my perspective, but I'd like @JeffBezanson or @vtjnash to take a final look at the changes since they've sat around for a while.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
compiler:inference Type inference
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants