Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite rank 0 elemwise ops and push scalar constants into elemwise #107

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

aseyboldt
Copy link
Member

@aseyboldt aseyboldt commented Dec 11, 2022

A bit more groundwork for the #92, to remove some cases where elemwise ops are not actually needed.

This adds two rewrites:

  • local_elemwise_lift_scalars is meant to remove Elemwise ops that do not actually vectorize anything because all inputs are rank 0 tensors, and replaces them by a TensorFromScalar(Composite). I put this into the "specialize" phase, because I think it interacts badly with stabilization rewrites (that come before the specialize phase). I'm not really sure I like the way this works yet. It seems that currently most rewrites assume that during canonicalization we always deal with elemwise ops instead of scalar ops. And many rewrites then only apply to those. So for instance if we avoid all elemwise ops and only use scalars, even basic rewrites are not applied at all:

    x = pytensor.scalar.ScalarVariable(pytensor.scalar.ScalarType("float64"), None)
    out = pytensor.scalar.log(pytensor.scalar.as_scalar(1) + x)
    pytensor.dprint(out)

    is not rewritten to use log1p:

    func = pytensor.function([x], out, mode="NUMBA")
    pytensor.dprint(func)
    # log [id A] 1
    # |add [id B] 0
    #   |ScalarConstant{1} [id C]
    #   |<float64> [id D]

    because those rewrites only target tensor log and add ops, not the scalar versions.
    So if we put this new rewrite in canonicalize (where I think it maybe should belong?) it would break a lot of other rewrites that assume that everything is wrapped in an elemwise.
    I wonder if we can change the pattern rewriter so that it figures this out automatically and also applies rewrites to matching scalar ops? I think this would give us quite a bit more flexibility, because then we can change things to scalar ops, which should usually compile and execute faster.

  • push_elemwise_constants I made it so this rewrite only applies to numba (see below), and is applied after Elemwise fusion. We often have Ops that use scalar constants in an elemwise, but those constants are provided as rank 0 inputs to the Elemwise ops and then broadcasted, when they could just be ScalarConstants in the inner Composite op. So we rewrite

      Elemwise{Composite{(i0 + (i1 * i2))}} [id A] 0
       |TensorConstant{(1,) of 1.0} [id B]
       |TensorConstant{(1,) of 2.0} [id C]
       |x [id D]

    to

    Elemwise{Composite{(1.0 + (2.0 * i0))}} [id A] 0
     |x [id B]

    I ran into segfaults when I tried this with the c backend, so this only applies to numba for now.

  • I also made some small changes to the rewrites: Previously the elemwise fusion db was created in a different file than all the other basic rewrite dbs (mode.py), so I moved that to the others to make it more consistent. I then also created a new rewrite phase post_fusion that is executed right after the elemwise_fusion rewrites, that currently only contains push_elemwise_constants.

Update

  • I also removed local_subtensor_merge from the canonicalize pass. This is supposed to simplify chained Subtensor ops, but I'm not sure we really should have this rewrite be so aggressive. There are cases where I'd say it is making things more complicated instead of simpler, which was especially apparent in combination with local_elemwise_lift_scalars. For instance it would end up rewriting this:
    x = pt.dvector("x")
    a = x.shape[0]
    b = x[0:][:a]
    pytensor.dprint([a, b])
    func = pytensor.function([x], b, mode="NUMBA")
    into
    DeepCopyOp [id A] 5
     |Subtensor{int64:int64:int8} [id B] 4
       |x [id C]
       |Switch [id D] 3
       | |LE [id E] 2
       | | |ScalarFromTensor [id F] 1
       | | | |Shape_i{0} [id G] 0
       | | |   |x [id C]
       | | |ScalarConstant{0} [id H]
       | |ScalarConstant{0} [id H]
       | |ScalarConstant{0} [id I]
       |ScalarFromTensor [id F] 1
       |ScalarConstant{1} [id J]
    while without it pytensor would notice that b is just x. Maybe we should change this rewrite so that it only does something if it knows statically if start and end are non-negative? (And I guess this also suggests we could use a rewrite that teaches graphs that shapes are non-negative (or they should just return an unsigned int?))

Update
A bit more motivation for this:
I think the numba backend at least can really profit from turning more things into scalar ops. Both for compile time and run time. This clashes a bit with what theano thinks of as the "canonical form", where pretty much everything is a tensor.
Maybe a nice compromise between this might be to leave the canonicalize and specialize phases exactly as they are, so that the tensor form stays the canonical way of representing everything, and all the rewrites in those phases work with that.
But at some stage later (not sure exactly when...) we could add a stage (possibly numba specific) that tries to turns as much as possible into scalars. So elemwise ops, shape_i, sum and a bunch of others could be rewritten to return scalars.

The reason I think scalars are better if possible in numba is that tensors produce a lot of code, which slows down compilation, adds lots of allocations and makes the code in general much less transparent for llvm, which I think leads to lots of missed optimizations there.

I benchmarked a small radon model for this a bit, and in that model for instance we spend about 5% of the time in allocation code. That might not sound like too much, but I think this is way to much for comfort, given that most of the cost of those allocations in terms of missed optimizations, cache misses, ref counting etc will be hidden.

After the rewrites here I see a lot of code like this in logp graphs:

Sum{acc_dtype=float64} [id A] '__logp' 120
 |MakeVector{dtype='float64'} [id B] 114
   |TensorFromScalar [id C] 'intercept_logprob' 73
   | |sub [id D] 66
   |   |mul [id E] 59
   |   | |ScalarConstant{-0.5} [id F]
   |   | |sqr [id G] 47
   |   |   |mul [id H] 32
   |   |     |ScalarConstant{0.1} [id I]
   |   |     |ScalarFromTensor [id J] 23
   |   |       |Reshape{0} [id K] 14
   |   |         |Subtensor{int64:int64:} [id L] 6
   |   |         | |__joined_variables [id M]
   |   |         | |ScalarConstant{0} [id N]
   |   |         | |ScalarConstant{1} [id O]
   |   |         |TensorConstant{[]} [id P]
   |   |ScalarConstant{3.2215236261987186} [id Q]
   |Sum{acc_dtype=float64} [id R] 30
   | |Elemwise{Composite{((-0.5 * sqr(i0)) - 0.9189385332046727)}} [id S] 'county_raw_logprob' 21
   |   |SpecifyShape [id T] 13
   |     |Subtensor{int64:int64:} [id U] 5
   |     | |__joined_variables [id M]
   |     | |ScalarConstant{1} [id O]
   |     | |ScalarConstant{86} [id V]
   |     |TensorConstant{85} [id W]
   |TensorFromScalar [id X] 'county_sd_log___logprob' 82
   | |add [id Y] 77
   |   |Switch [id Z] 72
   |   | |GE [id BA] 45
...

There is quite a bit of potential now to rewrite this further: Sum of make vector for instance shoudl just be the sum of the elements. But those sums can easily be represented as a scalar add node for instance, so that we never have to allocate those tensors.

So why do I think that rank 0 tensors are something to avoid in numba? Take this addition for instance:

import numba
from numba import types
import numpy as np

@numba.njit("float64(float64, float64)", no_cpython_wrapper=True)
def add_scalar(x, y):
    return x + y

ty_rank0 = types.Array(numba.float64, 0, "C")

@numba.njit(ty_rank0(ty_rank0, ty_rank0), no_cpython_wrapper=True)
def add_tensor(x, y):
    return np.asarray(x + y)

The first is compiled to this:

; ModuleID = 'add_scalar'
source_filename = "<string>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

@_ZN08NumbaEnv8__main__10add_scalarB3v14B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dEdd = common local_unnamed_addr global i8* null

; Function Attrs: nofree norecurse nounwind writeonly
define i32 @_ZN8__main__10add_scalarB3v14B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dEdd(double* noalias nocapture %retptr, { i8*, i32, i8* }** noalias nocapture readnone %excinfo, double %arg.x, double %arg.y) local_unnamed_addr #0 {
entry:
  %.6 = fadd double %arg.x, %arg.y
  store double %.6, double* %retptr, align 8
  ret i32 0
}

; Function Attrs: nofree norecurse nounwind writeonly
define double @cfunc._ZN8__main__10add_scalarB3v14B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dEdd(double %.1, double %.2) local_unnamed_addr #0 {
entry:
  %.4 = alloca double, align 8
  store double 0.000000e+00, double* %.4, align 8
  %.8 = call i32 @_ZN8__main__10add_scalarB3v14B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dEdd(double* nonnull %.4, { i8*, i32, i8* }** undef, double %.1, double %.2) #2
  %.18 = load double, double* %.4, align 8
  ret double %.18
}

; Function Attrs: nounwind
declare void @llvm.stackprotector(i8*, i8**) #1

attributes #0 = { nofree norecurse nounwind writeonly }
attributes #1 = { nounwind }
attributes #2 = { noinline }

The addition is turned into a single fadd instruction that is really easy to reason about for llvm:

  %.6 = fadd double %arg.x, %arg.y

In contrast the second look like this (after optimization):

; ModuleID = 'add_tensor'
source_filename = "<string>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

@_ZN08NumbaEnv8__main__10add_tensorB3v13B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dE5ArrayIdLi0E1C7mutable7alignedE5ArrayIdLi0E1C7mutable7alignedE = common local_unnamed_addr global i8* null
@".const.<numba.core.cpu.CPUContext object at 0x7f44de3b8640>" = internal constant [53 x i8] c"<numba.core.cpu.CPUContext object at 0x7f44de3b8640>\00"
@PyExc_SystemError = external global i8
@".const.unknown error when calling native function" = internal constant [43 x i8] c"unknown error when calling native function\00"
@_ZN08NumbaEnv5numba2np9arraymath10np_asarray12_3clocals_3e4implB2v7B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dEd27omitted_28default_3dNone_29 = common local_unnamed_addr global i8* null
@_ZN08NumbaEnv5numba2np8arrayobj13impl_np_array12_3clocals_3e4implB2v3B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dEd18dtype_28float64_29 = common local_unnamed_addr global i8* null
@_ZN08NumbaEnv5numba2np8arrayobj15_call_allocatorB2v4B44c8tJTC_2fWQA9wW1DkAz0Pj1skAdT4gkkUlYBZmgA_3dEN29typeref_5b_3cclass_20_27numba4core5types8npytypes14Array_27_3e_5dExj = common local_unnamed_addr global i8* null
@_ZN08NumbaEnv5numba2np8arrayobj18_ol_array_allocate12_3clocals_3e4implB2v5B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dEN29typeref_5b_3cclass_20_27numba4core5types8npytypes14Array_27_3e_5dExj = common local_unnamed_addr global i8* null
@.const.picklebuf.139933779134528 = internal constant { i8*, i32, i8* } { i8* getelementptr inbounds ([86 x i8], [86 x i8]* @.const.pickledata.139933779134528, i32 0, i32 0), i32 86, i8* getelementptr inbounds ([20 x i8], [20 x i8]* @.const.pickledata.139933779134528.sha1, i32 0, i32 0) }
@.const.pickledata.139933779134528 = internal constant [86 x i8] c"\80\04\95K\00\00\00\00\00\00\00\8C\08builtins\94\8C\0BMemoryError\94\93\94\8C'Allocation failed (probably too large).\94\85\94N\87\94."
@.const.pickledata.139933779134528.sha1 = internal constant [20 x i8] c"\BA(\9D\81\F0\\p \F3G|\15sH\04\DFe\AB\E2\09"

define i32 @_ZN8__main__10add_tensorB3v13B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dE5ArrayIdLi0E1C7mutable7alignedE5ArrayIdLi0E1C7mutable7alignedE({ i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }* noalias nocapture %retptr, { i8*, i32, i8* }** noalias nocapture %excinfo, i8* nocapture readnone %arg.x.0, i8* nocapture readnone %arg.x.1, i64 %arg.x.2, i64 %arg.x.3, double* nocapture readonly %arg.x.4, i8* nocapture readnone %arg.y.0, i8* nocapture readnone %arg.y.1, i64 %arg.y.2, i64 %arg.y.3, double* nocapture readonly %arg.y.4) local_unnamed_addr {
entry:
  %.43.le = load double, double* %arg.x.4, align 8
  %.45.le = load double, double* %arg.y.4, align 8
  %.7.i.i.i.i = tail call i8* @NRT_MemInfo_alloc_aligned(i64 8, i32 32), !noalias !0
  %.8.i.i.i.i = icmp eq i8* %.7.i.i.i.i, null
  br i1 %.8.i.i.i.i, label %afterloop.if, label %afterloop.endif, !prof !13

afterloop.if:                                     ; preds = %entry
  store { i8*, i32, i8* }* @.const.picklebuf.139933779134528, { i8*, i32, i8* }** %excinfo, align 8
  ret i32 1, !ret_is_raise !14

afterloop.endif:                                  ; preds = %entry
  %.46.le = fadd double %.43.le, %.45.le
  %.5.i.i.i = getelementptr i8, i8* %.7.i.i.i.i, i64 24
  %0 = bitcast i8* %.5.i.i.i to double**
  %.6.i2.i.i = load double*, double** %0, align 8, !noalias !15
  store double %.46.le, double* %.6.i2.i.i, align 8, !noalias !15
  %1 = ptrtoint i8* %.7.i.i.i.i to i64
  %2 = ptrtoint double* %.6.i2.i.i to i64
  %3 = bitcast { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }* %retptr to i64*
  store i64 %1, i64* %3, align 8
  %retptr.repack2 = getelementptr inbounds { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }, { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }* %retptr, i64 0, i32 1
  %4 = bitcast i8** %retptr.repack2 to <2 x i64>*
  store <2 x i64> <i64 0, i64 1>, <2 x i64>* %4, align 8
  %retptr.repack6 = getelementptr inbounds { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }, { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }* %retptr, i64 0, i32 3
  store i64 8, i64* %retptr.repack6, align 8
  %retptr.repack8 = getelementptr inbounds { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }, { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }* %retptr, i64 0, i32 4
  %5 = bitcast double** %retptr.repack8 to i64*
  store i64 %2, i64* %5, align 8
  ret i32 0
}

define { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } @cfunc._ZN8__main__10add_tensorB3v13B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dE5ArrayIdLi0E1C7mutable7alignedE5ArrayIdLi0E1C7mutable7alignedE({ i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } %.1, { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } %.2) local_unnamed_addr {
entry:
  %.4 = alloca { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }, align 8
  %.fca.0.gep1 = bitcast { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }* %.4 to i8**
  %.fca.1.gep = getelementptr inbounds { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }, { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }* %.4, i64 0, i32 1
  %.fca.2.gep = getelementptr inbounds { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }, { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }* %.4, i64 0, i32 2
  %.fca.3.gep = getelementptr inbounds { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }, { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }* %.4, i64 0, i32 3
  %.fca.4.gep = getelementptr inbounds { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }, { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }* %.4, i64 0, i32 4
  %excinfo = alloca { i8*, i32, i8* }*, align 8
  %0 = bitcast { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }* %.4 to i8*
  call void @llvm.memset.p0i8.i64(i8* nonnull align 8 dereferenceable(40) %0, i8 0, i64 40, i1 false)
  store { i8*, i32, i8* }* null, { i8*, i32, i8* }** %excinfo, align 8
  %extracted.meminfo = extractvalue { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } %.1, 0
  %extracted.parent = extractvalue { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } %.1, 1
  %extracted.nitems = extractvalue { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } %.1, 2
  %extracted.itemsize = extractvalue { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } %.1, 3
  %extracted.data = extractvalue { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } %.1, 4
  %extracted.meminfo.1 = extractvalue { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } %.2, 0
  %extracted.parent.1 = extractvalue { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } %.2, 1
  %extracted.nitems.1 = extractvalue { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } %.2, 2
  %extracted.itemsize.1 = extractvalue { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } %.2, 3
  %extracted.data.1 = extractvalue { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } %.2, 4
  %.8 = call i32 @_ZN8__main__10add_tensorB3v13B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dE5ArrayIdLi0E1C7mutable7alignedE5ArrayIdLi0E1C7mutable7alignedE({ i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] }* nonnull %.4, { i8*, i32, i8* }** nonnull %excinfo, i8* %extracted.meminfo, i8* %extracted.parent, i64 %extracted.nitems, i64 %extracted.itemsize, double* %extracted.data, i8* %extracted.meminfo.1, i8* %extracted.parent.1, i64 %extracted.nitems.1, i64 %extracted.itemsize.1, double* %extracted.data.1) #2
  %.9 = load { i8*, i32, i8* }*, { i8*, i32, i8* }** %excinfo, align 8
  %.10.not = icmp eq i32 %.8, 0
  %.18.fca.0.load = load i8*, i8** %.fca.0.gep1, align 8
  %.18.fca.1.load = load i8*, i8** %.fca.1.gep, align 8
  %.18.fca.2.load = load i64, i64* %.fca.2.gep, align 8
  %.18.fca.3.load = load i64, i64* %.fca.3.gep, align 8
  %.18.fca.4.load = load double*, double** %.fca.4.gep, align 8
  %.27 = alloca i32, align 4
  store i32 0, i32* %.27, align 4
  br i1 %.10.not, label %entry.endif, label %entry.if, !prof !16

entry.if:                                         ; preds = %entry
  %.16 = icmp sgt i32 %.8, 0
  call void @numba_gil_ensure(i32* nonnull %.27)
  br i1 %.16, label %entry.if.if, label %entry.if.endif.endif.endif

entry.endif:                                      ; preds = %entry, %.30
  %.18.fca.0.insert = insertvalue { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } undef, i8* %.18.fca.0.load, 0
  %.18.fca.1.insert = insertvalue { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } %.18.fca.0.insert, i8* %.18.fca.1.load, 1
  %.18.fca.2.insert = insertvalue { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } %.18.fca.1.insert, i64 %.18.fca.2.load, 2
  %.18.fca.3.insert = insertvalue { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } %.18.fca.2.insert, i64 %.18.fca.3.load, 3
  %.18.fca.4.insert = insertvalue { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } %.18.fca.3.insert, double* %.18.fca.4.load, 4
  ret { i8*, i8*, i64, i64, double*, [0 x i64], [0 x i64] } %.18.fca.4.insert

.30:                                              ; preds = %entry.if.if, %entry.if.if.if, %entry.if.endif.endif.endif
  %.52 = call i8* @PyUnicode_FromString(i8* getelementptr inbounds ([53 x i8], [53 x i8]* @".const.<numba.core.cpu.CPUContext object at 0x7f44de3b8640>", i64 0, i64 0))
  call void @PyErr_WriteUnraisable(i8* %.52)
  call void @Py_DecRef(i8* %.52)
  call void @numba_gil_release(i32* nonnull %.27)
  br label %entry.endif

entry.if.if:                                      ; preds = %entry.if
  call void @PyErr_Clear()
  %.33 = load { i8*, i32, i8* }, { i8*, i32, i8* }* %.9, align 8
  %.34 = extractvalue { i8*, i32, i8* } %.33, 0
  %.36 = extractvalue { i8*, i32, i8* } %.33, 1
  %.38 = extractvalue { i8*, i32, i8* } %.33, 2
  %.39 = call i8* @numba_unpickle(i8* %.34, i32 %.36, i8* %.38)
  %.40.not = icmp eq i8* %.39, null
  br i1 %.40.not, label %.30, label %entry.if.if.if, !prof !13

entry.if.if.if:                                   ; preds = %entry.if.if
  call void @numba_do_raise(i8* nonnull %.39)
  br label %.30

entry.if.endif.endif.endif:                       ; preds = %entry.if
  call void @PyErr_SetString(i8* nonnull @PyExc_SystemError, i8* getelementptr inbounds ([43 x i8], [43 x i8]* @".const.unknown error when calling native function", i64 0, i64 0))
  br label %.30
}

declare void @numba_gil_ensure(i32*) local_unnamed_addr

declare i8* @PyUnicode_FromString(i8*) local_unnamed_addr

declare void @PyErr_WriteUnraisable(i8*) local_unnamed_addr

declare void @Py_DecRef(i8*) local_unnamed_addr

declare void @numba_gil_release(i32*) local_unnamed_addr

declare void @PyErr_Clear() local_unnamed_addr

declare i8* @numba_unpickle(i8*, i32, i8*) local_unnamed_addr

declare void @numba_do_raise(i8*) local_unnamed_addr

declare void @PyErr_SetString(i8*, i8*) local_unnamed_addr

declare noalias i8* @NRT_MemInfo_alloc_aligned(i64, i32) local_unnamed_addr

; Function Attrs: argmemonly nounwind willreturn writeonly
declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1 immarg) #0

; Function Attrs: nounwind
declare void @llvm.stackprotector(i8*, i8**) #1

attributes #0 = { argmemonly nounwind willreturn writeonly }
attributes #1 = { nounwind }
attributes #2 = { noinline }

!0 = !{!1, !3, !4, !6, !7, !9, !10, !12}
!1 = distinct !{!1, !2, !"_ZN5numba2np8arrayobj18_ol_array_allocate12_3clocals_3e4implB2v5B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dEN29typeref_5b_3cclass_20_27numba4core5types8npytypes14Array_27_3e_5dExj: %retptr"}
!2 = distinct !{!2, !"_ZN5numba2np8arrayobj18_ol_array_allocate12_3clocals_3e4implB2v5B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dEN29typeref_5b_3cclass_20_27numba4core5types8npytypes14Array_27_3e_5dExj"}
!3 = distinct !{!3, !2, !"_ZN5numba2np8arrayobj18_ol_array_allocate12_3clocals_3e4implB2v5B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dEN29typeref_5b_3cclass_20_27numba4core5types8npytypes14Array_27_3e_5dExj: %excinfo"}
!4 = distinct !{!4, !5, !"_ZN5numba2np8arrayobj15_call_allocatorB2v4B44c8tJTC_2fWQA9wW1DkAz0Pj1skAdT4gkkUlYBZmgA_3dEN29typeref_5b_3cclass_20_27numba4core5types8npytypes14Array_27_3e_5dExj: %retptr"}
!5 = distinct !{!5, !"_ZN5numba2np8arrayobj15_call_allocatorB2v4B44c8tJTC_2fWQA9wW1DkAz0Pj1skAdT4gkkUlYBZmgA_3dEN29typeref_5b_3cclass_20_27numba4core5types8npytypes14Array_27_3e_5dExj"}
!6 = distinct !{!6, !5, !"_ZN5numba2np8arrayobj15_call_allocatorB2v4B44c8tJTC_2fWQA9wW1DkAz0Pj1skAdT4gkkUlYBZmgA_3dEN29typeref_5b_3cclass_20_27numba4core5types8npytypes14Array_27_3e_5dExj: %excinfo"}
!7 = distinct !{!7, !8, !"_ZN5numba2np8arrayobj13impl_np_array12_3clocals_3e4implB2v3B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dEd18dtype_28float64_29: %retptr"}
!8 = distinct !{!8, !"_ZN5numba2np8arrayobj13impl_np_array12_3clocals_3e4implB2v3B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dEd18dtype_28float64_29"}
!9 = distinct !{!9, !8, !"_ZN5numba2np8arrayobj13impl_np_array12_3clocals_3e4implB2v3B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dEd18dtype_28float64_29: %excinfo"}
!10 = distinct !{!10, !11, !"_ZN5numba2np9arraymath10np_asarray12_3clocals_3e4implB2v7B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dEd27omitted_28default_3dNone_29: %retptr"}
!11 = distinct !{!11, !"_ZN5numba2np9arraymath10np_asarray12_3clocals_3e4implB2v7B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dEd27omitted_28default_3dNone_29"}
!12 = distinct !{!12, !11, !"_ZN5numba2np9arraymath10np_asarray12_3clocals_3e4implB2v7B42c8tJTIcFHzwl2ILiXkcBV0KBSmNGHkyiCKJEEwA_3dEd27omitted_28default_3dNone_29: %excinfo"}
!13 = !{!"branch_weights", i32 1, i32 99}
!14 = !{i1 true}
!15 = !{!7, !9, !10, !12}
!16 = !{!"branch_weights", i32 99, i32 1}

Including a call to @NRT_MemInfo_alloc_aligned(i64 8, i32 32), which allocates, and sets up refcounting. But llvm doesn't know what this function is doing, so it's presence prevents a lot of optimizations. For instance if we do something like this:

@numba.njit(types.void(types.float64), no_cpython_wrapper=True)
def useless_asarray(x):
    np.asarray(x)

where we just call np.asarray, but never use it in any way, it still can't optimize that away.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about the Elemwise related changes. Indeed Theano rewriting is built all around tensor nodes.

I am very interested by the inline of scalar Constants in Composite. I don't see why it would fail with the C backend ...

pytensor/tensor/rewriting/elemwise.py Outdated Show resolved Hide resolved
return [as_tensor_variable(out) for out in op.scalar_op.make_node(*scalars).outputs]


compile.optdb["specialize"].register(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

specialize is optional, if Numba will fail without this rewrite we should add a new non optional rewrite phase at the end.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should probably make it so that elemwise doesn't fail in numba either way.
To be honest, I'm having a hard time seeing how we could expect everything to work nicely if users pick and choose rewrites. I think that is a testing nightmare...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll have to rething the phases a bit for sure though. Maybe this actually belongs in "uncanonicalize" or so? Or we could have a new phase "scalarize"...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved it to a new "scalarize" phase

if not isinstance(op, Elemwise):
return False

if any(op.inplace_pattern):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's fine to be inplace because constants are never inplaced. But to not have to deal with it just register this rewrite before the inplace rewrites

Copy link
Member Author

@aseyboldt aseyboldt Dec 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was worried that maybe some downstream op is assuming that one of the inputs has in fact changed? It should be running before the inline passes anyway though....

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That shouldn't happen. Inplace rewrites are myopic, they only look at 1 node at a time. I never saw a rewrite checking inplace patterns elsewhere.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, what about examples like this?

import pytensor.tensor as pt
import pytensor

x = pt.dvector("x")
y = 2 * x + 1

val = np.ones(3)
input = pytensor.In(x, update=y)
func = pytensor.function([input], [])
pytensor.dprint(func)
Elemwise{Composite{(1.0 + (2.0 * i0))}}[(0, 0)] [id A] 0
 |x [id B]

Replacing the inplace Elemwise with a non-inplace Elemwise would be incorrect here.
Still not a problem because the rewrite is registered before the inplace pass, but still...

Copy link
Member

@ricardoV94 ricardoV94 Dec 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the update supposed to create a problem? I mean if you are worried about this rewrite ignoring inplacing you would have to be worried in every other rewrite we have the library. What is special about your push constants rewrite?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not 100% sure there would be...
But it looks like we need this final Elemwise in the graph not for its output but only for its side effect of changing the first input. If we were to replace this node with an Elemwise without the inplace flag, but the same output, wouldn't the update break? But maybe there is a feature somewhere that prevents this?

input for i, input in enumerate(inner_graph.inputs) if i not in push_idxs
]
return (
Elemwise(scalar_op=aes.Composite(new_inputs, inner_graph.outputs))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... I am curious why would it fail. I can have a look at the generated C code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, figuring out what exactly is going wrong here would be good I think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow I can't reproduce the segfaults anymore...
I'm getting compilation errors however.

Copy link
Member

@ricardoV94 ricardoV94 Dec 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested one simple example locally and it seems to work in the C-backend. Can you share the problematic example? The more I look at this PR the more it seems it shouldn't be made Numba specific!

@@ -380,6 +380,99 @@ def is_dimshuffle_useless(new_order, input):
return is_useless


@node_rewriter([Elemwise])
def local_elemwise_lift_scalars(fgraph, node):
Copy link
Member

@ricardoV94 ricardoV94 Dec 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rewrite is not really lifting scalars (they are sandwiched between the same inputs and outputs), maybe call it "elemwise_to_scalar"?

Anyway, wouldn't it be easier to create a different Numba function when dispatching? Don't you still have the same problem with mixed rank0 and non rank0 inputs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, elemwise_to_scalar is better :-)
I could have a separate numba function to solve the immediate scalar issue in elemwise...
I'll write a bit more about my motivation for this rewrite below...

@aseyboldt
Copy link
Member Author

@ricardoV94 I expanded the description a bit to add a bit of the motivation. I'm curious to hear what you think :-)

@ricardoV94
Copy link
Member

I totally agree with specializing for scalar graphs. I would do it at a later rewrite phase like you suggested. That way we can keep coverage from our rewrites without duplicated work. If JAX and C show speedups we could also include that phase in those backends.

We really need to start a benchmark suite to guide performance related changes!

There is no need for an Elemwise Op if all inputs have rank 0.
And we don't need to use scalar constants as inputs of the
Elemwise, they can be inputs for the scalar_op.
@aseyboldt aseyboldt force-pushed the elemwise-rewrites branch 3 times, most recently from d651753 to 011d3f6 Compare December 13, 2022 05:15
if not isinstance(op, Elemwise):
return False

if any(op.inplace_pattern):
Copy link
Member

@ricardoV94 ricardoV94 Dec 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the update supposed to create a problem? I mean if you are worried about this rewrite ignoring inplacing you would have to be worried in every other rewrite we have the library. What is special about your push constants rewrite?

Comment on lines +1572 to +1581
@register_scalarize
@node_rewriter([Sum])
def local_sum_of_makevector(fgraph, node):
(array,) = node.inputs
if not array.owner or not isinstance(array.owner.op, MakeVector):
return False

values = array.owner.inputs
summed = aes.add(*values)
return [as_tensor_variable(summed)]
Copy link
Member

@ricardoV94 ricardoV94 Dec 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to touch on #59

Can we abstract the scalarize part from the "lift reduction operations towards the inputs", which is useful regardless of the backend? Even the scalarize seems useful in both backends. What was the problem with the C backend again?

Comment on lines +472 to +473
@register_stabilize("cxx_only")
@register_canonicalize("cxx_only")
Copy link
Member

@ricardoV94 ricardoV94 Dec 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me push against this approach.

There are three scenarios for this rewrite:

  1. It's not very useful and should be reconsidered, regardless of backend
  2. It's useful in the context of a larger chain of rewrites, regardless of backend
  3. It's only useful in one specific backend.

I don't see the reason for 3

If you start making rewrites exclusive to the C-backend you will forget about 2. But eventually you will want to make numba the default backend and you will want the old tests to pass. You will now have made your task much more challenging because you diverged the C and Numba backends, and the latter's test suite is way more myopic.

It's actually a blessing that Theano/Aesara had very extensive test suites and it was difficult to break things unintentionally. But restricting rewrites to the old well tested backend that we want to eventually replace by the new poorly tested one, is opting out of this safety net. In a sense you will just be kicking the can down the road. The decision about the rewrite will have to be done regardless, but by then the Numba rewrite passes may look so different (because it was developed in a much more forgiving test suite) that you cannot even reason about the two and make an informed choice.


In short I think we should be very very selective about the rewrites that are backend specific. For instance I think we should definitely investigate if the scalarize changes also make sense for the C and JAX backends.

@ricardoV94
Copy link
Member

I think #349 might be a good solution?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants