@@ -317,64 +317,4 @@ function diagonal_indices_zero_indexed(m::Integer, n::Integer, k::Integer=0)
317317 return indices
318318end
319319
320- # # This is quite handy to have but is not generalized enough to be put into Ops? Or maybe
321- # # we can document it and place it there under a different name. It takes a list of values
322- # # and a list of indices and constructs a matrix with the values at the indices.
323- function simple_scatter_op (
324- shape, scatter_indices:: TracedRArray{Int64,2} , updates:: TracedRArray{T,1}
325- ) where {T}
326- @assert length (updates) == size (scatter_indices, 1 )
327- @assert size (scatter_indices, 2 ) == 2
328-
329- update_computation = MLIR. IR. Region ()
330- block = MLIR. IR. Block (
331- [mlir_type (TracedRNumber{T}), mlir_type (TracedRNumber{T})],
332- [MLIR. IR. Location (), MLIR. IR. Location ()],
333- )
334- return_op = MLIR. Dialects. stablehlo. return_ ([MLIR. IR. argument (block, 2 )])
335- MLIR. IR. rmfromparent! (return_op)
336- push! (block, return_op)
337- pushfirst! (update_computation, block)
338-
339- init_array = Ops. constant (fill (zero (T), shape)). mlir_data
340-
341- # ! format: off
342- scatter_dimension_numbers = MLIR. API. stablehloScatterDimensionNumbersGet (
343- MLIR. IR. context (),
344- 0 , Int64[],
345- 2 , Int64[0 , 1 ],
346- 0 , Int64[],
347- 0 , Int64[],
348- 2 , Int64[0 , 1 ],
349- 1
350- )
351- # ! format: on
352-
353- res = MLIR. IR. result (
354- MLIR. Dialects. stablehlo. scatter (
355- [init_array],
356- scatter_indices. mlir_data,
357- [updates. mlir_data];
358- result_0= [mlir_type (TracedRArray{T,2 }, shape)],
359- update_computation,
360- scatter_dimension_numbers,
361- ),
362- 1 ,
363- )
364-
365- return TracedRArray {T,2} ((), res, shape)
366- end
367-
368- # # The cartesian version doesn't exist in julia 1.10
369- function diagonal_indices_zero_indexed (m:: Integer , n:: Integer , k:: Integer = 0 )
370- idx1, idx2 = 1 + max (0 , - k), 1 + max (0 , k)
371- L = max (0 , k ≤ 0 ? min (m + k, n) : min (m, n - k))
372- indices = Matrix {Int} (undef, (L, 2 ))
373- for i in axes (indices, 1 )
374- indices[i, 1 ] = idx1 + i - 2
375- indices[i, 2 ] = idx2 + i - 2
376- end
377- return indices
378- end
379-
380320end
0 commit comments