@@ -545,13 +545,16 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
545545
546546 // Indices of tensors to be skipped because they're empty
547547 std::vector<int64_t > skipped_tensor_indices;
548+ // Tensors to be read
549+ std::vector<const Tensor*> input_tensors;
548550 int tensor_idx = 0 ;
549551 for (const Tensor& t : materialized_inputs) {
550552 if (t.numel () == 0 || should_skip (t)) {
551553 skipped_tensor_indices.push_back (tensor_idx);
552554 tensor_idx++;
553555 continue ;
554556 }
557+ input_tensors.push_back (&t);
555558 nDims = t.dim ();
556559 // TODO: Is this OK?
557560 notSkippedTensor = &t;
@@ -659,26 +662,22 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
659662 MPSGraphTensor* castInputMPSGraphTensors[len_tensor_array];
660663
661664 int graph_tensor_idx = 0 ;
662- int i = 0 ;
663- for (const Tensor& tensor : materialized_inputs) {
664- if (std::find (skipped_tensor_indices.begin (), skipped_tensor_indices.end (), i) == skipped_tensor_indices.end ()) {
665- inputMPSGraphTensors[graph_tensor_idx] = mpsGraphUnrankedPlaceHolder (mpsGraph, getMPSDataType (tensor.scalar_type ()) );
666- if (getMPSDataType (result_type (inputs)) == MPSDataTypeBool) {
665+ for (const Tensor* tensor : input_tensors) {
666+ inputMPSGraphTensors[graph_tensor_idx] = mpsGraphUnrankedPlaceHolder (mpsGraph, getMPSDataType (tensor->scalar_type ()) );
667+ if (getMPSDataType (result_type (inputs)) == MPSDataTypeBool) {
668+ castInputMPSGraphTensors[graph_tensor_idx] = [mpsGraph castTensor: inputMPSGraphTensors[graph_tensor_idx]
669+ toType: MPSDataTypeFloat32
670+ name: [NSString stringWithFormat: @" castInput%@ " , [NSNumber numberWithInt: graph_tensor_idx]]];
671+ }
672+ else {
673+ if (tensor->scalar_type () != result_type (inputs))
667674 castInputMPSGraphTensors[graph_tensor_idx] = [mpsGraph castTensor: inputMPSGraphTensors[graph_tensor_idx]
668- toType: MPSDataTypeFloat32
669- name: [NSString stringWithFormat: @" castInput%@ " , [NSNumber numberWithInt: graph_tensor_idx]]];
670- }
671- else {
672- if (tensor.scalar_type () != result_type (inputs))
673- castInputMPSGraphTensors[graph_tensor_idx] = [mpsGraph castTensor: inputMPSGraphTensors[graph_tensor_idx]
674- toType: getMPSDataType (result_type (inputs))
675- name: [NSString stringWithFormat: @" castInput%@ " , [NSNumber numberWithInt: graph_tensor_idx]]];
676- else
677- castInputMPSGraphTensors[graph_tensor_idx] = inputMPSGraphTensors[graph_tensor_idx];
678- }
679- graph_tensor_idx++;
675+ toType: getMPSDataType (result_type (inputs))
676+ name: [NSString stringWithFormat: @" castInput%@ " , [NSNumber numberWithInt: graph_tensor_idx]]];
677+ else
678+ castInputMPSGraphTensors[graph_tensor_idx] = inputMPSGraphTensors[graph_tensor_idx];
680679 }
681- i ++;
680+ graph_tensor_idx ++;
682681 }
683682
684683 auto inputTensorsArray = [NSArray arrayWithObjects: castInputMPSGraphTensors
0 commit comments