Skip to content

Commit 26b5129

Browse files
H-Huangpytorchmergebot
authored andcommitted
[BE] Update ProcessGroupWrapper to add deserializer and improve logs
Pull Request resolved: pytorch#79724 Approved by: https://github.com/kumpera, https://github.com/rohan-varma
1 parent ccccd0e commit 26b5129

File tree

1 file changed

+79
-22
lines changed

1 file changed

+79
-22
lines changed

torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp

Lines changed: 79 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,25 +22,38 @@ namespace {
2222
struct CollectiveFingerPrint {
2323
// Current collective's operation type.
2424
OpType op_type_;
25-
// Ref to input tensors, if given, of the collective. If given, shapes will be
26-
// checked across processes to ensure valid input into the collective.
27-
const std::vector<at::Tensor>& input_tensors_;
25+
// Number of input tensors
26+
std::size_t num_tensors_;
2827
// input tensor data types
2928
std::vector<int8_t> tensor_dtypes_;
3029
// input tensor device types
3130
std::vector<int8_t> tensor_device_types_;
31+
// input tensor sizes
32+
std::vector<c10::IntArrayRef> tensor_sizes_;
33+
3234
explicit CollectiveFingerPrint(
3335
OpType op_type,
3436
const std::vector<at::Tensor>& input_tensors)
35-
: op_type_(op_type), input_tensors_(input_tensors) {
36-
tensor_dtypes_.reserve(input_tensors.size());
37-
tensor_device_types_.reserve(input_tensors.size());
38-
for (const at::Tensor& t : input_tensors_) {
37+
: op_type_(op_type), num_tensors_(input_tensors.size()) {
38+
tensor_dtypes_.reserve(num_tensors_);
39+
tensor_device_types_.reserve(num_tensors_);
40+
tensor_sizes_.reserve(num_tensors_);
41+
for (const at::Tensor& t : input_tensors) {
3942
tensor_dtypes_.push_back(static_cast<int8_t>(t.dtype().toScalarType()));
4043
tensor_device_types_.push_back(static_cast<int8_t>(t.device().type()));
44+
tensor_sizes_.push_back(t.sizes());
4145
}
4246
}
4347

48+
// Constructor for the data received from deserialized fingerprint
49+
CollectiveFingerPrint(
50+
OpType op_type,
51+
std::vector<int8_t> tensor_dtypes,
52+
std::vector<int8_t> tensor_device_types)
53+
: op_type_(op_type),
54+
tensor_dtypes_(tensor_dtypes),
55+
tensor_device_types_(tensor_device_types) {}
56+
4457
// Logs collective information in case of a failure.
4558
friend std::ostream& operator<<(
4659
std::ostream& output,
@@ -62,14 +75,51 @@ struct CollectiveFingerPrint {
6275
verify_tensors(inp, pg);
6376
}
6477

78+
// Takes a serialized fingerprint from
79+
// CollectiveFingerPrint::serialize_fingerprint and deserializes it back to a
80+
// CollectiveFingerPrint struct
81+
CollectiveFingerPrint deserialize_fingerprint(at::Tensor serialized_tensor) {
82+
// TODO: Need to add asserts to validate serialized_tensor.sizes() before
83+
// deserializing
84+
int index = 0;
85+
// 1. OpType
86+
OpType optype = OpType(serialized_tensor[index].item<int>());
87+
index++;
88+
89+
std::vector<int8_t> dtypes = std::vector<int8_t>();
90+
std::vector<int8_t> device_types = std::vector<int8_t>();
91+
if (index < serialized_tensor.size(0)) {
92+
// 2. Num tensors
93+
int num_tensors = serialized_tensor[index].item<int>();
94+
index++;
95+
96+
// 3. Tensor dtypes
97+
for (int i = 0; i < num_tensors; i++) {
98+
dtypes.push_back(serialized_tensor[index].item<int8_t>());
99+
index++;
100+
}
101+
// 4. Device types
102+
for (int i = 0; i < num_tensors; i++) {
103+
device_types.push_back(serialized_tensor[index].item<int8_t>());
104+
index++;
105+
}
106+
}
107+
return CollectiveFingerPrint(optype, dtypes, device_types);
108+
}
109+
65110
private:
66111
void verify_tensors(
67112
std::vector<at::Tensor>& tensors_to_verify,
68113
c10::intrusive_ptr<ProcessGroup>& pg) {
69114
// Create output tensor data structure to pass into allgather.
70115
std::vector<std::vector<at::Tensor>> output_tensors;
116+
// output tensors: [<tensor 0 outputs>, <tensor 1 outputs>, ..., <tensor n
117+
// outputs>]
71118
output_tensors.reserve(tensors_to_verify.size());
72119
for (const auto& tensor_shape : tensors_to_verify) {
120+
// Each rank has its own outputs shape, e.g.
121+
// <tensor 0 outputs>: [<rank 0 tensor>, <rank 1 tensor>, ..., <rank n
122+
// tensor>]
73123
std::vector<at::Tensor> outputs;
74124
outputs.reserve(pg->getSize());
75125
for (const auto i : c10::irange(pg->getSize())) {
@@ -84,38 +134,45 @@ struct CollectiveFingerPrint {
84134
for (const auto i : c10::irange(output_tensors.size())) {
85135
const std::vector<at::Tensor> gathered_tensors = output_tensors[i];
86136
const at::Tensor reference_tensor = tensors_to_verify[i];
87-
for (const auto& rank_tensor : gathered_tensors) {
137+
for (int rank = 0; rank < gathered_tensors.size(); rank++) {
138+
const auto& rank_tensor = gathered_tensors[rank];
88139
if (!rank_tensor.equal(reference_tensor)) {
140+
CollectiveFingerPrint rank_fingerprint =
141+
deserialize_fingerprint(rank_tensor);
89142
std::stringstream ss;
90143
ss << "Detected mismatch between collectives on ranks. Rank "
91-
<< pg->getRank()
92-
<< " is running inconsistent collective: " << *this;
144+
<< pg->getRank() << " is running collective: " << *this
145+
<< ", but Rank " << rank << " is running collective: "
146+
<< opTypeToString(rank_fingerprint.op_type_) << ".";
93147
TORCH_CHECK(false, ss.str());
94148
}
95149
}
96150
}
97151
}
98152

153+
// Serializes the information (op type, input shapes, data types, device
154+
// types) about the collective fingerprint into a tensor
99155
at::Tensor serialize_fingerprint() {
100156
auto data = std::make_unique<std::vector<int64_t>>();
101157
// std::vector<int64_t> data;
102-
// OpType
158+
// 1. OpType
103159
data->push_back(static_cast<int64_t>(op_type_));
104-
// Shapes
105-
for (const auto& tensor : input_tensors_) {
106-
auto sizes = tensor.sizes().vec();
107-
for (const auto& s : sizes) {
108-
data->push_back(s);
109-
}
110-
}
111-
// tensor dtypes
160+
// 2. Num tensors
161+
data->push_back(static_cast<int64_t>(num_tensors_));
162+
// 3. Tensor dtypes
112163
for (const auto& type : tensor_dtypes_) {
113164
data->push_back(type);
114165
}
115-
// device types
166+
// 4. Device types
116167
for (const auto& d : tensor_device_types_) {
117168
data->push_back(d);
118169
}
170+
// 5. Shapes
171+
for (const auto& sizes : tensor_sizes_) {
172+
for (const auto& s : sizes) {
173+
data->push_back(s);
174+
}
175+
}
119176
// Serialize data into tensor
120177
int64_t data_size = data->size();
121178
// Need to release here and get the ptr due to C++ parameter evaluation
@@ -138,7 +195,7 @@ std::ostream& operator<<(
138195
std::ostream& output,
139196
const CollectiveFingerPrint& collective_fingerprint) {
140197
std::string collectiveInfo;
141-
if (!collective_fingerprint.input_tensors_.empty()) {
198+
if (collective_fingerprint.num_tensors_ != 0) {
142199
// Convert dtype and device type info to string.
143200
std::vector<std::string> dtype_strs;
144201
std::vector<std::string> device_type_strs;
@@ -157,7 +214,7 @@ std::ostream& operator<<(
157214
"OpType=",
158215
opTypeToString(collective_fingerprint.op_type_),
159216
", TensorShape=",
160-
(collective_fingerprint.input_tensors_)[0].sizes(),
217+
(collective_fingerprint.tensor_sizes_)[0],
161218
", TensorDtypes=",
162219
(dtype_strs),
163220
", TensorDeviceTypes=",

0 commit comments

Comments
 (0)