Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions aten/src/ATen/native/mps/OperationUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ void runMPSGraph(
MPSDataType getMPSDataType(ScalarType scalar_type);
MPSDataType getMPSScalarType(ScalarType scalar_type);
MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type);
std::string getMPSTypeString(ScalarType scalar_type);
std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false);
std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type);
NSArray<NSNumber*>* getTensorAxes(const Tensor& t);
NSArray<NSNumber*>* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim);
std::string getMPSShapeString(MPSShape* shape);
std::string getTensorsStringKey(const TensorList& tensors, bool use_scalar_value = false);
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = false);
std::string getArrayRefString(const IntArrayRef s);
// use has_storage() on the returned tensor to determine if src actually is a view
Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst);
Expand Down
25 changes: 13 additions & 12 deletions aten/src/ATen/native/mps/OperationUtils.mm
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,26 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
}
}

std::string getMPSTypeString(ScalarType scalar_type) {
// use short_name to avoid getting extra long cached graph keys with ops such as cat_out(), etc.
std::string getMPSTypeString(ScalarType scalar_type, bool short_name) {
switch (scalar_type) {
case ScalarType::Double:
case ScalarType::Float:
return "Float32";
return short_name ? "f32" : "Float32";
case ScalarType::Half:
return "Float16";
return short_name ? "f16" : "Float16";
case ScalarType::Int:
return "Int32";
return short_name ? "i32" : "Int32";
case ScalarType::Long:
return "Int64";
return short_name ? "i64" : "Int64";
case ScalarType::Short:
return "Int16";
return short_name ? "i16" : "Int16";
case ScalarType::Char:
return "Int8";
return short_name ? "i8" : "Int8";
case ScalarType::Byte:
return "UInt8";
return short_name ? "u8" : "UInt8";
case ScalarType::Bool:
return "Bool";
return short_name ? "b8" : "Bool";
default:
return "Undefined";
}
Expand Down Expand Up @@ -150,16 +151,16 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
return ss.str();
}

std::string getTensorsStringKey(const TensorList& tensors, bool use_scalar_value) {
std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype) {
std::string str;
// The key format per tensor would look like ":Float32[1,1,1,10]:"
for (const Tensor& tensor: tensors) {
str += ":";
if (tensor.defined()) {
str += getMPSTypeString(tensor.scalar_type()) + "[";
str += getMPSTypeString(tensor.scalar_type(), short_dtype) + "[";
// if tensor is a scalar
if (tensor.dim() == 0) {
str += (use_scalar_value ? std::to_string(tensor.item().to<double>()) : "Scalar");
str += "Scalar";
} else {
const NSString* ns_shape_key = [[getMPSShape(tensor) valueForKey:@"description"] componentsJoinedByString:@","];
str += std::string(ns_shape_key.UTF8String);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/mps/operations/BinaryOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ void binaryOpTensor(const Tensor& self, const Tensor& other, const Scalar& alpha

MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
string key = op_name + getTensorsStringKey({self, other, output_}, /*use_scalar_value*/ false);
string key = op_name + getTensorsStringKey({self, other, output_});
BinaryOpCachedGraph* cachedGraph = static_cast<BinaryOpCachedGraph *>(cache_->LookUp(key));

if(!cachedGraph) {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/mps/operations/PointwiseOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
MPSGraphCache* cache_ = MPSGraphCache::getInstance();

@autoreleasepool {
string key = op_name + getTensorsStringKey({self, tensor1, tensor2}, false);
string key = op_name + getTensorsStringKey({self, tensor1, tensor2});

CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);

Expand Down
146 changes: 39 additions & 107 deletions aten/src/ATen/native/mps/operations/Shape.mm
Original file line number Diff line number Diff line change
Expand Up @@ -182,25 +182,6 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
}
}

inline c10::MemoryFormat compute_output_memory_format(const TensorList &inputs) {
c10::optional<c10::MemoryFormat> format = c10::nullopt;
for (auto &t : inputs) {
auto f = t.suggest_memory_format();
if (!format.has_value()) {
format = f;
continue;
}
if (format.value() == f) {
continue;
}
bool contiguous = (format.value() == c10::MemoryFormat::Contiguous || f == c10::MemoryFormat::Contiguous || format.value() != f);
if (contiguous) {
return c10::MemoryFormat::Contiguous;
}
}
return format.value();
}

TORCH_IMPL_FUNC(cat_out_mps)
(const ITensorListRef& inputs,
int64_t dimension,
Expand All @@ -214,17 +195,25 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
if (out.numel() == 0) {
return;
}

auto materialized_inputs = inputs.materialize();
auto out_dtype = at::native::result_type(inputs);

int idx = 0;
for(const Tensor& t : materialized_inputs) {
TORCH_CHECK(t.dim() > 0,
"zero-dimensional tensor (at position ", idx, ") cannot be concatenated");
TORCH_CHECK(t.dim() > 0, "zero-dimensional tensor (at position ", idx, ") cannot be concatenated");
auto lap = at::get_overlap_status(out, t);
TORCH_CHECK(lap != at::MemOverlapStatus::Partial && lap != at::MemOverlapStatus::Full,
"torch.cat(): unsupported operation: the input tensors cannot refer to any "
"of the output memory locations. Found overlap in input tensor ", idx);
idx++;
}
// Check for type promotion
TORCH_CHECK(canCast(out_dtype, out.scalar_type()),
"torch.cat(): input types can't be cast to the desired output type ", out.scalar_type());
TORCH_CHECK(inputs.size() > 0,"torch.cat(): invalid number of inputs ", inputs.size());

dimension = legacy_cat_wrap_dim(dimension, materialized_inputs);
TORCH_CHECK(dimension >= 0, "torch.cat(): invalid dimension ", dimension);

// previously, size [0] tensors were the only possible empty tensors; thus, it
// wasn't possible to cat empty tensors unless all the other tensors were
Expand All @@ -235,28 +224,6 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
auto should_skip = [](const Tensor& t) {
return t.dim() == 1 && at::native::size(t, 0) == 0;
};


// Check for type promotion
TORCH_CHECK(
canCast(result_type(inputs), out.scalar_type()),
"torch.cat(): input types ",
" can't be cast to the desired output type ",
out.scalar_type());

// Inputs cannot alias the output tensor
idx = 0;
for(const Tensor& t : materialized_inputs) {
auto lap = at::get_overlap_status(out, t);
TORCH_CHECK(
lap != at::MemOverlapStatus::Partial &&
lap != at::MemOverlapStatus::Full,
"torch.cat(): unsupported operation: the input tensors cannot refer to any "
"of the output memory locations. Found overlap in input "
"tensor ",
idx);
idx++;
}
at::assert_no_internal_overlap(out);

Tensor notSkippedTensor;
Expand All @@ -276,38 +243,22 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
notSkippedTensor = t;
tensor_idx++;
}

// If all inputs are empty tensors, return an empty tensor
if (!notSkippedTensor.defined()) {
return;
}

TORCH_CHECK(
inputs.size() > 0,
"torch.cat(): invalid number of inputs ",
inputs.size());
TORCH_CHECK(dimension >= 0, "torch.cat(): invalid dimension ", dimension);

for (const Tensor& t : inputs) {
TORCH_CHECK(
t.device() == notSkippedTensor.device(),
"torch.cat(): all input tensors must be on the same device. Received ",
t.device(),
" and ",
notSkippedTensor.device());
TORCH_CHECK(t.device() == notSkippedTensor.device(),
"torch.cat(): all input tensors must be on the same device. Received ",
t.device(), " and ", notSkippedTensor.device());
}
TORCH_CHECK(out.device() == notSkippedTensor.device(),
"torch.cat(): all input tensors and out must be on the same device, but inputs are on ",
notSkippedTensor.device(), " and out is on ", out.device());

TORCH_CHECK(
out.device() == notSkippedTensor.device(),
"torch.cat(): all input tensors and out must be on the same device, but inputs are on ",
notSkippedTensor.device(),
" and out is on ",
out.device());

// TODO: memory_format is now an argument?
// // TODO: Factor out `compute_output_memory_format`
// c10::MemoryFormat memory_format = compute_output_memory_format(inputs);

if (out.suggest_memory_format() == MemoryFormat::ChannelsLast) {
out.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
}
std::vector<int64_t> size(notSkippedTensor.sizes().vec());

// Compute size of the result in the cat dimension
Expand All @@ -322,48 +273,30 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
cat_dim_size += at::native::size(tensor, dimension);
idx++;
}

// Compute the size of the result
size[dimension] = cat_dim_size;

// skip resizing if size of result is same as expected
if (out.sizes() != size) {
out.resize_(size, memory_format);
}

if (out.numel() == 0) {
return;
}

// Get stream
MPSStream* stream = getCurrentMPSStream();

struct CachedGraph : public MPSCachedGraph
{
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
std::vector<MPSGraphTensor*> inputTensors_;
MPSGraphTensor* outputTensor_ = nil;
};

MPSGraphCache *cache_ = MPSGraphCache::getInstance();

// Make string out of skipped tensor indices
string skipped_indices_string = "";
for(int idx : skipped_tensor_indices)
skipped_indices_string += (std::to_string(idx)+",");
string input_types = "";
for(const Tensor& tensor : materialized_inputs)
input_types += (getMPSTypeString(tensor.scalar_type())+",");

@autoreleasepool {
string key = "cat_out_mps:" + getMPSTypeString(result_type(inputs))
+ ":" + to_string(inputs.size())
+ ":" + skipped_indices_string
+ ":" + input_types
+ ":" + to_string(dimension);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
string key = "cat_out_mps:" + to_string(dimension) + getTensorsStringKey(input_tensors, /*short_dtype*/true) + ":" +
(memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");

CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;

@autoreleasepool {
Expand All @@ -375,15 +308,15 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
newCachedGraph->inputTensors_.reserve(len_tensor_array);

for (const auto idx : c10::irange(len_tensor_array)) {
auto scalar_type = getMPSScalarType(input_tensors[idx].scalar_type());
if (input_tensors[idx].scalar_type() == kBool) {
const Tensor& tensor = input_tensors[idx];
auto scalar_type = getMPSScalarType(tensor.scalar_type());
if (tensor.scalar_type() == kBool) {
scalar_type = MPSDataTypeInt8;
}

newCachedGraph->inputTensors_[idx] = mpsGraphUnrankedPlaceHolder(mpsGraph, scalar_type);
if (input_tensors[idx].scalar_type() != result_type(inputs)) {
newCachedGraph->inputTensors_[idx] = mpsGraphRankedPlaceHolder(mpsGraph, scalar_type, getMPSShape(tensor, memory_format));
if (tensor.scalar_type() != out_dtype) {
castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx]
toType:getMPSDataType(result_type(inputs))
toType:getMPSDataType(out_dtype)
name:@"castInput"];
} else {
castInputTensors[idx] = newCachedGraph->inputTensors_[idx];
Expand All @@ -395,16 +328,16 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray
dimension:dimension // Maybe convert this from int64_t -> int32
name:nil];
if(getMPSDataType(result_type(inputs)) == MPSDataTypeBool) {
if(getMPSDataType(out_dtype) == MPSDataTypeBool) {
outputTensor = [mpsGraph castTensor:outputTensor
toType:MPSDataTypeBool
name:@"outputTensor"];
}
newCachedGraph->outputTensor_ = outputTensor;
newCachedGraph->outputTensor_ = memory_format == MemoryFormat::ChannelsLast ?
convertNHWCtoNCHW(mpsGraph, outputTensor) : outputTensor;
}
return newCachedGraph;
});
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}

std::vector<Placeholder> inputPlaceholders;
Expand All @@ -416,9 +349,9 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
if (tensor.scalar_type() == kBool) {
scalar_type = MPSDataTypeInt8;
}
Placeholder currentInputPlaceholder = Placeholder(
cachedGraph->inputTensors_[t_idx], tensor, /*mpsShape=*/nil, /*gatherTensorData=*/true, scalar_type);
inputPlaceholders.push_back(currentInputPlaceholder);
inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor,
getMPSShape(tensor, memory_format),
memory_format != MemoryFormat::ChannelsLast, scalar_type);
t_idx++;
}
i++;
Expand All @@ -439,9 +372,8 @@ void check_shape_except_dim(const Tensor &first, const Tensor &second,
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};

mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
}

}

} // namespace native
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/mps/operations/UnaryOps.mm
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void unary_op(const Tensor& self, const Tensor& output, std::string op_name, Una
}
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
string key = op_name + getTensorsStringKey({self, output}, /*use_scalar_value*/ false);
string key = op_name + getTensorsStringKey({self, output});
auto cachedGraph = cache_->LookUpAs<MPSUnaryCachedGraph>(key);

if(!cachedGraph) {
Expand Down
2 changes: 1 addition & 1 deletion test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8873,7 +8873,7 @@ def test_numpy_ref_mps(self, device, dtype, op):
# does not support float64 Tensors.
# A few ops are currently broken on their reference inputs, but not their sample inputs. These should
# get patched up and this workaround removed.
broken_on_ref_inputs = op.name in ['cat', 'clamp', 'where']
broken_on_ref_inputs = op.name in ['clamp', 'where']
inputs = op.reference_inputs(device, dtype) if not broken_on_ref_inputs else op.sample_inputs(device, dtype)
for sample_input in inputs:
self.compare_with_reference(op, op.ref, sample_input)
Expand Down