diff --git a/c10/core/SymBool.cpp b/c10/core/SymBool.cpp new file mode 100644 index 0000000000000..c41cffb061356 --- /dev/null +++ b/c10/core/SymBool.cpp @@ -0,0 +1,72 @@ +#include +#include +#include +#include + +namespace c10 { + +SymNode SymBool::toSymNodeImpl() const { + TORCH_CHECK(is_symbolic()); + return SymNode::reclaim_copy(toSymNodeImplUnowned()); +} + +static std::array normalize_symbools( + const SymBool& a_, + const SymBool& b_) { + SymNode a, b; + if (a_.is_symbolic()) + a = a_.toSymNodeImpl(); + if (b_.is_symbolic()) + b = b_.toSymNodeImpl(); + + SymNodeImpl* common = a ? a.get() : b.get(); + if (!a) { + a = common->wrap_bool(a_.as_bool_unchecked()); + } + if (!b) { + b = common->wrap_bool(b_.as_bool_unchecked()); + } + return {std::move(a), std::move(b)}; +} + +SymBool SymBool::sym_and(const SymBool& sci) const { + if (!is_symbolic() && !sci.is_symbolic()) { + return SymBool(data_ && sci.data_); + } + auto res = normalize_symbools(*this, sci); + return SymBool(res[0]->sym_and(res[1])); +} + +SymBool SymBool::sym_or(const SymBool& sci) const { + if (!is_symbolic() && !sci.is_symbolic()) { + return SymBool(data_ || sci.data_); + } + auto res = normalize_symbools(*this, sci); + return SymBool(res[0]->sym_or(res[1])); +} + +SymBool SymBool::sym_not() const { + if (!is_symbolic()) { + return SymBool(!data_); + } + return SymBool(toSymNodeImpl()->sym_not()); +} + +std::ostream& operator<<(std::ostream& os, const SymBool& s) { + if (s.is_symbolic()) { + os << s.toSymNodeImpl()->str(); + } else { + os << s.as_bool_unchecked(); + } + return os; +} + +bool SymBool::guard_bool(const char* file, int64_t line) const { + if (!is_symbolic()) { + return data_; + } + SymNode a = toSymNodeImpl(); + return a->guard_bool(file, line); +} + +} // namespace c10 diff --git a/c10/core/SymBool.h b/c10/core/SymBool.h new file mode 100644 index 0000000000000..de2d7c2f28250 --- /dev/null +++ b/c10/core/SymBool.h @@ -0,0 +1,70 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace c10 { + +class C10_API SymBool { + public: + /*implicit*/ SymBool(bool b) : data_(b){}; + SymBool(SymNode ptr) : data_(false), ptr_(std::move(ptr)) { + TORCH_CHECK(ptr_->is_bool()); + }; + SymBool() : data_(false) {} + + SymNodeImpl* toSymNodeImplUnowned() const { + return ptr_.get(); + } + + SymNodeImpl* release() && { + return std::move(ptr_).release(); + } + + SymNode toSymNodeImpl() const; + + bool expect_bool() const { + TORCH_CHECK(!is_symbolic()); + return data_; + } + + SymBool sym_and(const SymBool&) const; + SymBool sym_or(const SymBool&) const; + SymBool sym_not() const; + + SymBool operator&(const SymBool& other) const { + return sym_and(other); + } + SymBool operator|(const SymBool& other) const { + return sym_or(other); + } + SymBool operator~() const { + return sym_not(); + } + + // Insert a guard for the bool to be its concrete value, and then return + // that value. Note that C++ comparison operations default to returning + // bool, so it's not so common to have to call this + bool guard_bool(const char* file, int64_t line) const; + + C10_ALWAYS_INLINE bool is_symbolic() const { + return ptr_; + } + + bool as_bool_unchecked() const { + return data_; + } + + private: + // TODO: optimize to union + bool data_; + SymNode ptr_; +}; + +C10_API std::ostream& operator<<(std::ostream& os, const SymBool& s); +} // namespace c10 diff --git a/c10/core/SymInt.cpp b/c10/core/SymInt.cpp index 21b83b122bf59..faa0d650b038a 100644 --- a/c10/core/SymInt.cpp +++ b/c10/core/SymInt.cpp @@ -94,48 +94,52 @@ SymInt SymInt::operator%(const SymInt& sci) const { return SymInt(res[0]->mod(res[1])); } -bool SymInt::operator==(const SymInt& sci) const { +SymBool SymInt::sym_eq(const SymInt& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return data_ == sci.data_; } auto res = normalize_symints(*this, sci); - return res[0]->eq(res[1])->bool_(); + return res[0]->eq(res[1]); } -bool SymInt::operator!=(const SymInt& sci) const { - return !(*this == sci); +SymBool SymInt::sym_ne(const SymInt& sci) const { + if (!is_symbolic() && !sci.is_symbolic()) { + return data_ != sci.data_; + } + auto res = normalize_symints(*this, sci); + return res[0]->ne(res[1]); } -bool SymInt::operator<(const SymInt& sci) const { +SymBool SymInt::sym_lt(const SymInt& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return data_ < sci.data_; } auto res = normalize_symints(*this, sci); - return res[0]->lt(res[1])->bool_(); + return res[0]->lt(res[1]); } -bool SymInt::operator<=(const SymInt& sci) const { +SymBool SymInt::sym_le(const SymInt& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return data_ <= sci.data_; } auto res = normalize_symints(*this, sci); - return res[0]->le(res[1])->bool_(); + return res[0]->le(res[1]); } -bool SymInt::operator>(const SymInt& sci) const { +SymBool SymInt::sym_gt(const SymInt& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return data_ > sci.data_; } auto res = normalize_symints(*this, sci); - return res[0]->gt(res[1])->bool_(); + return res[0]->gt(res[1]); } -bool SymInt::operator>=(const SymInt& sci) const { +SymBool SymInt::sym_ge(const SymInt& sci) const { if (!is_symbolic() && !sci.is_symbolic()) { return data_ >= sci.data_; } auto res = normalize_symints(*this, sci); - return res[0]->ge(res[1])->bool_(); + return res[0]->ge(res[1]); } SymInt SymInt::min(const SymInt& sci) const { diff --git a/c10/core/SymInt.h b/c10/core/SymInt.h index b2a46b751611c..ca3e718f8c02d 100644 --- a/c10/core/SymInt.h +++ b/c10/core/SymInt.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -157,16 +158,36 @@ class C10_API SymInt { SymInt operator*(const SymInt& sci) const; SymInt operator/(const SymInt& sci) const; SymInt operator%(const SymInt& sci) const; - bool operator==(const SymInt& sci) const; - bool operator!=(const SymInt& p2) const; - bool operator<(const SymInt& sci) const; - bool operator<=(const SymInt& sci) const; - bool operator>(const SymInt& sci) const; - bool operator>=(const SymInt& sci) const; void operator*=(const SymInt& sci); void operator+=(const SymInt& sci); void operator/=(const SymInt& sci); + SymBool sym_eq(const SymInt&) const; + SymBool sym_ne(const SymInt&) const; + SymBool sym_lt(const SymInt&) const; + SymBool sym_le(const SymInt&) const; + SymBool sym_gt(const SymInt&) const; + SymBool sym_ge(const SymInt&) const; + + bool operator==(const SymInt& o) const { + return sym_eq(o).guard_bool(__FILE__, __LINE__); + } + bool operator!=(const SymInt& o) const { + return sym_ne(o).guard_bool(__FILE__, __LINE__); + } + bool operator<(const SymInt& o) const { + return sym_lt(o).guard_bool(__FILE__, __LINE__); + } + bool operator<=(const SymInt& o) const { + return sym_le(o).guard_bool(__FILE__, __LINE__); + } + bool operator>(const SymInt& o) const { + return sym_gt(o).guard_bool(__FILE__, __LINE__); + } + bool operator>=(const SymInt& o) const { + return sym_ge(o).guard_bool(__FILE__, __LINE__); + } + SymInt min(const SymInt& sci) const; SymInt max(const SymInt& sci) const; diff --git a/c10/core/SymNodeImpl.h b/c10/core/SymNodeImpl.h index d5e62eca9fb5f..c87ed6c75a7fb 100644 --- a/c10/core/SymNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -25,6 +26,9 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { virtual bool is_int() { TORCH_CHECK(false, "NYI"); }; + virtual bool is_bool() { + TORCH_CHECK(false, "NYI"); + }; virtual bool is_float() { TORCH_CHECK(false, "NYI"); }; @@ -82,6 +86,21 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { virtual SymNode sym_max(const SymNode& other) { TORCH_CHECK(false, "NYI"); }; + virtual SymNode sym_or(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode sym_and(const SymNode& other) { + TORCH_CHECK(false, "NYI"); + }; + virtual SymNode sym_not() { + TORCH_CHECK(false, "NYI"); + }; + // NB: self is ignored here, only the arguments are used + virtual SymNode is_non_overlapping_and_dense( + ArrayRef sizes, + ArrayRef strides) { + TORCH_CHECK(false, "NYI"); + }; virtual SymNode clone() { TORCH_CHECK(false, "NYI"); }; @@ -94,9 +113,15 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { virtual SymNode wrap_float(double num) { TORCH_CHECK(false, "NYI"); }; + virtual SymNode wrap_bool(bool num) { + TORCH_CHECK(false, "NYI"); + }; virtual int64_t guard_int(const char* file, int64_t line) { TORCH_CHECK(false, "NYI"); }; + virtual bool guard_bool(const char* file, int64_t line) { + TORCH_CHECK(false, "NYI"); + }; virtual double guard_float(const char* file, int64_t line) { TORCH_CHECK(false, "NYI"); }; diff --git a/docs/source/conf.py b/docs/source/conf.py index f4d1d8b68eb92..90f1659d30e51 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -336,8 +336,6 @@ "Quantize", # torch.utils.backcompat "Warning", - "SymInt", - "SymFloat", ] # The suffix(es) of source filenames. diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 9e7be9072b67c..95dfad77f108d 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -621,6 +621,15 @@ Utilities Symbolic Numbers ---------------- +.. autoclass:: SymInt + :members: + +.. autoclass:: SymFloat + :members: + +.. autoclass:: SymBool + :members: + .. autosummary:: :toctree: generated :nosignatures: @@ -629,6 +638,7 @@ Symbolic Numbers sym_int sym_max sym_min + sym_not Optimizations ------------- diff --git a/test/cpp/jit/test_misc.cpp b/test/cpp/jit/test_misc.cpp index 22c344955345a..450e3eff332ac 100644 --- a/test/cpp/jit/test_misc.cpp +++ b/test/cpp/jit/test_misc.cpp @@ -1456,79 +1456,6 @@ TEST(TestSymInt, AddSymbolicInt) { ASSERT_TRUE((a + b).expect_int() == 8); } -#ifndef C10_MOBILE -class TestSymNodeImpl : public c10::SymNodeImpl { - public: - explicit TestSymNodeImpl(int64_t i) : i_(i) {} - - bool is_int() override { - return true; - }; - - bool is_float() override { - return false; - }; - - bool bool_() override { - return static_cast(i_); - }; - -#define OPDEF3(NAME, OP, RET) \ - RET NAME(const c10::SymNode& other) override { \ - return make_intrusive( \ - this->i_ OP dynamic_cast(other.get())->i_); \ - } - -#define OPDEF2(NAME, OP) OPDEF3(NAME, OP, c10::SymNode) - OPDEF2(add, +) - OPDEF2(sub, -) - OPDEF2(mul, *) - OPDEF2(floordiv, /) - OPDEF2(mod, %) - - OPDEF2(eq, ==) - OPDEF2(ne, !=) - OPDEF2(lt, <) - OPDEF2(le, <=) - OPDEF2(gt, >) - OPDEF2(ge, >=) -#undef OPDEF2 -#undef OPDEF3 - - int64_t i_; -}; - -TEST(TestSymInt, TestSymIntToSymNodeDispatch) { - auto get = [](c10::SymInt si) { - auto node = si.toSymNodeImpl(); - return dynamic_cast(node.get())->i_; - }; - - std::vector inputs{0, 1, -1, 4, -4, 777, -777}; - for (auto i : inputs) { - for (auto j : inputs) { - auto a = c10::SymInt( - static_cast(c10::make_intrusive(i))); - auto b = c10::SymInt( - static_cast(c10::make_intrusive(j))); - ASSERT_EQ(get(a + b), i + j); - ASSERT_EQ(get(a - b), i - j); - ASSERT_EQ(get(a * b), i * j); - if (j != 0) { - ASSERT_EQ(get(a / b), i / j); - ASSERT_EQ(get(a % b), i % j); - } - ASSERT_EQ(a == b, i == j); - ASSERT_EQ(a != b, i != j); - ASSERT_EQ(a < b, i < j); - ASSERT_EQ(a <= b, i <= j); - ASSERT_EQ(a > b, i > j); - ASSERT_EQ(a >= b, i >= j); - } - } -} -#endif - TEST(FallbackGraphsTest, Basic) { auto x = at::randn({1}, at::kCPU); auto y = at::randn({1}, at::kCPU); diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index bb16611f3c5c1..e1545708e10b1 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -14,7 +14,6 @@ import contextlib import math import atexit -import io import os from torch.utils._pytree import tree_map from torch.fx.experimental import symbolic_shapes @@ -389,6 +388,13 @@ def test_int_conversion(self): a0 = create_symint(shape_env, 2) self.assertRaisesRegex(RuntimeError, "Trying to extract", lambda: int(a0)) + @skipIfNoSympy + def test_non_overlapping_and_dense(self): + shape_env = ShapeEnv() + a0 = create_symint(shape_env, 5) + r = torch.empty_strided((a0, 7), (1, a0), device='meta') + self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r)) + @skipIfNoSympy def test_symint_as_scalar(self): shape_env = ShapeEnv() @@ -414,8 +420,7 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): self.assertTrue(sym_int_encountered) @skipIfNoSympy - @unittest.mock.patch('sys.stdout', new_callable=io.StringIO) - def test_print_readable_with_symints(self, mock_stdout): + def test_print_readable_with_symints(self): def f(a, b): dim0 = a.shape[0] + b.shape[0] dim1 = a.shape[1] + b.shape[1] @@ -424,9 +429,9 @@ def f(a, b): return d fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3)) - fx_g.print_readable() + out = fx_g.print_readable(print_output=False) - self.assertExpectedInline(mock_stdout.getvalue().strip(), """\ + self.assertExpectedInline(out.strip(), """\ class f(torch.nn.Module): def forward(self, a_1: f32[s0, s1], b_1: f32[s2, s1]): # No stacktrace found for following nodes @@ -484,9 +489,13 @@ class TestSymNumberMagicMethods(TestCase): def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn): # Helper function seed_node = (create_symint(shape_env, 1) / 1.).node + bool_seed_node = (create_symint(shape_env, 1) == 1).node def get_sym_inp(inp): - if isinstance(inp, int): + # NB: this must come before int + if isinstance(inp, bool): + return torch.SymBool(to_node(bool_seed_node, inp)) + elif isinstance(inp, int): return torch.SymInt(to_node(seed_node, inp)) else: return torch.SymFloat(to_node(seed_node, inp)) @@ -511,6 +520,8 @@ def context(): lambda_apply = getattr(math, fn) elif fn in symbolic_shapes.magic_methods_on_submodule: lambda_apply = getattr(symbolic_shapes, fn) + elif fn in symbolic_shapes.magic_methods_on_operator_with_trailing_underscore: + lambda_apply = getattr(operator, f"{fn}_") else: lambda_apply = getattr(operator, fn) @@ -518,16 +529,15 @@ def context(): tp = "float" elif fn in symbolic_shapes.always_int_magic_methods: tp = "int" + elif fn in symbolic_shapes.always_bool_magic_methods: + tp = "bool" elif is_unary_fn: tp = "float" if isinstance(inp1, float) else "int" else: tp = "float" if any(isinstance(i, float) for i in [inp1, inp2]) else "int" def guard_fn(v): - if fn in symbolic_shapes.always_bool_magic_methods: - return bool(v) - else: - return getattr(v.node, f"guard_{tp}")("", 0) + return getattr(v.node, f"guard_{tp}")("", 0) # Get reference result with maybe_xfail(inp1, inp2): @@ -560,11 +570,22 @@ def guard_fn(v): self.assertEqual(guard_fn(out), ref_out) + @parametrize("fn", list(symbolic_shapes.magic_methods.keys())) + def test_bool_method(self, fn): + if fn not in symbolic_shapes.bool_magic_methods: + self.skipTest(f"{fn} is non-bool") + + is_unary_fn = fn in symbolic_shapes.unary_magic_methods + shape_env = ShapeEnv() + self._do_test(fn, True, False, shape_env, is_unary_fn) + + @parametrize("fn", list(symbolic_shapes.magic_methods.keys())) @parametrize("first_type", ["int", "float"]) @parametrize("second_type", ["int", "float"]) def test_method(self, fn, first_type, second_type): if first_type == "float": + # TODO: Hmm, this looks like we skip all floats self.skipTest(f"{fn} is not a float magic method") is_unary_fn = fn in symbolic_shapes.unary_magic_methods @@ -572,6 +593,9 @@ def test_method(self, fn, first_type, second_type): if is_unary_fn and second_type == "float": self.skipTest(f"{fn} is unary and already tested") + if fn in symbolic_shapes.bool_magic_methods: + self.skipTest(f"{fn} is bool") + # We could pass int/float directly for types but then the # mangled test name is bad inp1 = random.random() * 2.5 diff --git a/torch/__init__.py b/torch/__init__.py index 4bd03af83fa18..57fa666dfca9d 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -50,6 +50,7 @@ 'set_deterministic_debug_mode', 'get_deterministic_debug_mode', 'set_float32_matmul_precision', 'get_float32_matmul_precision', 'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat', + 'SymBool', 'sym_not', 'sym_int', 'sym_float', 'sym_max', 'sym_min', 'compile', 'vmap' ] @@ -315,6 +316,66 @@ def __sym_min__(self, other): def __repr__(self): return self.node.str() +class SymBool: + """ + Like an bool (including magic methods), but redirects all operations on the + wrapped node. This is used in particular to symbolically record operations + in the symbolic shape workflow. + + Unlike regular bools, regular boolean operators will force extra guards instead + of symbolically evaluate. Use the bitwise operators instead to handle this. + """ + + def __init__(self, node): + from torch.fx.experimental.symbolic_shapes import SymNode + assert isinstance(node, SymNode) + # This field MUST be named node; C++ binding code assumes that this + # class has a field named node that stores SymNode + self.node = node + + def __bool__(self): + return self.node.bool_() + + # Magic methods installed by torch.fx.experimental.symbolic_shapes + def __and__(self, other) -> "SymBool": + raise AssertionError("type stub not overridden") + + def __or__(self, other) -> "SymBool": + raise AssertionError("type stub not overridden") + + # We very carefully define __sym_not__, and not a number of other + # plausible alternatives: + # + # - We do not override __not__ because this is not a real magic + # method; you cannot override the meaning of the not builtin in + # Python. We use the name 'sym_not' to clarify that in user code you + # cannot use the builtin not or operator.not_ or operator.__not__ and + # hit this magic method; you must use our custom sym_not operator. + # + # - We do not override the __invert__ method because SymBool is + # meant to be usable in situations where bool is expected. However, + # bitwise negation ~a does the wrong thing with booleans (because + # bool is a subclass of int, so ~1 = -2 which is not falseish.) + # This would be a giant footgun, so we get around it by defining + # our own operator. Note that bitwise and/or do the right thing, + # so we reuse the conventional operators there for readability. + # + def __sym_not__(self) -> "SymBool": + raise AssertionError("type stub not overridden") + + def __repr__(self): + return self.node.str() + +def sym_not(a): + r""" SymInt-aware utility for logical negation. + + Args: + a (SymBool or bool): Object to negate + """ + if hasattr(a, '__sym_not__'): + return a.__sym_not__() + return not a + def sym_float(a): r""" SymInt-aware utility for float casting. diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 515ebd9f0c56e..fed555c8cd7e8 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1159,6 +1159,7 @@ void initJITBindings(PyObject* module) { SYMNODE_UNARY(clone) SYMNODE_UNARY(is_int) SYMNODE_UNARY(is_float) + SYMNODE_UNARY(is_bool) SYMNODE_UNARY(bool_) SYMNODE_UNARY(int_) SYMNODE_UNARY(sym_float) @@ -1170,22 +1171,35 @@ void initJITBindings(PyObject* module) { SYMNODE_BINARY(floordiv) SYMNODE_BINARY(mod) SYMNODE_BINARY(eq) + SYMNODE_BINARY(ne) SYMNODE_BINARY(gt) SYMNODE_BINARY(lt) SYMNODE_BINARY(le) SYMNODE_BINARY(ge) SYMNODE_BINARY(sym_min) SYMNODE_BINARY(sym_max) + SYMNODE_BINARY(sym_and) + SYMNODE_BINARY(sym_or) + SYMNODE_UNARY(sym_not) SYMNODE_UNARY(ceil) SYMNODE_UNARY(floor) SYMNODE_UNARY(neg) // Intentionally don't set file line, as the // Python backtrace matters more here + .def("is_non_overlapping_and_dense", + [](c10::SymNode a, c10::ArrayRef sizes, c10::ArrayRef strides) { + return a->is_non_overlapping_and_dense(sizes, strides); + }) .def( "guard_int", [](c10::SymNode a) { return a->guard_int(nullptr, 0); }) + .def( + "guard_bool", + [](c10::SymNode a) { + return a->guard_bool(nullptr, 0); + }) .def( "guard_float", [](c10::SymNode a) { @@ -1201,6 +1215,11 @@ void initJITBindings(PyObject* module) { [](c10::SymNode a, double b) { return a->wrap_float(b); }) + .def( + "wrap_bool", + [](c10::SymNode a, bool b) { + return a->wrap_bool(b); + }) .def( "__str__", [](c10::SymNode a) { return a->str(); }) diff --git a/torch/csrc/utils.cpp b/torch/csrc/utils.cpp index 5fc91d68dd180..b42e389723b5b 100644 --- a/torch/csrc/utils.cpp +++ b/torch/csrc/utils.cpp @@ -390,5 +390,27 @@ handle type_caster::cast( return t.release(); } +bool type_caster>::load(handle src, bool) { + TORCH_INTERNAL_ASSERT(0, "NYI"); +} +handle type_caster>::cast( + at::ArrayRef src, + return_value_policy /* policy */, + handle /* parent */) { + py::list t(src.size()); + for (const auto i : c10::irange(src.size())) { + // TODO: this is terrible but I don't know how to override when + // the SymNode is also explicitly cast by py::cast + auto* py_node = dynamic_cast(src[i].get()); + if (py_node) { + // Return the Python directly (unwrap) + t[i] = py_node->getPyObj(); + } else { + t[i] = py::cast(src[i]); + } + } + return t.release(); +} + } // namespace detail } // namespace pybind11 diff --git a/torch/csrc/utils/pybind.h b/torch/csrc/utils/pybind.h index e79a55370aec6..2e91207a4c8e6 100644 --- a/torch/csrc/utils/pybind.h +++ b/torch/csrc/utils/pybind.h @@ -125,6 +125,22 @@ struct TORCH_PYTHON_API type_caster { std::vector v_value; }; +template <> +struct TORCH_PYTHON_API type_caster> { + public: + // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + PYBIND11_TYPE_CASTER(at::ArrayRef, _("List[SymNode]")); + + bool load(handle src, bool); + static handle cast( + at::ArrayRef src, + return_value_policy /* policy */, + handle /* parent */); + + private: + std::vector v_value; +}; + template <> struct type_caster { public: diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h index a1fad0047fb39..ec0c49c64dc58 100644 --- a/torch/csrc/utils/python_symnode.h +++ b/torch/csrc/utils/python_symnode.h @@ -35,13 +35,27 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { c10::SymNode wrap_int(int64_t num) override { py::gil_scoped_acquire acquire; auto r = getPyObj().attr("wrap_int")(num); - return c10::make_intrusive(r); + return c10::make_intrusive(std::move(r)); } c10::SymNode wrap_float(double num) override { py::gil_scoped_acquire acquire; auto r = getPyObj().attr("wrap_float")(num); - return c10::make_intrusive(r); + return c10::make_intrusive(std::move(r)); + } + + c10::SymNode wrap_bool(bool num) override { + py::gil_scoped_acquire acquire; + auto r = getPyObj().attr("wrap_bool")(num); + return c10::make_intrusive(std::move(r)); + } + + c10::SymNode is_non_overlapping_and_dense( + c10::ArrayRef sizes, + c10::ArrayRef strides) override { + py::gil_scoped_acquire acquire; + auto r = getPyObj().attr("is_non_overlapping_and_dense")(sizes, strides); + return c10::make_intrusive(std::move(r)); } bool bool_() override { @@ -59,6 +73,11 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { return getPyObj().attr("is_float")().is(py::handle(Py_True)); } + bool is_bool() override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("is_bool")().is(py::handle(Py_True)); + } + int64_t guard_int(const char* file, int64_t line) override { py::gil_scoped_acquire acquire; return getPyObj().attr("guard_int")(file, line).cast(); @@ -69,6 +88,11 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { return getPyObj().attr("guard_float")(file, line).cast(); } + bool guard_bool(const char* file, int64_t line) override { + py::gil_scoped_acquire acquire; + return getPyObj().attr("guard_bool")(file, line).cast(); + } + int64_t int_() override { py::gil_scoped_acquire acquire; return getPyObj().attr("int_")().cast(); @@ -125,6 +149,10 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { return dispatch_common_(__func__, other); } + c10::SymNode ne(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + c10::SymNode gt(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } @@ -148,6 +176,18 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { return dispatch_common_(__func__, other); } + c10::SymNode sym_and(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode sym_or(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode sym_not() override { + return dispatch_common_(__func__); + } + c10::SymNode ceil() override { return dispatch_common_(__func__); } diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 36cfdd85b7006..ca2a5f58a6218 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -23,7 +23,7 @@ from torch._subclasses import FakeTensor from .symbolic_shapes import ShapeEnv, SymDispatchMode, SymNode from torch.fx import Proxy -from torch import SymInt, SymFloat +from torch import SymInt, SymFloat, SymBool from torch.utils.weak import WeakTensorKeyDictionary __all__ = ["PythonKeyTracer", "dispatch_trace", "make_fx", "DecompositionInterpreter", "py_sym_types", "get_innermost_proxy_mode"] @@ -57,7 +57,7 @@ def decompose(decomposition_table): proxy_slot = object() no_default = object() -py_sym_types = (SymInt, SymFloat) +py_sym_types = (SymInt, SymFloat, SymBool) def is_sym_node(node): assert hasattr(node, 'meta'), "All nodes traced with proxy_tensor should have meta" return "val" in node.meta and isinstance(node.meta['val'], py_sym_types) @@ -104,7 +104,7 @@ def snapshot_fake(val): def unwrap_proxy(proxy_mode, e): if isinstance(e, torch.Tensor): return get_proxy_slot(e, proxy_mode.tracer, e, lambda e: e.proxy) - elif isinstance(e, (torch.SymInt, torch.SymFloat)): + elif isinstance(e, (torch.SymInt, torch.SymFloat, torch.SymBool)): return get_proxy_slot(e.node, proxy_mode.tracer, e, lambda e: e()) else: return e @@ -263,7 +263,7 @@ def can_handle_tensor(x): pytree.tree_all_only(_ProxyTensor, lambda t: t.constant is not None, (f_args, f_kwargs)) # TODO: maybe constant SymInts should also be allowed? Not sure if # this can happen - and pytree.tree_all_only((SymInt, SymFloat), lambda _: False, (args, kwargs)) + and pytree.tree_all_only((SymInt, SymFloat, SymBool), lambda _: False, (args, kwargs)) ) if torch.Tag.data_dependent_output in func.tags: # type: ignore[attr-defined] @@ -282,7 +282,7 @@ def can_handle_tensor(x): "It's likely that this is caused by data-dependent control flow or similar." ) proxy_args, proxy_kwargs = pytree.tree_map_only( - (SymInt, SymFloat), + (SymInt, SymFloat, SymBool), fetch_sym_proxy(proxy_mode.tracer), pytree.tree_map_only(_ProxyTensor, lambda e: e.proxy, (f_args, f_kwargs)) ) @@ -425,7 +425,7 @@ def create_arg(self, a: Any): setattr(self.root, qualname, a) return self.create_node('get_attr', qualname, (), {}) - elif isinstance(a, (SymInt, SymFloat)): + elif isinstance(a, (SymInt, SymFloat, SymBool)): assert a.node.constant is not None return a.node.constant return super().create_arg(a) @@ -459,7 +459,7 @@ def wrapped(*proxies): out ) out = pytree.tree_map_only( - (SymInt, SymFloat), + (SymInt, SymFloat, SymBool), lambda t: get_proxy_slot(t.node, tracer)(), out ) @@ -526,7 +526,7 @@ def enable(self, b): finally: self.enable_tracing = old - def _compute_proxy(self, func, args, out: Union[SymInt, SymFloat]): + def _compute_proxy(self, func, args, out: Union[SymInt, SymFloat, SymBool]): n_args = tuple( get_proxy_slot(a.node, self.tracer)().node if isinstance(a, py_sym_types) else a for a in args diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 2f179db4aa675..49306ef053de6 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -1,5 +1,5 @@ import torch -from typing import Set, Dict, List, Type, Optional, cast +from typing import Set, Dict, List, Type, Optional, cast, Union import sys import itertools import operator @@ -14,9 +14,11 @@ import logging # NB: The sym_* functions are used via getattr() and must be imported here. -from torch import SymInt, SymFloat, sym_float, sym_int, sym_max, sym_min # noqa: F401 +from torch import SymInt, SymFloat, SymBool, sym_not, sym_float, sym_int, sym_max, sym_min # noqa: F401 from torch._guards import ShapeGuard, Source +SymTypes = (SymInt, SymFloat, SymBool) + log = logging.getLogger(__name__) try: @@ -100,7 +102,7 @@ def _handle_sym_dispatch(func, args, kwargs): def guard_int(a): if isinstance(a, SymInt): return a.node.guard_int("", 0) # NB: uses Python backtrace - assert isinstance(a, int) + assert type(a) is int return a # Drop in replacement for math.sqrt @@ -110,11 +112,13 @@ def sym_sqrt(a): return math.sqrt(a) def to_node(self, num): - if isinstance(num, (SymInt, SymFloat)): + if isinstance(num, SymTypes): return num.node - elif isinstance(num, int): + elif type(num) is bool: + return self.wrap_bool(num) + elif type(num) is int: return self.wrap_int(num) - elif isinstance(num, float): + elif type(num) is float: return self.wrap_float(num) else: # NotImplemented is important so that Python tries the @@ -162,14 +166,21 @@ def is_int(self): def is_float(self): return self.pytype is float + def is_bool(self): + return self.pytype is bool + def wrap_int(self, num): - assert isinstance(num, int) + assert type(num) is int return SymNode(sympy.Integer(num), self.shape_env, int, constant=num) def wrap_float(self, num): - assert isinstance(num, float) + assert type(num) is float return SymNode(sympy.Float(num), self.shape_env, float, constant=num) + def wrap_bool(self, num): + assert type(num) is bool + return SymNode(sympy.true if num else sympy.false, self.shape_env, bool, constant=num) + def clone(self): return self @@ -189,6 +200,19 @@ def sym_int(self) -> "SymNode": # noqa: F811 def sym_float(self) -> "SymNode": # noqa: F811 raise AssertionError("should have been overridden") + def or_(self, other) -> "SymNode": # noqa: F811 + raise AssertionError("should have been overridden") + + def and_(self, other) -> "SymNode": # noqa: F811 + raise AssertionError("should have been overridden") + + # Make C++ happy + def sym_or(self, other): + return self.or_(other) + + def sym_and(self, other): + return self.and_(other) + # Today we error on calling int on a symbolic shape, as this is a very accessible footgun. def int_(self): raise RuntimeError("Trying to extract a concrete int out of a symbolic int") @@ -204,7 +228,14 @@ def guard_float(self, file, line): # guard occurred return float(self.shape_env.evaluate_expr(self.expr)) + def guard_bool(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + # TODO: why is the replace needed here? + return bool(self.shape_env.evaluate_expr(self.shape_env.replace(self.expr))) + def bool_(self): + # TODO: why is the replace needed here? return bool(self.shape_env.evaluate_expr(self.shape_env.replace(self.expr))) @@ -242,6 +273,28 @@ def eval(cls, base, divisor): sympy.simplify(base / gcd), sympy.simplify(divisor / gcd) ) + class IsNonOverlappingAndDenseIndicator(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, *args): + assert len(args) % 2 == 0 + if all(isinstance(a, sympy.Integer) for a in args): + dim = len(args) // 2 + sizes = args[0:dim] + strides = args[dim:] + return int(eval_is_non_overlapping_and_dense( + [int(s) for s in sizes], + [int(s) for s in strides] + )) + return None + +def safe_expand(r): + if hasattr(r, 'expand'): + return sympy.expand(r) + else: + return r + # Methods that have a `__foo__` as well as `__rfoo__` reflectable_magic_methods = { 'add': lambda a, b: a + b, @@ -249,13 +302,17 @@ def eval(cls, base, divisor): 'mul': lambda a, b: a * b, 'mod': lambda a, b: a % b, 'pow': lambda a, b: a ** b, + 'and': lambda a, b: a & b, + 'or': lambda a, b: a | b, 'truediv': lambda a, b: a / b, 'floordiv': lambda a, b: FloorDiv(a, b), } magic_methods = { **reflectable_magic_methods, + 'sym_not': lambda a: ~a, 'eq': lambda a, b: sympy.Eq(a, b), + 'ne': lambda a, b: sympy.Ne(a, b), 'gt': lambda a, b: sympy.Gt(a, b), 'lt': lambda a, b: sympy.Lt(a, b), 'le': lambda a, b: sympy.Le(a, b), @@ -269,20 +326,72 @@ def eval(cls, base, divisor): 'sym_sqrt': lambda a: sympy.sqrt(a), } +sizes_strides_methods = { + 'is_non_overlapping_and_dense': lambda *args: IsNonOverlappingAndDenseIndicator(*args), +} + +# TODO: Deduplicate this with torch/_prims_common/__init__.py +def eval_is_non_overlapping_and_dense(sizes, strides): + dim = len(sizes) + + # Short-circuits for tensors of rank one, which are + # non-overlapping and "dense" if their stride is one + # or it is a 0/1 element tensor + if dim == 1: + return strides[0] == 1 or sizes[0] < 2 + + # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous + # Sorts (length, stride) pairs by stride + lengths_and_strides = sorted( + tuple(zip(sizes, strides)), key=operator.itemgetter(1) + ) + + # Unlike the C++ code, we don't move the 0/1 size dimensions to the + # end. So we have to keep going for this code. + expected_stride = 1 + for length, stride in lengths_and_strides: + + if length == 1: + continue + + if stride != expected_stride: + return False + + expected_stride *= length + + return True + +def is_non_overlapping_and_dense(sizes, strides): + base = None + for s in itertools.chain(sizes, strides): + if isinstance(s, SymInt): + base = s + break + + assert base is not None + return wrap_node(base.node.is_non_overlapping_and_dense( + [to_node(base.node, s) for s in sizes], + [to_node(base.node, s) for s in strides], + )) + unary_magic_methods = { 'sym_float', 'ceil', 'floor', 'neg', 'sym_sqrt', + 'sym_not', } +bool_magic_methods = {"and", "or", "sym_not"} + magic_methods_on_math = {"ceil", "floor"} -magic_methods_on_submodule = {"sym_float", "sym_sqrt", "sym_min", "sym_max"} +magic_methods_on_submodule = {"sym_float", "sym_sqrt", "sym_min", "sym_max", "sym_not"} +magic_methods_on_operator_with_trailing_underscore = {"and", "or"} always_float_magic_methods = {"truediv", "sym_float", "sym_sqrt"} always_int_magic_methods = {"ceil", "floor"} -always_bool_magic_methods = {"eq", "gt", "lt", "le", "ge"} +always_bool_magic_methods = {"eq", "ne", "gt", "lt", "le", "ge", "and", "or", "sym_not", "is_non_overlapping_and_dense"} def wrap_node(x): # TODO: let C++ also take advantage of this @@ -292,21 +401,28 @@ def wrap_node(x): return SymInt(x) elif x.is_float(): return SymFloat(x) + elif x.is_bool(): + return SymBool(x) else: raise AssertionError(f"unrecognized return type {x}") def _make_node_magic(method, func): func = lru_cache(256)(func) + if method in magic_methods_on_operator_with_trailing_underscore: + method_attr = f"{method}_" + else: + method_attr = method + def binary_magic_impl(self, other): if method in magic_methods_on_submodule: - op = getattr(sys.modules[__name__], method) + op = getattr(sys.modules[__name__], method_attr) else: assert method not in magic_methods_on_math - op = getattr(operator, method) + op = getattr(operator, method_attr) if SYM_FUNCTION_MODE: r = _handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) - assert isinstance(r, (SymInt, SymFloat)), type(r) + assert isinstance(r, SymTypes), type(r) return r.node assert isinstance(other, SymNode) other_expr = other.expr @@ -318,27 +434,27 @@ def binary_magic_impl(self, other): except Exception: log.warning(f"failed to eval {method}({expr}, {other_expr})") raise - out = sympy.expand(out) + out = safe_expand(out) pytype: Type if method in always_float_magic_methods: pytype = float + elif method in always_bool_magic_methods: + pytype = bool else: pytype = self.pytype - # TODO: relational operators actually technically return a - # PySymBool, this is a type error return SymNode(out, self.shape_env, pytype) def unary_magic_impl(self): if SYM_FUNCTION_MODE: if method in magic_methods_on_math: - op = getattr(math, method) + op = getattr(math, method_attr) elif method in magic_methods_on_submodule: - op = getattr(sys.modules[__name__], method) + op = getattr(sys.modules[__name__], method_attr) else: - op = getattr(operator, method) + op = getattr(operator, method_attr) r = _handle_sym_dispatch(op, (wrap_node(self),), {}) - assert isinstance(r, (SymInt, SymFloat)), type(r) + assert isinstance(r, SymTypes), type(r) return r.node # TODO: consider constant prop here expr = self.shape_env.replace(self.expr) @@ -347,7 +463,7 @@ def unary_magic_impl(self): except Exception: log.warning(f"failed to eval {method}({expr})") raise - out = sympy.expand(out) + out = safe_expand(out) pytype: Type if method in always_int_magic_methods: pytype = int @@ -359,31 +475,60 @@ def unary_magic_impl(self): return SymNode(out, self.shape_env, pytype) if method in unary_magic_methods: - setattr(SymNode, method, unary_magic_impl) + setattr(SymNode, method_attr, unary_magic_impl) else: - setattr(SymNode, method, binary_magic_impl) + setattr(SymNode, method_attr, binary_magic_impl) + +def _make_node_sizes_strides(method, func): + # NB: don't LRU cache, lots of arguments + + def sizes_strides_impl(self, sizes, strides): + op = getattr(sys.modules[__name__], method) + if SYM_FUNCTION_MODE: + r = _handle_sym_dispatch(op, ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]), {}) + assert isinstance(r, SymBool), type(r) + return r.node + size_exprs = [s.expr for s in sizes] + stride_exprs = [s.expr for s in strides] + try: + out = func(*size_exprs, *stride_exprs) + except Exception: + log.warning(f"failed to eval {method}(*{size_exprs}, *{stride_exprs})") + raise + # bool is never expandable + return SymNode(sympy.Eq(out, 1), self.shape_env, bool) + + setattr(SymNode, method, sizes_strides_impl) for method, func in magic_methods.items(): _make_node_magic(method, func) +for method, func in sizes_strides_methods.items(): + _make_node_sizes_strides(method, func) + def _make_user_magic(method, user_type): # User magic takes care of wrapping the other operand into a node, # so that our internal logic can assume everything is nodes + if method in magic_methods_on_operator_with_trailing_underscore: + method_attr = f"{method}_" + else: + method_attr = method + def unary_magic_impl(self): - return wrap_node(getattr(self.node, method)()) + return wrap_node(getattr(self.node, method_attr)()) def binary_magic_impl(self, other): other_node = to_node(self.node, other) if other_node is NotImplemented: return NotImplemented - return wrap_node(getattr(self.node, method)(other_node)) + return wrap_node(getattr(self.node, method_attr)(other_node)) def rbinary_magic_impl(self, other): other_node = to_node(self.node, other) if other_node is NotImplemented: return NotImplemented - return wrap_node(getattr(other_node, method)(self.node)) + return wrap_node(getattr(other_node, method_attr)(self.node)) if method in unary_magic_methods: setattr(user_type, f"__{method}__", unary_magic_impl) @@ -393,8 +538,11 @@ def rbinary_magic_impl(self, other): setattr(user_type, f"__r{method}__", rbinary_magic_impl) for method, func in magic_methods.items(): - _make_user_magic(method, SymInt) - _make_user_magic(method, SymFloat) + if method in bool_magic_methods: + _make_user_magic(method, SymBool) + else: + _make_user_magic(method, SymInt) + _make_user_magic(method, SymFloat) del method del func @@ -854,7 +1002,7 @@ def _maybe_evaluate_static(self, expr: "sympy.Expr") -> "Optional[sympy.Expr]": floor_div_replace = {} for atom in new_expr.atoms(FloorDiv): floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1]) - new_expr = sympy.expand(new_expr.xreplace(floor_div_replace)) + new_expr = safe_expand(new_expr.xreplace(floor_div_replace)) if len(list(new_expr.free_symbols)) == 0: return new_expr return None @@ -862,7 +1010,7 @@ def _maybe_evaluate_static(self, expr: "sympy.Expr") -> "Optional[sympy.Expr]": @_lru_cache def replace(self, expr: "sympy.Expr") -> "sympy.Expr": replacements = {s: self._find(cast(sympy.Symbol, s)) for s in expr.free_symbols} - return sympy.expand(expr.xreplace(replacements)) + return safe_expand(expr.xreplace(replacements)) @_lru_cache def _update_divisible(self): @@ -885,7 +1033,7 @@ def simplify(self, expr: "sympy.Expr") -> "sympy.Expr": if self.replace(base % divisor) in self.divisible: div_replacements[atom] = base / divisor expr = expr.xreplace(div_replacements) - expr = sympy.expand(expr) + expr = safe_expand(expr) return expr @lru_cache(256) @@ -895,7 +1043,7 @@ def size_hint(self, expr: "sympy.Expr"): Does not introduce a guard, so only use this when you can guarantee that your code is still valid for arbitrary shapes (such as optimization decisions) """ - result_expr = sympy.expand(expr).xreplace(self.var_to_val) + result_expr = safe_expand(expr).xreplace(self.var_to_val) if len(result_expr.free_symbols) != 0: raise self._make_data_dependent_error(result_expr) return result_expr @@ -934,14 +1082,22 @@ def _find(self, a: "sympy.Symbol") -> "sympy.Expr": return self.replacements[a] @lru_cache(256) - def _maybe_guard_eq(self, expr: "sympy.Eq") -> None: + def _maybe_guard_eq(self, expr: Union["sympy.Eq", "sympy.Ne"]) -> None: """ Evaluates the result of an eq call. If true, uses information to simplify shapes (i.e. a == b or a % 5 == 0) """ concrete_bool = bool(self.size_hint(expr)) - if not concrete_bool: - return + if isinstance(expr, sympy.Eq): + if not concrete_bool: + return + # NB: Apparently this is load bearing; to see what test fails if + # you comment it out run: + # python test/functorch/test_aotdispatch.py -k + # test_aot_autograd_symbolic_module_exhaustive_nn_LazyConv3d_cpu_float32 + elif isinstance(expr, sympy.Ne): + if concrete_bool: + return free = list(expr.free_symbols) assert len(free) > 0, "The expression should not be static by this point" @@ -984,7 +1140,7 @@ def evaluate_expr(self, expr: "sympy.Expr"): if static_expr is not None: return static_expr - if isinstance(expr, sympy.Eq): + if isinstance(expr, (sympy.Eq, sympy.Ne)): self._maybe_guard_eq(expr) # TODO: If we successfully eliminate a symbol via equality, it # is not actually necessary to save a guard for the equality, diff --git a/torch/overrides.py b/torch/overrides.py index 6adafdaf95985..3c064a1b9747a 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -183,6 +183,7 @@ def get_ignored_functions() -> Set[Callable]: torch.sym_int, torch.sym_max, torch.sym_min, + torch.sym_not, torch.tril_indices, torch.triu_indices, torch.vander, diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index d89f1bcdfc3de..66dfb9a8a7e45 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -6,6 +6,7 @@ T = TypeVar('T') S = TypeVar('S') +U = TypeVar('U') R = TypeVar('R') """ @@ -195,8 +196,10 @@ def tree_map(fn: Any, pytree: PyTree) -> PyTree: return tree_unflatten([fn(i) for i in flat_args], spec) Type2 = Tuple[Type[T], Type[S]] +Type3 = Tuple[Type[T], Type[S], Type[U]] TypeAny = Union[Type[Any], Tuple[Type[Any], ...]] +Fn3 = Callable[[Union[T, S, U]], R] Fn2 = Callable[[Union[T, S]], R] Fn = Callable[[T], R] FnAny = Callable[[Any], R] @@ -255,6 +258,10 @@ def tree_map_only(ty: Type[T], fn: Fn[T, Any], pytree: PyTree) -> PyTree: def tree_map_only(ty: Type2[T, S], fn: Fn2[T, S, Any], pytree: PyTree) -> PyTree: ... +@overload +def tree_map_only(ty: Type3[T, S, U], fn: Fn3[T, S, U, Any], pytree: PyTree) -> PyTree: + ... + def tree_map_only(ty: TypeAny, fn: FnAny[Any], pytree: PyTree) -> PyTree: return tree_map(map_only(ty)(fn), pytree) @@ -274,6 +281,10 @@ def tree_all_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool: def tree_all_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool: ... +@overload +def tree_all_only(ty: Type3[T, S, U], pred: Fn3[T, S, U, bool], pytree: PyTree) -> bool: + ... + def tree_all_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool: flat_args, _ = tree_flatten(pytree) return all(pred(x) for x in flat_args if isinstance(x, ty))