diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 343a335217352..16ea8791a5319 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -42,12 +42,12 @@ void runMPSGraph( MPSDataType getMPSDataType(ScalarType scalar_type); MPSDataType getMPSScalarType(ScalarType scalar_type); MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type); -std::string getMPSTypeString(ScalarType scalar_type); +std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false); std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type); NSArray* getTensorAxes(const Tensor& t); NSArray* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim); std::string getMPSShapeString(MPSShape* shape); -std::string getTensorsStringKey(const TensorList& tensors, bool use_scalar_value = false); +std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = false); std::string getArrayRefString(const IntArrayRef s); // use has_storage() on the returned tensor to determine if src actually is a view Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst); diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index ffb5ddf490267..22dca3250596e 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -63,25 +63,26 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { } } -std::string getMPSTypeString(ScalarType scalar_type) { +// use short_name to avoid getting extra long cached graph keys with ops such as cat_out(), etc. +std::string getMPSTypeString(ScalarType scalar_type, bool short_name) { switch (scalar_type) { case ScalarType::Double: case ScalarType::Float: - return "Float32"; + return short_name ? "f32" : "Float32"; case ScalarType::Half: - return "Float16"; + return short_name ? "f16" : "Float16"; case ScalarType::Int: - return "Int32"; + return short_name ? "i32" : "Int32"; case ScalarType::Long: - return "Int64"; + return short_name ? "i64" : "Int64"; case ScalarType::Short: - return "Int16"; + return short_name ? "i16" : "Int16"; case ScalarType::Char: - return "Int8"; + return short_name ? "i8" : "Int8"; case ScalarType::Byte: - return "UInt8"; + return short_name ? "u8" : "UInt8"; case ScalarType::Bool: - return "Bool"; + return short_name ? "b8" : "Bool"; default: return "Undefined"; } @@ -150,16 +151,16 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { return ss.str(); } -std::string getTensorsStringKey(const TensorList& tensors, bool use_scalar_value) { +std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype) { std::string str; // The key format per tensor would look like ":Float32[1,1,1,10]:" for (const Tensor& tensor: tensors) { str += ":"; if (tensor.defined()) { - str += getMPSTypeString(tensor.scalar_type()) + "["; + str += getMPSTypeString(tensor.scalar_type(), short_dtype) + "["; // if tensor is a scalar if (tensor.dim() == 0) { - str += (use_scalar_value ? std::to_string(tensor.item().to()) : "Scalar"); + str += "Scalar"; } else { const NSString* ns_shape_key = [[getMPSShape(tensor) valueForKey:@"description"] componentsJoinedByString:@","]; str += std::string(ns_shape_key.UTF8String); diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 7efb646d72175..ed3f69f22ad44 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -71,7 +71,7 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - string key = op_name + getTensorsStringKey({self, other, output_}, /*use_scalar_value*/ false); + string key = op_name + getTensorsStringKey({self, other, output_}); BinaryOpCachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); if(!cachedGraph) { diff --git a/aten/src/ATen/native/mps/operations/PointwiseOps.mm b/aten/src/ATen/native/mps/operations/PointwiseOps.mm index 4c3e7d9e50cc1..9ed6298368716 100644 --- a/aten/src/ATen/native/mps/operations/PointwiseOps.mm +++ b/aten/src/ATen/native/mps/operations/PointwiseOps.mm @@ -35,7 +35,7 @@ MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - string key = op_name + getTensorsStringKey({self, tensor1, tensor2}, false); + string key = op_name + getTensorsStringKey({self, tensor1, tensor2}); CachedGraph* cachedGraph = cache_->LookUpAs(key); diff --git a/aten/src/ATen/native/mps/operations/Shape.mm b/aten/src/ATen/native/mps/operations/Shape.mm index 86e82bef93cbf..3f460437c1c00 100644 --- a/aten/src/ATen/native/mps/operations/Shape.mm +++ b/aten/src/ATen/native/mps/operations/Shape.mm @@ -182,25 +182,6 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, } } -inline c10::MemoryFormat compute_output_memory_format(const TensorList &inputs) { - c10::optional format = c10::nullopt; - for (auto &t : inputs) { - auto f = t.suggest_memory_format(); - if (!format.has_value()) { - format = f; - continue; - } - if (format.value() == f) { - continue; - } - bool contiguous = (format.value() == c10::MemoryFormat::Contiguous || f == c10::MemoryFormat::Contiguous || format.value() != f); - if (contiguous) { - return c10::MemoryFormat::Contiguous; - } - } - return format.value(); -} - TORCH_IMPL_FUNC(cat_out_mps) (const ITensorListRef& inputs, int64_t dimension, @@ -214,17 +195,25 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, if (out.numel() == 0) { return; } - auto materialized_inputs = inputs.materialize(); + auto out_dtype = at::native::result_type(inputs); int idx = 0; for(const Tensor& t : materialized_inputs) { - TORCH_CHECK(t.dim() > 0, - "zero-dimensional tensor (at position ", idx, ") cannot be concatenated"); + TORCH_CHECK(t.dim() > 0, "zero-dimensional tensor (at position ", idx, ") cannot be concatenated"); + auto lap = at::get_overlap_status(out, t); + TORCH_CHECK(lap != at::MemOverlapStatus::Partial && lap != at::MemOverlapStatus::Full, + "torch.cat(): unsupported operation: the input tensors cannot refer to any " + "of the output memory locations. Found overlap in input tensor ", idx); idx++; } + // Check for type promotion + TORCH_CHECK(canCast(out_dtype, out.scalar_type()), + "torch.cat(): input types can't be cast to the desired output type ", out.scalar_type()); + TORCH_CHECK(inputs.size() > 0,"torch.cat(): invalid number of inputs ", inputs.size()); dimension = legacy_cat_wrap_dim(dimension, materialized_inputs); + TORCH_CHECK(dimension >= 0, "torch.cat(): invalid dimension ", dimension); // previously, size [0] tensors were the only possible empty tensors; thus, it // wasn't possible to cat empty tensors unless all the other tensors were @@ -235,28 +224,6 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, auto should_skip = [](const Tensor& t) { return t.dim() == 1 && at::native::size(t, 0) == 0; }; - - - // Check for type promotion - TORCH_CHECK( - canCast(result_type(inputs), out.scalar_type()), - "torch.cat(): input types ", - " can't be cast to the desired output type ", - out.scalar_type()); - - // Inputs cannot alias the output tensor - idx = 0; - for(const Tensor& t : materialized_inputs) { - auto lap = at::get_overlap_status(out, t); - TORCH_CHECK( - lap != at::MemOverlapStatus::Partial && - lap != at::MemOverlapStatus::Full, - "torch.cat(): unsupported operation: the input tensors cannot refer to any " - "of the output memory locations. Found overlap in input " - "tensor ", - idx); - idx++; - } at::assert_no_internal_overlap(out); Tensor notSkippedTensor; @@ -276,38 +243,22 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, notSkippedTensor = t; tensor_idx++; } - // If all inputs are empty tensors, return an empty tensor if (!notSkippedTensor.defined()) { return; } - - TORCH_CHECK( - inputs.size() > 0, - "torch.cat(): invalid number of inputs ", - inputs.size()); - TORCH_CHECK(dimension >= 0, "torch.cat(): invalid dimension ", dimension); - for (const Tensor& t : inputs) { - TORCH_CHECK( - t.device() == notSkippedTensor.device(), - "torch.cat(): all input tensors must be on the same device. Received ", - t.device(), - " and ", - notSkippedTensor.device()); + TORCH_CHECK(t.device() == notSkippedTensor.device(), + "torch.cat(): all input tensors must be on the same device. Received ", + t.device(), " and ", notSkippedTensor.device()); } + TORCH_CHECK(out.device() == notSkippedTensor.device(), + "torch.cat(): all input tensors and out must be on the same device, but inputs are on ", + notSkippedTensor.device(), " and out is on ", out.device()); - TORCH_CHECK( - out.device() == notSkippedTensor.device(), - "torch.cat(): all input tensors and out must be on the same device, but inputs are on ", - notSkippedTensor.device(), - " and out is on ", - out.device()); - - // TODO: memory_format is now an argument? - // // TODO: Factor out `compute_output_memory_format` - // c10::MemoryFormat memory_format = compute_output_memory_format(inputs); - + if (out.suggest_memory_format() == MemoryFormat::ChannelsLast) { + out.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous); + } std::vector size(notSkippedTensor.sizes().vec()); // Compute size of the result in the cat dimension @@ -322,48 +273,30 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, cat_dim_size += at::native::size(tensor, dimension); idx++; } - // Compute the size of the result size[dimension] = cat_dim_size; - // skip resizing if size of result is same as expected if (out.sizes() != size) { out.resize_(size, memory_format); } - if (out.numel() == 0) { return; } - // Get stream - MPSStream* stream = getCurrentMPSStream(); - - struct CachedGraph : public MPSCachedGraph - { + struct CachedGraph : public MPSCachedGraph { CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {} std::vector inputTensors_; MPSGraphTensor* outputTensor_ = nil; }; - MPSGraphCache *cache_ = MPSGraphCache::getInstance(); - // Make string out of skipped tensor indices - string skipped_indices_string = ""; - for(int idx : skipped_tensor_indices) - skipped_indices_string += (std::to_string(idx)+","); - string input_types = ""; - for(const Tensor& tensor : materialized_inputs) - input_types += (getMPSTypeString(tensor.scalar_type())+","); - @autoreleasepool { - string key = "cat_out_mps:" + getMPSTypeString(result_type(inputs)) - + ":" + to_string(inputs.size()) - + ":" + skipped_indices_string - + ":" + input_types - + ":" + to_string(dimension); - CachedGraph* cachedGraph = static_cast(cache_->LookUp(key)); + string key = "cat_out_mps:" + to_string(dimension) + getTensorsStringKey(input_tensors, /*short_dtype*/true) + ":" + + (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW"); + + CachedGraph* cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { - MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () { + cachedGraph = cache_->CreateCachedGraphAs(key, ^ MPSCachedGraph * () { CachedGraph *newCachedGraph = nil; @autoreleasepool { @@ -375,15 +308,15 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, newCachedGraph->inputTensors_.reserve(len_tensor_array); for (const auto idx : c10::irange(len_tensor_array)) { - auto scalar_type = getMPSScalarType(input_tensors[idx].scalar_type()); - if (input_tensors[idx].scalar_type() == kBool) { + const Tensor& tensor = input_tensors[idx]; + auto scalar_type = getMPSScalarType(tensor.scalar_type()); + if (tensor.scalar_type() == kBool) { scalar_type = MPSDataTypeInt8; } - - newCachedGraph->inputTensors_[idx] = mpsGraphUnrankedPlaceHolder(mpsGraph, scalar_type); - if (input_tensors[idx].scalar_type() != result_type(inputs)) { + newCachedGraph->inputTensors_[idx] = mpsGraphRankedPlaceHolder(mpsGraph, scalar_type, getMPSShape(tensor, memory_format)); + if (tensor.scalar_type() != out_dtype) { castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx] - toType:getMPSDataType(result_type(inputs)) + toType:getMPSDataType(out_dtype) name:@"castInput"]; } else { castInputTensors[idx] = newCachedGraph->inputTensors_[idx]; @@ -395,16 +328,16 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray dimension:dimension // Maybe convert this from int64_t -> int32 name:nil]; - if(getMPSDataType(result_type(inputs)) == MPSDataTypeBool) { + if(getMPSDataType(out_dtype) == MPSDataTypeBool) { outputTensor = [mpsGraph castTensor:outputTensor toType:MPSDataTypeBool name:@"outputTensor"]; } - newCachedGraph->outputTensor_ = outputTensor; + newCachedGraph->outputTensor_ = memory_format == MemoryFormat::ChannelsLast ? + convertNHWCtoNCHW(mpsGraph, outputTensor) : outputTensor; } return newCachedGraph; }); - cachedGraph = static_cast(tmpCachedGraph); } std::vector inputPlaceholders; @@ -416,9 +349,9 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, if (tensor.scalar_type() == kBool) { scalar_type = MPSDataTypeInt8; } - Placeholder currentInputPlaceholder = Placeholder( - cachedGraph->inputTensors_[t_idx], tensor, /*mpsShape=*/nil, /*gatherTensorData=*/true, scalar_type); - inputPlaceholders.push_back(currentInputPlaceholder); + inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor, + getMPSShape(tensor, memory_format), + memory_format != MemoryFormat::ChannelsLast, scalar_type); t_idx++; } i++; @@ -439,9 +372,8 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second, outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData() }; - mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results); + runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results); } - } } // namespace native diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm index cb3e40299bceb..039e8e5f52c80 100644 --- a/aten/src/ATen/native/mps/operations/UnaryOps.mm +++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm @@ -29,7 +29,7 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una } MPSGraphCache* cache_ = MPSGraphCache::getInstance(); @autoreleasepool { - string key = op_name + getTensorsStringKey({self, output}, /*use_scalar_value*/ false); + string key = op_name + getTensorsStringKey({self, output}); auto cachedGraph = cache_->LookUpAs(key); if(!cachedGraph) { diff --git a/test/test_mps.py b/test/test_mps.py index 0f074fbe88fcd..7f73e66619dff 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -8873,7 +8873,7 @@ def test_numpy_ref_mps(self, device, dtype, op): # does not support float64 Tensors. # A few ops are currently broken on their reference inputs, but not their sample inputs. These should # get patched up and this workaround removed. - broken_on_ref_inputs = op.name in ['cat', 'clamp', 'where'] + broken_on_ref_inputs = op.name in ['clamp', 'where'] inputs = op.reference_inputs(device, dtype) if not broken_on_ref_inputs else op.sample_inputs(device, dtype) for sample_input in inputs: self.compare_with_reference(op, op.ref, sample_input)