diff --git a/aten/src/ATen/native/mps/MPSGraphVenturaOps.h b/aten/src/ATen/native/mps/MPSGraphVenturaOps.h index 767c785303291..16e3845f27ebb 100644 --- a/aten/src/ATen/native/mps/MPSGraphVenturaOps.h +++ b/aten/src/ATen/native/mps/MPSGraphVenturaOps.h @@ -33,4 +33,60 @@ typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode) - (MPSGraphTensor * _Nonnull)inverseOfTensor:(MPSGraphTensor * _Nonnull) inputTensor name:(NSString * _Nullable)name; + +- (MPSGraphTensor * _Nonnull) resizeNearestWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor + sizeTensor:(MPSGraphTensor * _Nonnull) size + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + centerResult:(BOOL) centerResult + alignCorners:(BOOL) alignCorners + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeNearestWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor + sizeTensor:(MPSGraphTensor * _Nonnull) size + scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeBilinearWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor + sizeTensor:(MPSGraphTensor * _Nonnull) size + centerResult:(BOOL) centerResult + alignCorners:(BOOL) alignCorners + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeBilinearWithTensor:(MPSGraphTensor * _Nonnull) imagesTensor + sizeTensor:(MPSGraphTensor * _Nonnull) size + scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeNearestWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient + input:(MPSGraphTensor * _Nonnull) input + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + centerResult:(BOOL) centerResult + alignCorners:(BOOL) alignCorners + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeNearestWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient + input:(MPSGraphTensor * _Nonnull) input + scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset + nearestRoundingMode:(MPSGraphResizeNearestRoundingMode) nearestRoundingMode + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeBilinearWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient + input:(MPSGraphTensor * _Nonnull) input + centerResult:(BOOL) centerResult + alignCorners:(BOOL) alignCorners + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; + +- (MPSGraphTensor * _Nonnull) resizeBilinearWithGradientTensor:(MPSGraphTensor * _Nonnull) gradient + input:(MPSGraphTensor * _Nonnull) input + scaleOffsetTensor:(MPSGraphTensor * _Nonnull) scaleOffset + layout:(MPSGraphTensorNamedDataLayout) layout + name:(NSString * _Nullable) name; @end diff --git a/aten/src/ATen/native/mps/operations/Distributions.mm b/aten/src/ATen/native/mps/operations/Distributions.mm index ecb528301f2c0..dcaf8baf6c31c 100644 --- a/aten/src/ATen/native/mps/operations/Distributions.mm +++ b/aten/src/ATen/native/mps/operations/Distributions.mm @@ -3,6 +3,7 @@ #include #include #include +#include #include #include diff --git a/aten/src/ATen/native/mps/operations/Inverse.mm b/aten/src/ATen/native/mps/operations/Inverse.mm index ca028a1a864b8..2975fd9875949 100644 --- a/aten/src/ATen/native/mps/operations/Inverse.mm +++ b/aten/src/ATen/native/mps/operations/Inverse.mm @@ -1,5 +1,6 @@ #include #include +#include #include #include