@@ -22,25 +22,38 @@ namespace {
2222struct CollectiveFingerPrint {
2323 // Current collective's operation type.
2424 OpType op_type_;
25- // Ref to input tensors, if given, of the collective. If given, shapes will be
26- // checked across processes to ensure valid input into the collective.
27- const std::vector<at::Tensor>& input_tensors_;
25+ // Number of input tensors
26+ std::size_t num_tensors_;
2827 // input tensor data types
2928 std::vector<int8_t > tensor_dtypes_;
3029 // input tensor device types
3130 std::vector<int8_t > tensor_device_types_;
31+ // input tensor sizes
32+ std::vector<c10::IntArrayRef> tensor_sizes_;
33+
3234 explicit CollectiveFingerPrint (
3335 OpType op_type,
3436 const std::vector<at::Tensor>& input_tensors)
35- : op_type_(op_type), input_tensors_(input_tensors) {
36- tensor_dtypes_.reserve (input_tensors.size ());
37- tensor_device_types_.reserve (input_tensors.size ());
38- for (const at::Tensor& t : input_tensors_) {
37+ : op_type_(op_type), num_tensors_(input_tensors.size()) {
38+ tensor_dtypes_.reserve (num_tensors_);
39+ tensor_device_types_.reserve (num_tensors_);
40+ tensor_sizes_.reserve (num_tensors_);
41+ for (const at::Tensor& t : input_tensors) {
3942 tensor_dtypes_.push_back (static_cast <int8_t >(t.dtype ().toScalarType ()));
4043 tensor_device_types_.push_back (static_cast <int8_t >(t.device ().type ()));
44+ tensor_sizes_.push_back (t.sizes ());
4145 }
4246 }
4347
48+ // Constructor for the data received from deserialized fingerprint
49+ CollectiveFingerPrint (
50+ OpType op_type,
51+ std::vector<int8_t > tensor_dtypes,
52+ std::vector<int8_t > tensor_device_types)
53+ : op_type_(op_type),
54+ tensor_dtypes_ (tensor_dtypes),
55+ tensor_device_types_(tensor_device_types) {}
56+
4457 // Logs collective information in case of a failure.
4558 friend std::ostream& operator <<(
4659 std::ostream& output,
@@ -62,14 +75,51 @@ struct CollectiveFingerPrint {
6275 verify_tensors (inp, pg);
6376 }
6477
78+ // Takes a serialized fingerprint from
79+ // CollectiveFingerPrint::serialize_fingerprint and deserializes it back to a
80+ // CollectiveFingerPrint struct
81+ CollectiveFingerPrint deserialize_fingerprint (at::Tensor serialized_tensor) {
82+ // TODO: Need to add asserts to validate serialized_tensor.sizes() before
83+ // deserializing
84+ int index = 0 ;
85+ // 1. OpType
86+ OpType optype = OpType (serialized_tensor[index].item <int >());
87+ index++;
88+
89+ std::vector<int8_t > dtypes = std::vector<int8_t >();
90+ std::vector<int8_t > device_types = std::vector<int8_t >();
91+ if (index < serialized_tensor.size (0 )) {
92+ // 2. Num tensors
93+ int num_tensors = serialized_tensor[index].item <int >();
94+ index++;
95+
96+ // 3. Tensor dtypes
97+ for (int i = 0 ; i < num_tensors; i++) {
98+ dtypes.push_back (serialized_tensor[index].item <int8_t >());
99+ index++;
100+ }
101+ // 4. Device types
102+ for (int i = 0 ; i < num_tensors; i++) {
103+ device_types.push_back (serialized_tensor[index].item <int8_t >());
104+ index++;
105+ }
106+ }
107+ return CollectiveFingerPrint (optype, dtypes, device_types);
108+ }
109+
65110 private:
66111 void verify_tensors (
67112 std::vector<at::Tensor>& tensors_to_verify,
68113 c10::intrusive_ptr<ProcessGroup>& pg) {
69114 // Create output tensor data structure to pass into allgather.
70115 std::vector<std::vector<at::Tensor>> output_tensors;
116+ // output tensors: [<tensor 0 outputs>, <tensor 1 outputs>, ..., <tensor n
117+ // outputs>]
71118 output_tensors.reserve (tensors_to_verify.size ());
72119 for (const auto & tensor_shape : tensors_to_verify) {
120+ // Each rank has its own outputs shape, e.g.
121+ // <tensor 0 outputs>: [<rank 0 tensor>, <rank 1 tensor>, ..., <rank n
122+ // tensor>]
73123 std::vector<at::Tensor> outputs;
74124 outputs.reserve (pg->getSize ());
75125 for (const auto i : c10::irange (pg->getSize ())) {
@@ -84,38 +134,45 @@ struct CollectiveFingerPrint {
84134 for (const auto i : c10::irange (output_tensors.size ())) {
85135 const std::vector<at::Tensor> gathered_tensors = output_tensors[i];
86136 const at::Tensor reference_tensor = tensors_to_verify[i];
87- for (const auto & rank_tensor : gathered_tensors) {
137+ for (int rank = 0 ; rank < gathered_tensors.size (); rank++) {
138+ const auto & rank_tensor = gathered_tensors[rank];
88139 if (!rank_tensor.equal (reference_tensor)) {
140+ CollectiveFingerPrint rank_fingerprint =
141+ deserialize_fingerprint (rank_tensor);
89142 std::stringstream ss;
90143 ss << " Detected mismatch between collectives on ranks. Rank "
91- << pg->getRank ()
92- << " is running inconsistent collective: " << *this ;
144+ << pg->getRank () << " is running collective: " << *this
145+ << " , but Rank " << rank << " is running collective: "
146+ << opTypeToString (rank_fingerprint.op_type_ ) << " ." ;
93147 TORCH_CHECK (false , ss.str ());
94148 }
95149 }
96150 }
97151 }
98152
153+ // Serializes the information (op type, input shapes, data types, device
154+ // types) about the collective fingerprint into a tensor
99155 at::Tensor serialize_fingerprint () {
100156 auto data = std::make_unique<std::vector<int64_t >>();
101157 // std::vector<int64_t> data;
102- // OpType
158+ // 1. OpType
103159 data->push_back (static_cast <int64_t >(op_type_));
104- // Shapes
105- for (const auto & tensor : input_tensors_) {
106- auto sizes = tensor.sizes ().vec ();
107- for (const auto & s : sizes) {
108- data->push_back (s);
109- }
110- }
111- // tensor dtypes
160+ // 2. Num tensors
161+ data->push_back (static_cast <int64_t >(num_tensors_));
162+ // 3. Tensor dtypes
112163 for (const auto & type : tensor_dtypes_) {
113164 data->push_back (type);
114165 }
115- // device types
166+ // 4. Device types
116167 for (const auto & d : tensor_device_types_) {
117168 data->push_back (d);
118169 }
170+ // 5. Shapes
171+ for (const auto & sizes : tensor_sizes_) {
172+ for (const auto & s : sizes) {
173+ data->push_back (s);
174+ }
175+ }
119176 // Serialize data into tensor
120177 int64_t data_size = data->size ();
121178 // Need to release here and get the ptr due to C++ parameter evaluation
@@ -138,7 +195,7 @@ std::ostream& operator<<(
138195 std::ostream& output,
139196 const CollectiveFingerPrint& collective_fingerprint) {
140197 std::string collectiveInfo;
141- if (! collective_fingerprint.input_tensors_ . empty () ) {
198+ if (collective_fingerprint.num_tensors_ != 0 ) {
142199 // Convert dtype and device type info to string.
143200 std::vector<std::string> dtype_strs;
144201 std::vector<std::string> device_type_strs;
@@ -157,7 +214,7 @@ std::ostream& operator<<(
157214 " OpType=" ,
158215 opTypeToString (collective_fingerprint.op_type_ ),
159216 " , TensorShape=" ,
160- (collective_fingerprint.input_tensors_ )[0 ]. sizes () ,
217+ (collective_fingerprint.tensor_sizes_ )[0 ],
161218 " , TensorDtypes=" ,
162219 (dtype_strs),
163220 " , TensorDeviceTypes=" ,
0 commit comments