Skip to content

Commit 2d885ab

Browse files
swolchokfacebook-github-bot
authored andcommitted
[jit] Reduce refcounting of Types (pytorch#65345)
Summary: Pull Request resolved: pytorch#65345 FooType::get() can return a const reference. Inconveniently, converting shared_ptr<FooType> to shared_ptr<Type> requires a copy & refcount bump, so to properly take advantage of this in unshapedType() we need to take a const Type& in isSubtypeOf(), which is good practice anyway -- don't require a shared_ptr if you don't need to take ownership. ghstack-source-id: 140044165 Test Plan: CI perf says c10::unshapedType time decreased from 2.8% to 2.2% during static runtime startup, though I expect this to be generally beneficial. Reviewed By: hlu1 Differential Revision: D31027361 fbshipit-source-id: 676feb81db9f74ad7b8651d8774f4ecb4cfa6ab8
1 parent 1ae468a commit 2d885ab

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

69 files changed

+421
-405
lines changed

aten/src/ATen/BatchedFallback.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ static bool areAnyArgumentsTensorList(const FunctionSchema& schema) {
3636
return std::any_of(
3737
schema.arguments().begin(),
3838
schema.arguments().end(),
39-
[] (const Argument& arg) { return arg.type()->isSubtypeOf(ListType::ofTensors()); });
39+
[] (const Argument& arg) { return arg.type()->isSubtypeOf(*ListType::ofTensors()); });
4040
}
4141

4242
// Returns if an operator is in-place. An operator is inplace if:

aten/src/ATen/core/List_inl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ List<T> toTypedList(impl::GenericList list) {
6363
// as List<Tensor> before we changed that argument to be List<optional<Tensor>>. When deserializing, we
6464
// have list.use_count() == 1 and can deserialize the List<Tensor> directly as List<optional<Tensor>>.
6565
TORCH_CHECK(*list.impl_->elementType == *getTypePtr<T>()
66-
|| (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(getTypePtr<T>()))
66+
|| (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(*getTypePtr<T>()))
6767
, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr<T>()), ">. Types mismatch.");
6868
return List<T>(std::move(list.impl_));
6969
}

aten/src/ATen/core/dispatch/DispatchKeyExtractor.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,13 +172,13 @@ struct TORCH_API DispatchKeyExtractor final {
172172
" arguments but this PyTorch build only supports ", c10::utils::bitset::NUM_BITS());
173173
c10::utils::bitset dispatch_arg_indices_reverse;
174174
for (size_t index = 0; index < schema.arguments().size(); ++index) {
175-
if (schema.arguments()[index].type()->isSubtypeOf(TensorType::get()) ||
175+
if (schema.arguments()[index].type()->isSubtypeOf(*TensorType::get()) ||
176176
schema.arguments()[index].type()->isSubtypeOf(
177-
ListType::ofTensors()) ||
177+
*ListType::ofTensors()) ||
178178
schema.arguments()[index].type()->isSubtypeOf(
179-
ListType::ofOptionalTensors()) ||
179+
*ListType::ofOptionalTensors()) ||
180180
schema.arguments()[index].type()->isSubtypeOf(
181-
OptionalType::ofTensor())) {
181+
*OptionalType::ofTensor())) {
182182
dispatch_arg_indices_reverse.set(schema.arguments().size() - 1 - index);
183183
}
184184
}

aten/src/ATen/core/function_schema_inl.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ inline bool Argument::isBackwardCompatibleWith(
7676
if (lhs->kwarg_only() && !rhs->kwarg_only()) {
7777
return false;
7878
}
79-
if (!rhs->type()->isSubtypeOfExt(lhs->type(), why_not)) {
79+
if (!rhs->type()->isSubtypeOfExt(*lhs->type(), why_not)) {
8080
return false;
8181
}
8282
if (rhs->default_value().has_value() &&
@@ -179,7 +179,7 @@ inline void FunctionSchema::checkArg(
179179
// Fast-path for the common case
180180
return;
181181
}
182-
if (!value.type()->isSubtypeOf(argument.type())) {
182+
if (!value.type()->isSubtypeOf(*argument.type())) {
183183
TORCH_CHECK(
184184
false,
185185
formatTypeMismatchMsg(
@@ -304,7 +304,7 @@ inline bool isSubtypeOfList(
304304
if (c.name() != p.name()) {
305305
return false;
306306
}
307-
if (!c.type()->isSubtypeOfExt(p.type(), why_not)) {
307+
if (!c.type()->isSubtypeOfExt(*p.type(), why_not)) {
308308
return false;
309309
}
310310
}

0 commit comments

Comments
 (0)