|
173 | 173 | // Do we need to slice into the src tensor? |
174 | 174 | bool needSlice = false; |
175 | 175 | bool inputNeedSlice = false; |
| 176 | + bool needsCast = false; |
176 | 177 |
|
177 | 178 | for (const auto i : c10::irange(num_input_dims)) { |
178 | 179 | TORCH_CHECK(i == dim || [index_shape[i] intValue] <= [input_shape[i] intValue], "Index dim must not exceed input dim except at gathering axis") |
|
184 | 185 | } |
185 | 186 | TORCH_CHECK(reduce != "mean", "Scatter reduce mean mode not yet supported in MPS") |
186 | 187 |
|
187 | | - bool needsCast = isIntegralType(self.scalar_type(), true) && |
188 | | - (reduce != "set" || self.scalar_type() == ScalarType::Byte); |
| 188 | + MPSDataType src_type = getMPSDataType(src.scalar_type()); |
| 189 | + if (reduce != "set" || self.scalar_type() == ScalarType::Byte) { |
| 190 | + src_type = isFloatingType(src.scalar_type()) ? MPSDataTypeFloat32 : MPSDataTypeInt32; |
| 191 | + needsCast = true; |
| 192 | + } |
| 193 | + |
189 | 194 | string key = func_name + getTensorsStringKey({self, index, src, output}) + ":" + std::to_string(dim) + ":" + std::string(reduce); |
190 | 195 | CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key)); |
191 | 196 | if(!cachedGraph) { |
|
196 | 201 | MPSGraph* mpsGraph = make_mps_graph(); |
197 | 202 | newCachedGraph = new CachedGraph(mpsGraph); |
198 | 203 |
|
199 | | - MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self.scalar_type()), input_shape); |
200 | | - MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(index.scalar_type()), index_shape); |
201 | | - MPSGraphTensor* srcTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(src.scalar_type()), src_shape); |
| 204 | + MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self); |
| 205 | + MPSGraphTensor* indexTensor = mpsGraphRankedPlaceHolder(mpsGraph, index); |
| 206 | + MPSGraphTensor* srcTensor = mpsGraphRankedPlaceHolder(mpsGraph, src); |
202 | 207 |
|
203 | | - MPSGraphTensor* getSrc = srcTensor; |
204 | | - MPSGraphTensor* getInput = inputTensor; |
| 208 | + MPSGraphTensor* outputTensor = nil; |
| 209 | + MPSGraphTensor* castSrcTensor = srcTensor; |
| 210 | + MPSGraphTensor* castInputTensor = inputTensor; |
| 211 | + |
| 212 | + if (needsCast) { |
| 213 | + castSrcTensor = [mpsGraph castTensor:srcTensor toType:src_type name:@"cast"]; |
| 214 | + castInputTensor = [mpsGraph castTensor:inputTensor toType:src_type name:@"cast"]; |
| 215 | + } |
| 216 | + MPSGraphTensor* castIndexTensor = [mpsGraph castTensor:indexTensor toType:MPSDataTypeInt32 name:@"cast"]; |
| 217 | + |
| 218 | + MPSGraphTensor* slicedSrc = castSrcTensor; |
| 219 | + MPSGraphTensor* slicedInput = castInputTensor; |
205 | 220 |
|
206 | 221 | // Use in case input needs to be smaller to get scatter |
207 | | - NSMutableArray<NSNumber*>* scatterInputShape = [NSMutableArray arrayWithArray:input_shape];; |
| 222 | + NSMutableArray<NSNumber*>* scatterInputShape = [NSMutableArray arrayWithArray:input_shape]; |
208 | 223 |
|
209 | | - // Slice into the src tensor IF NEEDED |
| 224 | + // Slice into the src or input tensors IF NEEDED |
210 | 225 | if (needSlice || inputNeedSlice) { |
211 | 226 | NSMutableArray<NSNumber*> *starts = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims]; |
212 | 227 | NSMutableArray<NSNumber*> *strides = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims]; |
213 | 228 | NSMutableArray<NSNumber*> *ends_src = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims]; |
214 | 229 |
|
215 | 230 | for (const auto i : c10::irange(num_input_dims)) { |
216 | | - // All strides are 1 |
217 | 231 | strides[i] = @1; |
218 | | - // All starts are 0 |
219 | 232 | starts[i] = @0; |
220 | 233 | ends_src[i] = index_shape[i]; |
221 | 234 | scatterInputShape[i] = (i != dim) ? index_shape[i] : input_shape[i]; |
222 | 235 | } |
223 | 236 | if (needSlice) { |
224 | | - getSrc = [mpsGraph sliceTensor:srcTensor |
| 237 | + slicedSrc = [mpsGraph sliceTensor:castSrcTensor |
225 | 238 | starts:starts |
226 | 239 | ends:ends_src |
227 | 240 | strides:strides |
228 | 241 | name:nil]; |
229 | 242 | } |
230 | 243 | if (inputNeedSlice) { |
231 | | - getInput = [mpsGraph sliceTensor:inputTensor |
| 244 | + slicedInput = [mpsGraph sliceTensor:castInputTensor |
232 | 245 | starts:starts |
233 | 246 | ends:scatterInputShape |
234 | 247 | strides:strides |
235 | 248 | name:nil]; |
236 | 249 | } |
237 | 250 | } |
238 | | - MPSGraphTensor* outputTensor = nil; |
239 | | - MPSGraphTensor* castSrcTensor = getSrc; |
240 | | - MPSGraphTensor* castInputTensor = getInput; |
241 | | - |
242 | | - if (needsCast) { |
243 | | - castSrcTensor = castMPSTensor(mpsGraph, getSrc, ScalarType::Int); |
244 | | - castInputTensor = castMPSTensor(mpsGraph, getInput, ScalarType::Int); |
245 | | - } |
246 | | - MPSGraphTensor* castIndexTensor = castMPSTensor(mpsGraph, indexTensor, ScalarType::Int); |
247 | | - |
248 | 251 | MPSGraphScatterMode scatter_mode = MPSGraphScatterModeSet; |
249 | 252 |
|
250 | 253 | if(reduce == "sum" || reduce == "add") |
|
258 | 261 |
|
259 | 262 | // Scatter this into the input with set mode |
260 | 263 | MPSGraphTensor* scatterTensor = [mpsGraph scatterAlongAxis: (NSInteger) dim |
261 | | - withDataTensor: castInputTensor |
262 | | - updatesTensor: castSrcTensor |
| 264 | + withDataTensor: slicedInput |
| 265 | + updatesTensor: slicedSrc |
263 | 266 | indicesTensor: castIndexTensor |
264 | 267 | mode: scatter_mode |
265 | 268 | name: nil]; |
|
301 | 304 | withShape:@[@-1] |
302 | 305 | name:nil]; |
303 | 306 |
|
304 | | - outputTensor = [mpsGraph scatterNDWithDataTensor:inputTensor |
| 307 | + outputTensor = [mpsGraph scatterNDWithDataTensor:castInputTensor |
305 | 308 | updatesTensor:flatValuesTensor |
306 | 309 | indicesTensor:scatter_fullIndexTensor |
307 | 310 | batchDimensions:0 |
|
0 commit comments