Skip to content

Commit

Permalink
Implement SymBool (pytorch#92149)
Browse files Browse the repository at this point in the history
We have known for a while that we should in principle support SymBool as a separate concept from SymInt and SymFloat ( in particular, every distinct numeric type should get its own API). However, recent work with unbacked SymInts in, e.g., pytorch#90985 have made this a priority to implement. The essential problem is that our logic for computing the contiguity of tensors performs branches on the passed in input sizes, and this causes us to require guards when constructing tensors from unbacked SymInts. Morally, this should not be a big deal because, we only really care about the regular (non-channels-last) contiguity of the tensor, which should be guaranteed since most people aren't calling `empty_strided` on the tensor, however, because we store a bool (not a SymBool, prior to this PR it doesn't exist) on TensorImpl, we are forced to *immediately* compute these values, even if the value ends up not being used at all. In particular, even when a user allocates a contiguous tensor, we still must compute channels-last contiguity (as some contiguous tensors are also channels-last contiguous, but others are not.)

This PR implements SymBool, and makes TensorImpl use SymBool to store the contiguity information in ExtraMeta. There are a number of knock on effects, which I now discuss below.

* I introduce a new C++ type SymBool, analogous to SymInt and SymFloat. This type supports logical and, logical or and logical negation. I support the bitwise operations on this class (but not the conventional logic operators) to make it clear that logical operations on SymBool are NOT short-circuiting. I also, for now, do NOT support implicit conversion of SymBool to bool (creating a guard in this case). This does matter too much in practice, as in this PR I did not modify the equality operations (e.g., `==` on SymInt) to return SymBool, so all preexisting implicit guards did not need to be changed. I also introduced symbolic comparison functions `sym_eq`, etc. on SymInt to make it possible to create SymBool. The current implementation of comparison functions makes it unfortunately easy to accidentally introduce guards when you do not mean to (as both `s0 == s1` and `s0.sym_eq(s1)` are valid spellings of equality operation); in the short term, I intend to prevent excess guarding in this situation by unit testing; in the long term making the equality operators return SymBool is probably the correct fix.
* ~~I modify TensorImpl to store SymBool for the `is_contiguous` fields and friends on `ExtraMeta`. In practice, this essentially meant reverting most of the changes from pytorch#85936 . In particular, the fields on ExtraMeta are no longer strongly typed; at the time I was particularly concerned about the giant lambda I was using as the setter getting a desynchronized argument order, but now that I have individual setters for each field the only "big list" of boolean arguments is in the constructor of ExtraMeta, which seems like an acceptable risk. The semantics of TensorImpl are now that we guard only when you actually attempt to access the contiguity of the tensor via, e.g., `is_contiguous`. By in large, the contiguity calculation in the implementations now needs to be duplicated (as the boolean version can short circuit, but the SymBool version cannot); you should carefully review the duplicate new implementations. I typically use the `identity` template to disambiguate which version of the function I need, and rely on overloading to allow for implementation sharing. The changes to the `compute_` functions are particularly interesting; for most of the functions, I preserved their original non-symbolic implementation, and then introduce a new symbolic implementation that is branch-less (making use of our new SymBool operations). However, `compute_non_overlapping_and_dense` is special, see next bullet.~~ This appears to cause performance problems, so I am leaving this to an update PR.
* (Update: the Python side pieces for this are still in this PR, but they are not wired up until later PRs.) While the contiguity calculations are relatively easy to write in a branch-free way, `compute_non_overlapping_and_dense` is not: it involves a sort on the strides. While in principle we can still make it go through by using a data oblivious sorting network, this seems like too much complication for a field that is likely never used (because typically, it will be obvious that a tensor is non overlapping and dense, because the tensor is contiguous.) So we take a different approach: instead of trying to trace through the logic computation of non-overlapping and dense, we instead introduce a new opaque operator IsNonOverlappingAndDenseIndicator which represents all of the compute that would have been done here. This function returns an integer 0 if `is_non_overlapping_and_dense` would have returned `False`, and an integer 1 otherwise, for technical reasons (Sympy does not easily allow defining custom functions that return booleans). The function itself only knows how to evaluate itself if all of its arguments are integers; otherwise it is left unevaluated. This means we can always guard on it (as `size_hint` will always be able to evaluate through it), but otherwise its insides are left a black box. We typically do NOT expect this custom function to show up in actual boolean expressions, because we will typically shortcut it due to the tensor being contiguous. It's possible we should apply this treatment to all of the other `compute_` operations, more investigation necessary. As a technical note, because this operator takes a pair of a list of SymInts, we need to support converting `ArrayRef<SymNode>` to Python, and I also unpack the pair of lists into a single list because I don't know if Sympy operations can actually validly take lists of Sympy expressions as inputs. See for example `_make_node_sizes_strides`
* On the Python side, we also introduce a SymBool class, and update SymNode to track bool as a valid pytype. There is some subtlety here: bool is a subclass of int, so one has to be careful about `isinstance` checks (in fact, in most cases I replaced `isinstance(x, int)` with `type(x) is int` for expressly this reason.) Additionally, unlike, C++, I do NOT define bitwise inverse on SymBool, because it does not do the correct thing when run on booleans, e.g., `~True` is `-2`. (For that matter, they don't do the right thing in C++ either, but at least in principle the compiler can warn you about it with `-Wbool-operation`, and so the rule is simple in C++; only use logical operations if the types are statically known to be SymBool). Alas, logical negation is not overrideable, so we have to introduce `sym_not` which must be used in place of `not` whenever a SymBool can turn up. To avoid confusion with `__not__` which may imply that `operators.__not__` might be acceptable to use (it isn't), our magic method is called `__sym_not__`. The other bitwise operators `&` and `|` do the right thing with booleans and are acceptable to use.
* There is some annoyance working with booleans in Sympy. Unlike int and float, booleans live in their own algebra and they support less operations than regular numbers. In particular, `sympy.expand` does not work on them. To get around this, I introduce `safe_expand` which only calls expand on operations which are known to be expandable.

TODO: this PR appears to greatly regress performance of symbolic reasoning. In particular, `python test/functorch/test_aotdispatch.py -k max_pool2d` performs really poorly with these changes. Need to investigate.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: pytorch#92149
Approved by: https://github.com/albanD, https://github.com/Skylion007
  • Loading branch information
ezyang authored and pytorchmergebot committed Jan 21, 2023
1 parent 34e8eb2 commit 5c6f543
Show file tree
Hide file tree
Showing 18 changed files with 626 additions and 149 deletions.
72 changes: 72 additions & 0 deletions c10/core/SymBool.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include <c10/core/SymBool.h>
#include <c10/core/SymNodeImpl.h>
#include <array>
#include <utility>

namespace c10 {

SymNode SymBool::toSymNodeImpl() const {
TORCH_CHECK(is_symbolic());
return SymNode::reclaim_copy(toSymNodeImplUnowned());
}

static std::array<SymNode, 2> normalize_symbools(
const SymBool& a_,
const SymBool& b_) {
SymNode a, b;
if (a_.is_symbolic())
a = a_.toSymNodeImpl();
if (b_.is_symbolic())
b = b_.toSymNodeImpl();

SymNodeImpl* common = a ? a.get() : b.get();
if (!a) {
a = common->wrap_bool(a_.as_bool_unchecked());
}
if (!b) {
b = common->wrap_bool(b_.as_bool_unchecked());
}
return {std::move(a), std::move(b)};
}

SymBool SymBool::sym_and(const SymBool& sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return SymBool(data_ && sci.data_);
}
auto res = normalize_symbools(*this, sci);
return SymBool(res[0]->sym_and(res[1]));
}

SymBool SymBool::sym_or(const SymBool& sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return SymBool(data_ || sci.data_);
}
auto res = normalize_symbools(*this, sci);
return SymBool(res[0]->sym_or(res[1]));
}

SymBool SymBool::sym_not() const {
if (!is_symbolic()) {
return SymBool(!data_);
}
return SymBool(toSymNodeImpl()->sym_not());
}

std::ostream& operator<<(std::ostream& os, const SymBool& s) {
if (s.is_symbolic()) {
os << s.toSymNodeImpl()->str();
} else {
os << s.as_bool_unchecked();
}
return os;
}

bool SymBool::guard_bool(const char* file, int64_t line) const {
if (!is_symbolic()) {
return data_;
}
SymNode a = toSymNodeImpl();
return a->guard_bool(file, line);
}

} // namespace c10
70 changes: 70 additions & 0 deletions c10/core/SymBool.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#pragma once

#include <c10/core/SymNodeImpl.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h>

#include <limits>
#include <memory>

namespace c10 {

class C10_API SymBool {
public:
/*implicit*/ SymBool(bool b) : data_(b){};
SymBool(SymNode ptr) : data_(false), ptr_(std::move(ptr)) {
TORCH_CHECK(ptr_->is_bool());
};
SymBool() : data_(false) {}

SymNodeImpl* toSymNodeImplUnowned() const {
return ptr_.get();
}

SymNodeImpl* release() && {
return std::move(ptr_).release();
}

SymNode toSymNodeImpl() const;

bool expect_bool() const {
TORCH_CHECK(!is_symbolic());
return data_;
}

SymBool sym_and(const SymBool&) const;
SymBool sym_or(const SymBool&) const;
SymBool sym_not() const;

SymBool operator&(const SymBool& other) const {
return sym_and(other);
}
SymBool operator|(const SymBool& other) const {
return sym_or(other);
}
SymBool operator~() const {
return sym_not();
}

// Insert a guard for the bool to be its concrete value, and then return
// that value. Note that C++ comparison operations default to returning
// bool, so it's not so common to have to call this
bool guard_bool(const char* file, int64_t line) const;

C10_ALWAYS_INLINE bool is_symbolic() const {
return ptr_;
}

bool as_bool_unchecked() const {
return data_;
}

private:
// TODO: optimize to union
bool data_;
SymNode ptr_;
};

C10_API std::ostream& operator<<(std::ostream& os, const SymBool& s);
} // namespace c10
28 changes: 16 additions & 12 deletions c10/core/SymInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,48 +94,52 @@ SymInt SymInt::operator%(const SymInt& sci) const {
return SymInt(res[0]->mod(res[1]));
}

bool SymInt::operator==(const SymInt& sci) const {
SymBool SymInt::sym_eq(const SymInt& sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return data_ == sci.data_;
}
auto res = normalize_symints(*this, sci);
return res[0]->eq(res[1])->bool_();
return res[0]->eq(res[1]);
}

bool SymInt::operator!=(const SymInt& sci) const {
return !(*this == sci);
SymBool SymInt::sym_ne(const SymInt& sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return data_ != sci.data_;
}
auto res = normalize_symints(*this, sci);
return res[0]->ne(res[1]);
}

bool SymInt::operator<(const SymInt& sci) const {
SymBool SymInt::sym_lt(const SymInt& sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return data_ < sci.data_;
}
auto res = normalize_symints(*this, sci);
return res[0]->lt(res[1])->bool_();
return res[0]->lt(res[1]);
}

bool SymInt::operator<=(const SymInt& sci) const {
SymBool SymInt::sym_le(const SymInt& sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return data_ <= sci.data_;
}
auto res = normalize_symints(*this, sci);
return res[0]->le(res[1])->bool_();
return res[0]->le(res[1]);
}

bool SymInt::operator>(const SymInt& sci) const {
SymBool SymInt::sym_gt(const SymInt& sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return data_ > sci.data_;
}
auto res = normalize_symints(*this, sci);
return res[0]->gt(res[1])->bool_();
return res[0]->gt(res[1]);
}

bool SymInt::operator>=(const SymInt& sci) const {
SymBool SymInt::sym_ge(const SymInt& sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return data_ >= sci.data_;
}
auto res = normalize_symints(*this, sci);
return res[0]->ge(res[1])->bool_();
return res[0]->ge(res[1]);
}

SymInt SymInt::min(const SymInt& sci) const {
Expand Down
33 changes: 27 additions & 6 deletions c10/core/SymInt.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <c10/core/SymBool.h>
#include <c10/core/SymNodeImpl.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
Expand Down Expand Up @@ -157,16 +158,36 @@ class C10_API SymInt {
SymInt operator*(const SymInt& sci) const;
SymInt operator/(const SymInt& sci) const;
SymInt operator%(const SymInt& sci) const;
bool operator==(const SymInt& sci) const;
bool operator!=(const SymInt& p2) const;
bool operator<(const SymInt& sci) const;
bool operator<=(const SymInt& sci) const;
bool operator>(const SymInt& sci) const;
bool operator>=(const SymInt& sci) const;
void operator*=(const SymInt& sci);
void operator+=(const SymInt& sci);
void operator/=(const SymInt& sci);

SymBool sym_eq(const SymInt&) const;
SymBool sym_ne(const SymInt&) const;
SymBool sym_lt(const SymInt&) const;
SymBool sym_le(const SymInt&) const;
SymBool sym_gt(const SymInt&) const;
SymBool sym_ge(const SymInt&) const;

bool operator==(const SymInt& o) const {
return sym_eq(o).guard_bool(__FILE__, __LINE__);
}
bool operator!=(const SymInt& o) const {
return sym_ne(o).guard_bool(__FILE__, __LINE__);
}
bool operator<(const SymInt& o) const {
return sym_lt(o).guard_bool(__FILE__, __LINE__);
}
bool operator<=(const SymInt& o) const {
return sym_le(o).guard_bool(__FILE__, __LINE__);
}
bool operator>(const SymInt& o) const {
return sym_gt(o).guard_bool(__FILE__, __LINE__);
}
bool operator>=(const SymInt& o) const {
return sym_ge(o).guard_bool(__FILE__, __LINE__);
}

SymInt min(const SymInt& sci) const;
SymInt max(const SymInt& sci) const;

Expand Down
25 changes: 25 additions & 0 deletions c10/core/SymNodeImpl.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <c10/macros/Macros.h>
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#include <c10/util/intrusive_ptr.h>
#include <memory>
Expand All @@ -25,6 +26,9 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
virtual bool is_int() {
TORCH_CHECK(false, "NYI");
};
virtual bool is_bool() {
TORCH_CHECK(false, "NYI");
};
virtual bool is_float() {
TORCH_CHECK(false, "NYI");
};
Expand Down Expand Up @@ -82,6 +86,21 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
virtual SymNode sym_max(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_or(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_and(const SymNode& other) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode sym_not() {
TORCH_CHECK(false, "NYI");
};
// NB: self is ignored here, only the arguments are used
virtual SymNode is_non_overlapping_and_dense(
ArrayRef<SymNode> sizes,
ArrayRef<SymNode> strides) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode clone() {
TORCH_CHECK(false, "NYI");
};
Expand All @@ -94,9 +113,15 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target {
virtual SymNode wrap_float(double num) {
TORCH_CHECK(false, "NYI");
};
virtual SymNode wrap_bool(bool num) {
TORCH_CHECK(false, "NYI");
};
virtual int64_t guard_int(const char* file, int64_t line) {
TORCH_CHECK(false, "NYI");
};
virtual bool guard_bool(const char* file, int64_t line) {
TORCH_CHECK(false, "NYI");
};
virtual double guard_float(const char* file, int64_t line) {
TORCH_CHECK(false, "NYI");
};
Expand Down
2 changes: 0 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,6 @@
"Quantize",
# torch.utils.backcompat
"Warning",
"SymInt",
"SymFloat",
]

# The suffix(es) of source filenames.
Expand Down
10 changes: 10 additions & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,15 @@ Utilities

Symbolic Numbers
----------------
.. autoclass:: SymInt
:members:

.. autoclass:: SymFloat
:members:

.. autoclass:: SymBool
:members:

.. autosummary::
:toctree: generated
:nosignatures:
Expand All @@ -629,6 +638,7 @@ Symbolic Numbers
sym_int
sym_max
sym_min
sym_not

Optimizations
-------------
Expand Down
Loading

0 comments on commit 5c6f543

Please sign in to comment.