Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion aten/src/ATen/mps/MPSDevice.mm
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,32 @@
static std::unique_ptr<MPSDevice> mps_device;
static c10::once_flag mpsdev_init;

static inline MTLLanguageVersion getMetalLanguageVersion(const id<MTLDevice>& device) {
#if defined(__MAC_10_13) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_13
#else
#error "Metal is not available on the current platform."
#endif

// MPS Advanced Indexing needs at least Metal 2.0 (support for Argument Buffer and function constants)
MTLLanguageVersion languageVersion;
if (@available(macOS 13.0, *)) {
languageVersion = MTLLanguageVersion3_0;
} else if (@available(macOS 12.0, *)) {
languageVersion = MTLLanguageVersion2_4;
} else if (@available(macOS 11.0, *)) {
languageVersion = MTLLanguageVersion2_3;
} else if (@available(macOS 10.15, *)) {
languageVersion = MTLLanguageVersion2_2;
} else if (@available(macOS 10.14, *)) {
languageVersion = MTLLanguageVersion2_1;
} else if (@available(macOS 10.13, *)) {
languageVersion = MTLLanguageVersion2_0;
}

TORCH_CHECK([device supportsFamily:MTLGPUFamilyMac2], "Missing Metal support for MTLGPUFamilyMac2");
return languageVersion;
}

MPSDevice* MPSDevice::getInstance() {
c10::call_once(mpsdev_init, [] {
mps_device = std::unique_ptr<MPSDevice>(new MPSDevice());
Expand All @@ -22,8 +48,11 @@
assert(_mtl_device);
NSError* error = nil;
if (!_mtl_indexing_library) {
MTLCompileOptions *options = [MTLCompileOptions new];
[options setLanguageVersion: getMetalLanguageVersion(_mtl_device)];
[options setFastMathEnabled: YES];
_mtl_indexing_library = [_mtl_device newLibraryWithSource: [NSString stringWithCString: mps::indexing_metal_shaders encoding:NSASCIIStringEncoding]
options: nil
options: options
error: &error];
TORCH_CHECK(_mtl_indexing_library, "Failed to create indexing library, error: ", [[error description] UTF8String]);
}
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/mps/OperationUtils.mm
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
{
NSInteger sz_i = (i < sz) ? t.size(i) : 1;

NSNumber* number = [NSNumber numberWithInt:sz_i];
NSNumber* number = [NSNumber numberWithInteger:sz_i];
numbers[i] = number;
}
return [NSArray arrayWithObjects:numbers count:sz_];
Expand All @@ -213,7 +213,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) {
{
NSInteger sz_i = (i < sz) ? sizes[i] : 1;

NSNumber* number = [NSNumber numberWithInt:sz_i];
NSNumber* number = [NSNumber numberWithInteger:sz_i];
numbers[i] = number;
}
return [NSArray arrayWithObjects:numbers count:sz_];
Expand Down
234 changes: 234 additions & 0 deletions aten/src/ATen/native/mps/operations/Distributions.mm
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
#include <ATen/native/TensorIterator.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/NativeFunctions.h>
#include <ATen/AccumulateType.h>
#include <torch/library.h>

namespace at {
Expand Down Expand Up @@ -574,5 +578,235 @@ static void check_from_to_in_range(int64_t from, int64_t to_inc, ScalarType scal

}

Tensor& multinomial_with_replacement_mps_kernel(
const Tensor& self,
const int64_t n_sample,
c10::optional<Generator> generator,
Tensor& result) {

using namespace mps;

int inputSize = self.dim();
int numDist =
inputSize == 1 ? 1 : self.size(0);
int numCategories =
inputSize == 1 ? self.size(0) : self.size(1);

// Restructure data for 2d
auto self_v = inputSize == 1 ? self.view({numDist, numCategories}) : self;
auto result_v = inputSize == 1 ? result.view({numDist, n_sample}) : result;

MPSStream* stream = getCurrentMPSStream();
uint64_t seed_ = c10::detail::getNonDeterministicRandom(true);

@autoreleasepool {
MPSShape* prob_shape = getMPSShape(self_v);
MPSGraph* mpsGraph = make_mps_graph();

auto prob_dtype = getMPSDataType(self_v.scalar_type());
auto result_dtype = getMPSDataType(result.scalar_type());

// This is probability weights
MPSGraphTensor *probTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self_v.scalar_type()), prob_shape);

MPSGraphTensor *sumProbs = [mpsGraph reductionSumWithTensor:probTensor
axis:-1
name:nil];

MPSGraphTensor *normalizedProbs = [mpsGraph divisionWithPrimaryTensor:probTensor
secondaryTensor:sumProbs
name:nil];

auto ns_numCategories = [NSNumber numberWithInt:numCategories];
auto ns_numDist = [NSNumber numberWithInt:numDist];
auto ns_n_sample = [NSNumber numberWithInt:n_sample];

MPSGraphTensor *ones = [mpsGraph constantWithScalar:1.0f
shape:@[ns_numCategories, ns_numCategories]
dataType:prob_dtype];
MPSGraphTensor *upperTriangle = [mpsGraph bandPartWithTensor:ones
numLower:0
numUpper:-1
name:nil];
MPSGraphTensor *upperProbRange = [mpsGraph matrixMultiplicationWithPrimaryTensor:normalizedProbs
secondaryTensor:upperTriangle
name:nil];

MPSGraphTensor *lowerProbRange = [mpsGraph subtractionWithPrimaryTensor:upperProbRange
secondaryTensor:normalizedProbs
name:nil];

upperProbRange = [mpsGraph reshapeTensor:upperProbRange
withShape:@[ns_numDist, @1, ns_numCategories]
name:nil];
lowerProbRange = [mpsGraph reshapeTensor:lowerProbRange
withShape:@[ns_numDist, @1, ns_numCategories]
name:nil];

MPSGraphTensor *stateTensor = [mpsGraph randomPhiloxStateTensorWithSeed:seed_
name:nil];
MPSGraphRandomOpDescriptor *descriptor = [MPSGraphRandomOpDescriptor descriptorWithDistribution:MPSGraphRandomDistributionUniform
dataType:prob_dtype];
NSArray<MPSGraphTensor*> *generatorTensors = [mpsGraph randomTensorWithShape:@[ns_numDist, ns_n_sample, @1]
descriptor:descriptor
stateTensor:stateTensor
name:nil];
MPSGraphTensor *randomTensor = generatorTensors[0];

auto broadcastShape = @[ns_numDist ,ns_n_sample, ns_numCategories];
int broadcastShapeVals[3] = {numDist, n_sample, numCategories};
MPSGraphTensor *broadcastShapeTensor = [mpsGraph constantWithData:[NSData dataWithBytes:broadcastShapeVals length:sizeof(int) * broadcastShape.count]
shape:@[[NSNumber numberWithUnsignedInteger:broadcastShape.count]]
dataType:MPSDataTypeUInt32];

MPSGraphTensor *samplesTensor = [mpsGraph broadcastTensor:randomTensor
toShape:broadcastShape
name:nil];
MPSGraphTensor *sampleAbove = [mpsGraph greaterThanWithPrimaryTensor:samplesTensor
secondaryTensor:lowerProbRange
name:nil];
MPSGraphTensor *sampleBelow = [mpsGraph lessThanWithPrimaryTensor:samplesTensor
secondaryTensor:upperProbRange
name:nil];
MPSGraphTensor *sampleWithin = [mpsGraph logicalANDWithPrimaryTensor:sampleAbove
secondaryTensor:sampleBelow
name:nil];
MPSGraphTensor *sampleMask = [mpsGraph castTensor:sampleWithin
toType:MPSDataTypeInt32
name:@"sampleMask"];
MPSGraphTensor *categoriesTensor = [mpsGraph coordinateAlongAxis:-1
withShapeTensor:broadcastShapeTensor
name:nil];
MPSGraphTensor *binnedSamplesTensor = [mpsGraph multiplicationWithPrimaryTensor:categoriesTensor
secondaryTensor:sampleMask
name:nil];
MPSGraphTensor *reducedTensor = [mpsGraph reductionSumWithTensor:binnedSamplesTensor
axis:-1
name:nil];
MPSGraphTensor *reshapeTensor = [mpsGraph reshapeTensor:reducedTensor
withShape:@[ns_numDist ,ns_n_sample]
name:nil];
MPSGraphTensor *resultTensor = [mpsGraph castTensor:reshapeTensor
toType:getMPSDataType(result.scalar_type())
name:@"resultTensor"];

auto probPlaceholder = Placeholder(probTensor, self_v);
auto outputPlaceholder = Placeholder(resultTensor, result_v);
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
probPlaceholder.getMPSGraphTensor() : probPlaceholder.getMPSGraphTensorData()
};
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};

runMPSGraph(stream, mpsGraph, feeds, results);
}

return result;

}

/* The largest consecutive integer representable in float32 (2^24) */
constexpr int64_t FLOAT32_MAX_CONSECUTIVE_INT = 1 << (FLT_MANT_DIG);

Tensor& multinomial_out_mps(const Tensor& self,
int64_t n_sample,
bool with_replacement,
c10::optional<Generator> gen,
Tensor& result) {

TORCH_CHECK(
result.device() == self.device(),
"multinomial arguments must have the same device");
TORCH_CHECK(
self.dim() > 0 && self.dim() <= 2, "prob_dist must be 1 or 2 dim");
TORCH_CHECK(
at::isFloatingType(self.scalar_type()),
"multinomial only supports floating-point dtypes for input, got: ",
self.scalar_type());
TORCH_CHECK(result.scalar_type() == ScalarType::Long,
"multinomial expects Long tensor out, got: ", result.scalar_type());
TORCH_CHECK(n_sample > 0, "cannot sample n_sample <= 0 samples");
int64_t n_categories = self.size(-1);
TORCH_CHECK(with_replacement || (n_sample <= n_categories),
"cannot sample n_sample > prob_dist.size(-1) samples without replacement");
// Since the index tensor is float, numCategories cannot exceed max
// float integer precision
TORCH_CHECK(
n_categories <= FLOAT32_MAX_CONSECUTIVE_INT,
"number of categories cannot exceed 2^24");

if (self.dim() == 1) {
result.resize_({n_sample});
} else {
const int64_t n_dist = self.size(0);
result.resize_({n_dist, n_sample});
}
if (result.numel() == 0) {
return result;
}

// Fast-path for no replacement.
// Reference:
// https://github.com/pytorch/pytorch/issues/11931#issuecomment-625882503
// Half is not supported on CPU.
TORCH_CHECK(
!(self.device().is_cpu() && self.scalar_type() == ScalarType::Half),
"multinomial is not implemented for half on CPU");
if (!with_replacement) {
// Sanity checks on `self`.
auto is_valid = ((self.max() < INFINITY) & (self.min() >= 0)).item();
TORCH_CHECK(
is_valid.to<bool>(),
"probability tensor contains either `inf`, `nan` or element < 0");
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool zero_prob_condition;
if (self.dim() == 1){
zero_prob_condition = (self.sum() == 0).item().to<bool>();
} else {
zero_prob_condition = (self.sum(1) == 0).sum().item().to<bool>();
}
TORCH_CHECK(
!zero_prob_condition,
"invalid multinomial distribution (sum of probabilities <= 0)");

// The algorithm is from gumbel softmax.
// s = argmax( logp - log(-log(eps)) ) where eps ~ U(0, 1)
// Here we can apply exp to the formula which will not affect result of
// argmax or topk. Then we have
// s = argmax( p / (-log(eps)) ) where eps ~ U(0, 1).
// We can also simplify the formula above by
// s = argmax( p / q ) where q ~ Exp(1)
Tensor q = at::empty_like(self).exponential_(1, gen);
// In theory the probability to generate 0 from exponential distribution is
// 0. However, on CUDA side there is a protection to avoid 0s, but on CPU
// side, there is a very low probability to generate 0 from
// exponential<double>. The probability is about 2^(-DBL_MANT_DIG). We just
// ignore it here, but there may be some risk to get invalid output on CPU.
at::div_out(q, self, q);
if (n_sample == 1) {
at::argmax_out(result, q, /*dim=*/-1, /*keepdim=*/true);
} else {
Tensor vals = at::empty(result.sizes(), self.options());
at::topk_out(vals, result, q, n_sample);
}
return result;
}

result = multinomial_with_replacement_mps_kernel(const_cast<Tensor&>(self), n_sample, gen, result);

return result;
}

Tensor multinomial_mps(
const Tensor& self,
int64_t n_sample,
bool with_replacement,
c10::optional<Generator> gen) {
Tensor result = at::empty({0}, self.options().dtype(kLong));
multinomial_out_mps(self, n_sample, with_replacement, gen, result);
return result;
}

} // namespace native
} // namespace at
4 changes: 4 additions & 0 deletions aten/src/ATen/native/mps/operations/Indexing.mm
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ bool dispatchIndexSelectKernel(TensorIteratorBase& iter, IntArrayRef index_size,
using namespace mps;

@autoreleasepool {
if (iter.numel() == 0) {
return true;
}

const Tensor& inputTensor = iter.tensor(1);
Tensor outputTensor = iter.tensor(0);

Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7725,11 +7725,13 @@
- func: multinomial.out(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: multinomial_out
MPS: multinomial_out_mps

- func: multinomial(Tensor self, int num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor
variants: method, function
dispatch:
CPU, CUDA: multinomial
MPS: multinomial_mps

- func: lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
Expand Down
20 changes: 20 additions & 0 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4617,6 +4617,26 @@ def helper(shape):
helper(10000)
helper((10000, 40))

def test_multinomial(self):
# Test with num_dist = 1
def helper(probs, compare_mean, compare_var, num_samples=5, replacement=True):
cpu_prob_tensor = torch.tensor(probs, device='cpu', dtype=torch.float, requires_grad=False)
prob_tensor = cpu_prob_tensor.detach().clone().to('mps')

mps_out = torch.multinomial(prob_tensor, num_samples, replacement=replacement)
if(not replacement):
print(mps_out.to('cpu'))
else:
# Compare "real" with theoretical values
print(mps_out.to('cpu').float().mean(), compare_mean)
print(mps_out.to('cpu').float().std() ** 2, compare_var)

# TODO: Add tests for data types
helper(np.array([[0., 0., 0., 0.5, 0.5]]), (3 + 4)/2, (12.5 - 3.5 ** 2), 100000)
helper(np.array([[.2, .2, .2, .2, .2]]), (0 + 1 + 2 + 3 + 4)/5, (6 - 2 * 2), 10000)
helper(np.array([[1, 1, 1, 1, 1]]), (0 + 1 + 2 + 3 + 4)/5, (6 - 2 * 2), 10000)
helper(np.array([1, 1, 1, 1, 1]), (0 + 1 + 2 + 3 + 4)/5, (6 - 2 * 2), 10000)
helper(np.array([[1, 1, 1, 1, 1, 1, 1]]), 0, 0, 7, False)

class TestNNMPS(NNTestCase):

Expand Down