Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gather(): Address indices validation and other algorithm nits #642

Merged
merged 9 commits into from
Apr 16, 2024
17 changes: 9 additions & 8 deletions index.bs
Original file line number Diff line number Diff line change
Expand Up @@ -2715,40 +2715,41 @@ partial interface MLGraphBuilder {
<div>
**Arguments:**
- *input*: an {{MLOperand}}. The input N-D tensor from which the values are gathered.
- *indices*: an {{MLOperand}}. The indices N-D tensor of the input values to gather. The values must be of type {{MLOperandDataType/"uint32"}} or {{MLOperandDataType/"int64"}} in the range [0, N-1] where N is the size of the input dimension indexed by *options.axis*.
- *indices*: an {{MLOperand}}. The indices N-D tensor of the input values to gather. The values must be of type {{MLOperandDataType/"uint32"}} or {{MLOperandDataType/"int64"}}, and must be in the range -N (inclusive) to N (exclusive) where N is the size of the input dimension indexed by *options.axis*, and a negative index means indexing from the end of the dimension.
- *options*: an optional {{MLGatherOptions}}. The optional parameters of the operation.

**Returns:** an {{MLOperand}}. The output N-D tensor of [=MLOperand/rank=] equal to the [=MLOperand/rank=] of *input* + the [=MLOperand/rank=] of *indices* - 1.
</div>

<div class="note">
The {{MLGraphBuilder/gather(input, indices, options)/indices}} parameter to {{MLGraphBuilder/gather()}} can not be clamped to the allowed range when the graph is built because the inputs are not known until execution. Implementations can introduce {{MLGraphBuilder/clamp()}} in the compiled graph if the required clamping behavior is not provided by the underlying platform. Similarly, if the underlying platform does not support negative indices, the implementation can introduce operations in the compiled graph to transform a negative index from the end of the dimension into a positive index.
</div>

<details open algorithm>
<summary>
The <dfn method for=MLGraphBuilder>gather(|input|, |indices|, |options|)</dfn> method steps are:
</summary>
1. If [=MLGraphBuilder/validating operand=] with [=this=] and any of |input| abd |indices| returns false, then [=exception/throw=] a {{TypeError}}.
1. If [=MLGraphBuilder/validating operand=] with [=this=] and any of |input| and |indices| returns false, then [=exception/throw=] a {{TypeError}}.
1. If |indices|'s [=MLOperand/dataType=] is neither {{MLOperandDataType/"uint32"}} nor {{MLOperandDataType/"int64"}}, then [=exception/throw=] a {{TypeError}}.
1. Let |shapeInput| be |input|'s [=MLOperand/shape=] and |rankInput| be |shapeInput|'s [=MLOperand/rank=].
1. Let |shapeIndices| be |indices|'s [=MLOperand/shape=].
1. Let |axis| be |options|.{{MLGatherOptions/axis}}.
1. Let |axisSize| be |input|'s [=MLOperand/shape=][|axis|]
1. If |axis| is greater than or equal to |rankInput|, then [=exception/throw=] a {{TypeError}}.
1. [=map/For each=] |index| → |value| of |indices|:
1. If |index| is greater than or equal to |axisSize|, then [=exception/throw=] a {{TypeError}}.
1. Let |dimCount| be zero.
1. Let |rankOutput| be zero.
1. Let |shapeOutput| be an empty list.
1. [=map/For each=] |size| → |value| of |shapeInput|:
1. [=list/For each=] |size| of |shapeInput|:
1. If |dimCount| is equal to |axis| then [=iteration/break=].
1. Set |shapeOutput|[|dimCount|] to |size|.
1. Increment |dimCount| by one.
1. Set |rankOutput| to |dimCount|.
1. Let |dimCount| be zero.
1. [=map/For each=] |size| → |value| of |shapeIndices|:
1. [=list/For each=] |size| of |shapeIndices|:
1. Set |shapeOutput|[|rankOutput| + |dimCount|] to |size|.
1. Increment |dimCount| by one.
1. Set |rankOutput| to |rankOutput| + |dimCount|.
1. Let |dimCount| be zero.
1. [=map/For each=] |size| → |value| of |shapeInput|:
1. [=list/For each=] |size| of |shapeInput|:
1. If |dimCount| is less than or equal to |axis| then [=iteration/continue=].
1. Set |shapeOutput|[|rankOutput| + |dimCount| - |axis| - 1] to |size|.
1. Increment |dimCount| by one.
Expand Down