Skip to content

Commit b6e0f95

Browse files
committed
Simplify non-skipped tensor logic
1 parent 7b977f2 commit b6e0f95

File tree

1 file changed

+17
-18
lines changed
  • aten/src/ATen/native/mps/operations

1 file changed

+17
-18
lines changed

aten/src/ATen/native/mps/operations/Shape.mm

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -545,13 +545,16 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
545545

546546
// Indices of tensors to be skipped because they're empty
547547
std::vector<int64_t> skipped_tensor_indices;
548+
// Tensors to be read
549+
std::vector<const Tensor*> input_tensors;
548550
int tensor_idx = 0;
549551
for(const Tensor& t : materialized_inputs) {
550552
if(t.numel() == 0 || should_skip(t)) {
551553
skipped_tensor_indices.push_back(tensor_idx);
552554
tensor_idx++;
553555
continue;
554556
}
557+
input_tensors.push_back(&t);
555558
nDims = t.dim();
556559
// TODO: Is this OK?
557560
notSkippedTensor = &t;
@@ -659,26 +662,22 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
659662
MPSGraphTensor* castInputMPSGraphTensors[len_tensor_array];
660663

661664
int graph_tensor_idx = 0;
662-
int i = 0;
663-
for(const Tensor& tensor : materialized_inputs) {
664-
if(std::find(skipped_tensor_indices.begin(), skipped_tensor_indices.end(), i) == skipped_tensor_indices.end()) {
665-
inputMPSGraphTensors[graph_tensor_idx] = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(tensor.scalar_type()) );
666-
if(getMPSDataType(result_type(inputs)) == MPSDataTypeBool) {
665+
for(const Tensor* tensor : input_tensors) {
666+
inputMPSGraphTensors[graph_tensor_idx] = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(tensor->scalar_type()) );
667+
if(getMPSDataType(result_type(inputs)) == MPSDataTypeBool) {
668+
castInputMPSGraphTensors[graph_tensor_idx] = [mpsGraph castTensor:inputMPSGraphTensors[graph_tensor_idx]
669+
toType:MPSDataTypeFloat32
670+
name:[NSString stringWithFormat:@"castInput%@", [NSNumber numberWithInt:graph_tensor_idx]]];
671+
}
672+
else {
673+
if(tensor->scalar_type() != result_type(inputs))
667674
castInputMPSGraphTensors[graph_tensor_idx] = [mpsGraph castTensor:inputMPSGraphTensors[graph_tensor_idx]
668-
toType:MPSDataTypeFloat32
669-
name:[NSString stringWithFormat:@"castInput%@", [NSNumber numberWithInt:graph_tensor_idx]]];
670-
}
671-
else {
672-
if(tensor.scalar_type() != result_type(inputs))
673-
castInputMPSGraphTensors[graph_tensor_idx] = [mpsGraph castTensor:inputMPSGraphTensors[graph_tensor_idx]
674-
toType:getMPSDataType(result_type(inputs))
675-
name:[NSString stringWithFormat:@"castInput%@", [NSNumber numberWithInt:graph_tensor_idx]]];
676-
else
677-
castInputMPSGraphTensors[graph_tensor_idx] = inputMPSGraphTensors[graph_tensor_idx];
678-
}
679-
graph_tensor_idx++;
675+
toType:getMPSDataType(result_type(inputs))
676+
name:[NSString stringWithFormat:@"castInput%@", [NSNumber numberWithInt:graph_tensor_idx]]];
677+
else
678+
castInputMPSGraphTensors[graph_tensor_idx] = inputMPSGraphTensors[graph_tensor_idx];
680679
}
681-
i++;
680+
graph_tensor_idx++;
682681
}
683682

684683
auto inputTensorsArray = [NSArray arrayWithObjects:castInputMPSGraphTensors

0 commit comments

Comments
 (0)