Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 1xtfloat capability to pairwise_matrix distance computations #1493

Draft
wants to merge 4 commits into
base: branch-23.06
Choose a base branch
from

Conversation

ahendriksen
Copy link
Contributor

This PR adds the possibility to use 1xtfloat in the pairwise matrix computations of raft::distance.

When 1xtfloat is enabled, the throughput more than triples compared to using 3xtfloat.

Benchmarks below were taken on H100 (unlocked clocks, SXM). The distance computed was the square L2 expanded distance. Therefore, one core_op corresponds to one fused multiply add.

Time Iterations 1xtfloat BW core_ops/s k m n
0.050 ms 13697 249.475G/s 21.2885T/s 1024 1024 1024
0.087 ms 8071 242.288G/s 24.8103T/s 2.048k 1024 1024
0.160 ms 4392 235.911G/s 26.8414T/s 4.096k 1024 1024
0.297 ms 2361 240.406G/s 28.9619T/s 8.192k 1024 1024
0.265 ms 2633 523.272G/s 64.9491T/s 1024 1024 16.384k
0.490 ms 1430 428.423G/s 70.1928T/s 2.048k 1024 16.384k
0.945 ms 741 372.783G/s 72.7105T/s 4.096k 1024 16.384k
1.85 ms 378 344.673G/s 74.3042T/s 8.192k 1024 16.384k
0.132 ms 5304 95.2914G/s 8.13154T/s 1024 1024 1024
0.249 ms 2808 84.1437G/s 8.61631T/s 2.048k 1024 1024
0.484 ms 1445 77.9194G/s 8.8655T/s 4.096k 1024 1024
0.955 ms 733 74.636G/s 8.99144T/s 8.192k 1024 1024
0.816 ms 857 169.614G/s 21.0526T/s 1024 1024 16.384k
1.58 ms 442 132.367G/s 21.687T/s 2.048k 1024 16.384k
3.13 ms 224 112.65G/s 21.9721T/s 4.096k 1024 16.384k
6.21 ms 113 102.705G/s 22.141T/s 8.192k 1024 16.384k

The instance with rbf_fin_op caused some headache:

Because the handling of L2 expanded and unexpanded is unified in
``distance_impl_l2_with_options``, an instance of the CUTLASS distance
kernel for rbf_fin_op was instantiated. For some reason, CUTLASS did not
accept this as a valid argument and threw a very very big error message.
I could not get the rbf_fin_op in acceptable state for cutlass: I
included a default constructor, put const on every method, but to no
avail.

The current solution is to avoid CUTLASS when another final op is used
than the raft::identity_op.
Peak T ops/s = 74 T/s (1x tfloat)
Peak T ops/s = 22 T/s (3x tfloat)

This roughly corresponds to: (assuming 2 flops / core op)

Peak T ops/s = 144 Tflop/s (1x tfloat)
Peak T ops/s =  33 Tflop/s (3x tfloat)
@github-actions github-actions bot added the cpp label May 8, 2023
@ahendriksen
Copy link
Contributor Author

@cjnolet : I have implemented the 1xtfloat distance, but it is not yet exposed in the public API. The distance API is getting a bit unwieldy. I see the following options to expose the 1xtfloat in the API:

  1. Add another overload of raft::distance::distance that takes an {L2, cosine, etc..}_options struct.
  2. Option 1 and also remove many overloads of raft::distance::distance.
  3. Interrogate the NVIDIA_TF32_OVERRRIDE environment and/or add a flag to raft::resources to enable 1xtfloat (as discussed in [FEA] Support for NVIDIA_TF32_OVERRIDE environment variable + handle #1393)

Do you have any thought on this? What has your preference?

@ahendriksen
Copy link
Contributor Author

@benfred : Related to #852, I have drafted a type to describe the L2 distance options. It describes:

  • Whether to compute the squared or true L2 distance
  • How to compute the L2 distance (expanded/unexpanded, 3xtfloat, 1xtfloat, depend on environment variables)

The docstrings in the code explain how each option should work. Please let me know:

  • If this is how you envisioned the distance types => if so, I can expand to other distances as well.
  • If you have any comments on the current design.

Copy link
Contributor

@tfeher tfeher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Allard for this proposal! Indeed it is important to discuss how do we enable control for distance computation precision.

We should consider how the new arguments could be propagated to IVF methods.
It seems that ann::index_params already has a metric_arg, expanding that to a struct that has the additional options would work. The parameters would need to be passed to kmeans clustering called from ivf_flat / ivf_pq build(). The kmeans_base_params would need to be extended with the metric_arg.

Having global parameter that we set in resource handle also has its appeal: e.g. if we want to enable 1xtf32 for cuML SVM. Currently (at lest for some of the kernels) NVIDIA_TF32_OVERRIDE would work without passing extra args through the call chain.

Comment on lines +126 to +127
// double this number to get the flop/s. For l2 expanded, core_ops/s should
// equal flop/s (modulo the sqrt and subtracting from the norm).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this FMA/s (i.e. you would still need to multiply core_ops/s by 2 to get flops)?

Suggested change
// double this number to get the flop/s. For l2 expanded, core_ops/s should
// equal flop/s (modulo the sqrt and subtracting from the norm).
// double this number to get the flop/s. For l2 expanded, 2*core_ops/s should
// equal flop/s (ignoring the sqrt and subtracting from the norm).

Comment on lines +481 to +482
// Use if constexpr to prevent instantiation of CUTLASS templates with final
// operations like rbf_fin_op, which are somehow not compatible with
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mdoijade, are you aware of restrictions that we have for epilogue functions for the cutlass kernels?

Copy link
Contributor

@mdoijade mdoijade May 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes AFAIK cutlass doesn't work with lambda function here in this rb_fin_op must have been a lambda.

bool isRowMajor>
bool isRowMajor,
/// Whether to use 3xtfloat or 1xtfloat:
bool use_1xtfloat>
struct PairwiseDistanceGemm {
// This struct is specialized for fp32/3xTF32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the same tile sizes work reasonable well for 1xTF32?

Comment on lines +74 to +77
using Operator =
std::conditional_t<use_1xtfloat,
cutlass::arch::OpMultiplyAdd, // This implies tensorfloat
cutlass::arch::OpMultiplyAddFastF32>; // This implies 3xtfloat
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we decide what precision to use. IIUC, the rest of the PR is responsible to

  1. propagate the user input parameter until this point,
  2. update dispatch mechanism accordingly,
  3. add benchmarks.

/**
* @brief Describes how precise and fast distance should be computed.
*/
enum class Compute_options {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should consider naming / wording in a way that would clearly describes what happens even if we enable half precision input.

@cjnolet cjnolet added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels May 16, 2023
@cjnolet
Copy link
Member

cjnolet commented Jan 10, 2024

@ahendriksen @tfeher are we still planning to make progress on this feature? I'm doing a little housekeeping on the PRs and just want to make sure the PRs we are keeping open are still valid.

@tfeher
Copy link
Contributor

tfeher commented Jan 10, 2024

The main question here what is the best mechanism for the user to opt in/out of 1xTF32 computation. @vinaydes is working on the same question related to #1892. Let's wait until that is fixed, and afterwards we shall return to this PR.

Since @ahendriksen is busy with other tasks, we need someone else to continue this. Assigning this to myself for now, we will revisit availability once #1892 is solved.

@tfeher tfeher assigned tfeher and unassigned ahendriksen Jan 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cpp improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
Status: In Progress
Development

Successfully merging this pull request may close these issues.

4 participants