Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make operand data type and rank validation table-driven #657

Merged
merged 73 commits into from
Nov 15, 2024

Conversation

inexorabletash
Copy link
Member

@inexorabletash inexorabletash commented Apr 26, 2024

Improve readability of input operand data type and rank by introducing a table within each method definition that lines the restrictions for positional arguments and options. Steps are updated to reference the table columns.


Preview | Diff

Improve readability of input operand data type and rank by introducing
a table within each method definition that outlines the restrictions
for positional arguments and options. Steps are updated to reference
the table columns.
@inexorabletash inexorabletash changed the title Data type table Make operand data type and rank validation table-driven Apr 27, 2024
@zolkis
Copy link
Collaborator

zolkis commented Apr 29, 2024

I'd prefer this style than the one in #646

@inexorabletash
Copy link
Member Author

inexorabletash commented Apr 30, 2024

A few questions to consider:

dataType

  • For most "secondary" operands, the type is constrained to be the same as the primary operand. The relevant steps could be written as either:
    • "If mean’s dataType is not one of its allowed data types, then throw a TypeError." - more consistent.
    • "If mean’s dataType is not the same as input's dataType, then throw a TypeError." - clearer.

rank

  • In most cases there's only one allowed rank. So do we prefer:
    • "If mean’s rank is not its allowed rank, then throw a TypeError." - more consistent.
    • "If mean’s rank is not 1, then throw a TypeError." - clearer.
  • If we prefer the former (reference the table), what about Simplify, correct, and add validation for GRU/LSTM and friends #659 which proposes simplifying some validation e.g. "If mean’s shape is not equal to « input’s shape[options.axis] »..." — this implicitly validates the rank. If we had that, would we still want a step reading "If mean’s rank is not ..." ?
  • If any rank is allowed, do we still want to include the step, for consistency?

General

  • Should we add a "validate the dataType and rank of an operand" helper algorithm? The steps for each operands could then be collapsed to just: "If validating the dataType and rank of mean returns false, then throw a TypeError." - less text, but more is hidden behind the algorithm.
  • When there are multiple input operands, do we prefer separate steps or combined steps? At an extreme we could even write: "If validating the dataType and rank of any of mean, variance, options.scale (if it exists) or options.bias (if it exists) returns false, then throw a TypeError."

@zolkis
Copy link
Collaborator

zolkis commented May 2, 2024

IMHO the "consistent" option is clear enough in this case. :)

- gemm(): Fix ranks in table, align phrasing.
- gru(): Align phrasing.
- lstm(): Don't inline rank of 3, reference table.
- matmul(): Align phrasing.
- prelu(): Fix punctation.
- triangular(): Add table, use for rank validation.
- where(): Align phrasing.
@inexorabletash
Copy link
Member Author

Other notes:

  • For the op groups (element-wise binary/unary/logical, pooling, reduction) a table is not given and the inline steps are retained. Any better suggestions?
  • Do we keep a rank validation step if there's a subsequent shape validation step? i.e. keep both of these lines:
    1. If mean’s rank is not its allowed rank, then throw a TypeError.
    2. If mean’s shape is not equal to « input’s shape[options.axis] », then throw a TypeError.

@zolkis
Copy link
Collaborator

zolkis commented May 3, 2024

For the op groups (element-wise binary/unary/logical, pooling, reduction) a table is not given and the inline steps are retained.

Looking at the preview, I think that makes sense.

Do we keep a rank validation step if there's a subsequent shape validation step?

Shape should include rank... I'd leave rank out and let impl optionally handle that as a quick check, if it makes sense somewhere.

Copy link
Contributor

@huningxin huningxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @inexorabletash !

It seems some operators are not covered, such as argMin/Max. Do you plan to do them in a separate CL?

index.bs Outdated Show resolved Hide resolved
index.bs Outdated Show resolved Hide resolved
index.bs Outdated Show resolved Hide resolved
index.bs Outdated Show resolved Hide resolved
index.bs Outdated Show resolved Hide resolved
index.bs Outdated Show resolved Hide resolved
index.bs Outdated Show resolved Hide resolved
index.bs Outdated Show resolved Hide resolved
index.bs Show resolved Hide resolved
@inexorabletash
Copy link
Member Author

It seems some operators are not covered, such as argMin/Max. Do you plan to do them in a separate CL?

My approach was to very mechanically migrate explicit rank/shape validation steps from the prose steps that contain the constraints to the table format where the prose steps reference the table. That has these implications:

  • If an op didn't have either explicit data type or rank validation steps, the table wasn't added since nothing would reference it. So e.g. argMin/Max didn't get a table.
  • If an op's data type and/or rank was provided as parameters to the algorithm steps, the steps were not touched, and the table wasn't added for that op.
  • If there wasn't an explicit step validating the data type (e.g. gather), the data type in the table is given as: any data type/any
  • If there wasn't an explicit step validating the rank (e.g. gemm), or that validate the rank against something other than an explicit number (e.g. layerNormalization), the rank in the table is given as: any rank/N
    • Note that steps that validate the shape (e.g. where it must be broadcastable to something) are considered to not validate the rank.

In other words, only introduce the table where it will be normatively referenced.

Obviously we can change this!

  • We can include the table for all ops, which would have "any" / "N" for the allowed data types/ranks
    • argMin/argMax, cast, clamp, concat, element-wise binary, element-wise logical, expand, pad, reshape, slice, split, transpose
    • 🤔 Do we also add steps to the algorithm that reference the table, even though it's a no-op?
  • We can add in tables for the cases where data type/rank are passed in to the algorithm
    • element-wise unary, averagePool2d, l2Pool2d, maxPool2d, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProduct, reduceSum, reduceSumSquare
    • 🤔 This implies we'd have a table per op (e.g. one for abs, one for ceil, one for cos, etc), and modify the algorithms to pass/take both data types and ranks. That's... a lot of tables.
  • We can expand the definition of "allowed ranks" beyond "any" or an explicit number to include:
    • a reference to other parameters, e.g. "same as input"
    • a range
    • 🤔 Do we introduce the implied duplicate steps? e.g. for layerNormalization, do we have steps that validate rank against the table (i.e. scale is in [0, input rank]) and a specific value (i.e. scale is equal to options.axes's size) ?

I am not sure whether it should point to the allowed data types table of batchNorm. It currently points to the allowed data types definition.

Yeah, I wasn't sure how much effort to put into explicitly linking things. Currently nothing actually links to the tables; it's implied that if a step mentions "allowed data types" or "allowed ranks" then the reader should find the nearby table and look up the appropriate value for the named input operand. It might be better to link those phrases in steps to the table cells.

@huningxin
Copy link
Contributor

@inexorabletash

In other words, only introduce the table where it will be normatively referenced.

Makes sense to me! Thanks for the explanation!

We can include the table for all ops, which would have "any" / "N" for the allowed data types/ranks

  • argMin/argMax, cast, clamp, concat, element-wise binary, element-wise logical, expand, pad, reshape, slice, split, transpose

logicalNot may need steps to validate input data type is "uint8"? (and logicalAnd, logicalOr, logicalXor in @fdwr 's wave3 proposal)

And wave3 will introduce "int4"/"uint4" data types, I assume "any" data types may not apply for some ops, we may need "anyDataTypesAtLeast8Bits" (kAllDataTypesAtLeast8bits in Chromium prototype)

Of course we can handle the wave3 ops / data types after its PR landed.

  • 🤔 Do we also add steps to the algorithm that reference the table, even though it's a no-op?

I think it's unnecessary to add no-op steps.

We can add in tables for the cases where data type/rank are passed in to the algorithm

  • element-wise unary, averagePool2d, l2Pool2d, maxPool2d, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProduct, reduceSum, reduceSumSquare
  • 🤔 This implies we'd have a table per op (e.g. one for abs, one for ceil, one for cos, etc), and modify the algorithms to pass/take both data types and ranks. That's... a lot of tables.

Yes, it's a lot. Can we group some tables? Like a table for float-pointing element-wise unary, including ceil, floor, cos, sin, erf, exp, log, reciprocal, tan etc.,. However it may require linking the "allowed data types" to a particular table which is another discussion.

We can expand the definition of "allowed ranks" beyond "any" or an explicit number to include:

  • a reference to other parameters, e.g. "same as input"
  • a range
  • 🤔 Do we introduce the implied duplicate steps? e.g. for layerNormalization, do we have steps that validate rank against the table (i.e. scale is in [0, input rank]) and a specific value (i.e. scale is equal to options.axes's size) ?

Duplicating steps is not the intention. I just feel user should not expect validation failure if the supplied tensor is allowed according to "allowed ranks". We have "same as input" for "allowed data types". Could we have something like "maximum to input rank" for "allowed ranks"?

@inexorabletash
Copy link
Member Author

It seems some operators are not covered, such as argMin/Max.

f61292d adds tables for all ops (or op categories)

@inexorabletash
Copy link
Member Author

  • 🤔 This implies we'd have a table per op (e.g. one for abs, one for ceil, one for cos, etc), and modify the algorithms to pass/take both data types and ranks. That's... a lot of tables.

Yes, it's a lot. Can we group some tables? Like a table for float-pointing element-wise unary, including ceil, floor, cos, sin, erf, exp, log, reciprocal, tan etc.,. However it may require linking the "allowed data types" to a particular table which is another discussion.

For now, in f61292d I left it at one table per "op category" and introduced the phrase "specified as part of operation steps" in the table.

@inexorabletash
Copy link
Member Author

I just feel user should not expect validation failure if the supplied tensor is allowed according to "allowed ranks".

Makes sense. Added in 4bf14f6 for the cases you called out - I didn't audit the ops for more, though. How does that look?

index.bs Show resolved Hide resolved
index.bs Outdated Show resolved Hide resolved
@inexorabletash
Copy link
Member Author

As always, having the table contain "specified as part of operation steps" is an intentionally minimal change. There are a few more things we could do:

  • Have separate tables; e.g. for element-wise logical we could have a generic table with "any" and a table specific to logicalNot; element-wise unary would need 3 (signed types, float types, and any for identity); pooling would need 2; reduction ops would need 3
  • Keep a single table, but expand the text with casual developer-focused commentary, e.g. for element-wise logical it could read: "specified as part of operation steps; most element-wise logical ops support any data type, but logicalNot() only supports uint8"

Copy link
Contributor

@huningxin huningxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

index.bs Outdated Show resolved Hide resolved
Copy link
Contributor

@huningxin huningxin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Collaborator

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Josh! Minor comments, else LGTM.

index.bs Show resolved Hide resolved
index.bs Outdated Show resolved Hide resolved
index.bs Outdated Show resolved Hide resolved
Fix copy/pasta of {{input}} for other params

Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>
Copy link
Collaborator

@fdwr fdwr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

@fdwr fdwr merged commit c237cf1 into webmachinelearning:main Nov 15, 2024
2 checks passed
github-actions bot added a commit that referenced this pull request Nov 15, 2024
SHA: c237cf1
Reason: push, by fdwr

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@inexorabletash inexorabletash deleted the data-type-table branch November 15, 2024 22:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants