forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FunctionalStorageImpl.cpp
146 lines (134 loc) · 6.62 KB
/
FunctionalStorageImpl.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#include <ATen/FunctionalStorageImpl.h>
#include <ATen/EmptyTensor.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/SparseCsrTensorUtils.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <c10/util/Exception.h>
#include <vector>
namespace at::functionalization {
ViewMeta ViewMeta::to_out_idx(int64_t out_idx) {
if (out_idx == this->out_index) return *this;
return ViewMeta(forward_fn, reverse_fn, has_symbolic_inputs, is_multi_output, is_as_strided, out_idx);
}
// Note [Functionalization: Alias Removal Part 2]
// See Note [Functionalization: Alias Removal] for more details.
// This function applies a single update from one of the views to the StorageImpl.
// We start out with <original_base> and <mutated_view>, and our goal is to end up with <mutated_base>.
// Consider this program:
//
// base = ...
// a = base.view1()
// b = a.view2()
// c = b.view3()
// c.add_(3)
//
// Then the functionalization pass will queue an update as follows:
//
// update.new_val = c # the updated value of c
// update.view_metas = [view1_meta, view2_meta, view3_meta]
//
// Syncing any of a, b or c will eventually call apply_update() on the storage, and the following will run:
//
// tmp_values = [base, a, b] # NB: c is not necessary
// t = update.new_val
// t = view3_inverse(b, t, 0) # 0 is output index, these are all single output views so it's 0
// t = view2_inverse(a, t, 0)
// t = view1_inverse(base, t, 0) # t now represents the updated storage.
// storage.base_ = t
static const Tensor apply_update(const FunctionalStorageImpl::Update& update, const Tensor& base) {
at::Tensor t = update.new_val;
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
if (update.view_metas.empty()) return t;
std::vector<at::Tensor> tmp_values({base});
tmp_values.reserve(update.view_metas.size());
for (size_t i = 0; i < update.view_metas.size() - 1; ++i) {
at::Tensor next_view = update.view_metas[i].forward_fn(tmp_values.back(), update.view_metas[i].out_index);
// NB: We only actually need tmp_values for ops like select/slice/diagonal/squeeze/as_strided
// All of these ops require additional information to recover the sizes of the original tensor.
// If need to, we could probably apply this optimization and only bother computing tmp_values
// for those necessary view ops.
tmp_values.push_back(std::move(next_view));
}
for(int64_t i = static_cast<int64_t>(update.view_metas.size()) - 1; i >= 0; --i) {
int64_t out_idx = update.view_metas[i].out_index;
// Each view inverse is implemented in ViewInverses.cpp.
t = update.view_metas[i].reverse_fn(tmp_values[i], t, out_idx);
}
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(t));
return t;
}
static c10::SymInt get_nbytes(const Tensor& value) {
// The functionalization story when wrapping tensors that don't have storage
// is a bit wonky, but fortunately for some models (e.g., dlrm) we never
// actually perform mutations on these tensors, so you never really get
// called out on it. For now, functionalization still creates "storages"
// for these tensors (which is wrong), but we don't give them any space.
// A more proper fix would be to have a SparseFunctionalTensorWrapper that
// models sparse correctly.
if (value.is_sparse() || at::sparse_csr::is_sparse_compressed(value)) {
return 0;
}
if (value.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {
// Today, the two implementations of SymInt are in Python (proxy tensor),
// and lazy tensor (LTC/XLA).
// LTC hasn't implemented SymInt support yet though
// Once it does, we should remove this check.
if (value.key_set().has(c10::DispatchKey::Python)) {
return value.storage().sym_nbytes();
}
return at::detail::computeStorageNbytes(value.sym_sizes(), value.sym_strides(), value.dtype().itemsize(), value.sym_storage_offset());
}
// XLA storage objects also do not properly track nbytes.
return at::detail::computeStorageNbytes(value.sizes(), value.strides(), value.dtype().itemsize(), value.storage_offset());
}
FunctionalStorageImpl::FunctionalStorageImpl(const Tensor& base)
: c10::StorageImpl(
c10::StorageImpl::use_byte_size_t(),
get_nbytes(base),
DataPtr{nullptr, base.device()},
GetAllocator(kMeta),
/*resizable=*/true
),
base_(base)
{
// SparseTensorImpl has no storage, so we cannot query its nbytes.
// (original_storage_size is only used for storage resizing in fsdp anyway, which does not apply to sparse)
// Same for XLA
if (base.unsafeGetTensorImpl()->has_storage() && base.device().type() != c10::DeviceType::XLA) {
original_storage_size_ = base.unsafeGetTensorImpl()->unsafe_storage().unsafeGetStorageImpl()->sym_nbytes();
} else {
original_storage_size_ = -1;
}
curr_storage_size_ = original_storage_size_;
TORCH_INTERNAL_ASSERT(!at::functionalization::impl::isFunctionalTensor(base_));
}
void FunctionalStorageImpl::add_update(const Tensor& updated_val, const std::vector<ViewMeta>& metas) {
TORCH_CHECK(!frozen_, "cannot mutate tensors with frozen storage");
if (metas.size() > 1) {
for (size_t i = 1; i < metas.size(); ++i) {
// Skipping this check for XLA. Would be good to add it back, but it is failing XLA CI
TORCH_CHECK(updated_val.device().type() == c10::DeviceType::XLA || !metas[i].is_as_strided,
"During torch.compile, encountered a mutation on a view chain of length ", metas.size(), ", where view ", i,
" was an as_strided() call. as_strided() is non-compositional, and therefore is not possible to functionalize properly today,"
"so this behavior is banned in compile. As a workaround, you can either remove the mutation from the model code, or you "
"can insert a graph break right before the mutation with torch._dynamo.graph_break(). If you would like this behavior to "
"work properly, please comment on https://github.com/pytorch/pytorch/issues/104505.");
}
}
updates_.push_back({updated_val, metas});
generation_++;
}
bool FunctionalStorageImpl::apply_updates() {
// N.B:none of the tensors used in this function should be FunctionalTensorWrappers at this point.
// The only reason we currently need the TLS exclude guard here is because of functorch's DynamicLayer stack.
// It adds the Functionalize key into TLS before redispatching to the functionalization kernels,
// which means that we need to explicitly exclude it here before doing any other work underneath the pass.
at::AutoDispatchSkipFunctionalize guard;
bool any_updates = !updates_.empty();
for (auto& update_data: updates_) {
base_ = apply_update(update_data, base_);
}
updates_.clear();
return any_updates;
}
} // namespace at::functionalization