Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Welford Scheduling Support #561

Merged
merged 48 commits into from
Feb 18, 2021
Merged

Welford Scheduling Support #561

merged 48 commits into from
Feb 18, 2021

Conversation

shmsong
Copy link

@shmsong shmsong commented Dec 8, 2020

This PR intends to implement welford scheduling into the fuser pipeline.

(This PR will introduce non-trivial merging with #586 , I will rebase after #586 has merged)

  • Welford fusion IR interface
  • Serial Welford kernel generation
  • Block Parallel Welford kernel generation
  • Grid Parallel Welford kernel generation
  • Welford with Rfactor

More scheduling tests and welford scheduler in a subsequent PR.

Example math print containing welford:

%kernel_math {
T1[ iS2{i1}, iS3{i3} ] compute_at( T3, 2 )
   = T0[ iS0{i1}, iS1{i3} ]
   * double(1);
T2[ iS4{i1}, rS5{i3} ] compute_at( T3, 2 )(Var), T3[ ithreadIdx.x6{blockDim.x}, rblockIdx.x7{gridDim.x} ] compute_at( T4, 2 )(Avg), T4[ iS8{i1}, rS9{i3} ](Count) = Welford ( T1[ iS2{i1}, iS3{i3} ] compute_at( T3, 2 ) )
}

Example kernel containing welford:

__global__ void kernel1(Tensor<float, 2> T0, Tensor<float, 1> T2, Tensor<float, 1> T3, Tensor<int64_t, 1> T4, Tensor<float, 1> kT116, Tensor<float, 1> kT121, Tensor<int64_t, 1> kT126, Tensor<int64_t, 1> kT131) {
  alignas(8) extern __shared__ char array[];
  void* shared_mem = array;
  size_t block_size = blockDim.x*blockDim.y*blockDim.z;
  int64_t *shared_mem_var = static_cast<int64_t*>(shared_mem);
  int64_t *shared_mem_avg = shared_mem_var + block_size;
  int64_t *shared_mem_n = shared_mem_avg + block_size;
  T4[(threadIdx.x * T4.stride[0])] = 0;
  T3[(threadIdx.x * T3.stride[0])] = 0;
  T2[(threadIdx.x * T2.stride[0])] = 0;
  float T1[1];
  T1[0]
    = T0[(threadIdx.x * T0.stride[0]) + (blockIdx.x * T0.stride[1])]
    * 1;
  bool T3_pred;
  // Allocate global tensor kT116
  // Allocate global tensor kT121
  // Allocate global tensor kT126
  // Allocate global tensor kT131
  T3_pred = welford::gridWelford<true, false, false, true, true, true>(
    T2[(threadIdx.x * T2.stride[0])],
  T3[(threadIdx.x * T3.stride[0])],
  T4[(threadIdx.x * T4.stride[0])],
    (float) 0,
    T1[0],
    (int64_t)1,
    &kT116[0],
    &kT121[0],
    &kT126[0],
    kT131,
    reinterpret_cast<float*>(shared_mem_var),
    reinterpret_cast<float*>(shared_mem_avg),
    reinterpret_cast<int64_t*>(shared_mem_n),
    true,
    float(0));
}

@shmsong shmsong changed the title [WIP] Welford Scheduling Support Welford Scheduling Support Jan 8, 2021
@shmsong shmsong requested review from csarofeen and naoyam January 8, 2021 23:52
@@ -918,6 +970,15 @@ generateIndexAndExtentMap(
loops.pop_back();
}

if (tv->definition()->isA<WelfordOp>()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this assume WelfordOp is the only expression type with multiple outputs? If so, would it be possible to generalize it so that it could work with any future expressions with multiple outputs?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. WelfordOp is currently the only multi-output use case we considered so far. This PR was trying to support WelfordOp with minimal generalizations but if we have other multi-output cases we can generalize.

On the index compute side the implementation is more temporary than architectural due to the limitation that the loop variable now can only be mapped to one of the outputs. This part will be re-factored after we switch to index compute based on local compute-at and domain maps (@csarofeen). I'd prefer adding multi-output support at that point if we do decide to generalize.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's okay to start with something specific to Welford and generalize it later, as long as it is guarded with assertion about the assumption. For example, if something is only meant to work with Welford, then it should be preceded by TORCH_INTERNAL_ASSERT.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added TORCH_INTERNAL_ASSERT for the multioutput case. Thanks.

@shmsong
Copy link
Author

shmsong commented Jan 19, 2021

I left my comments. None of them are particularly critical, but some cleanup seems possible.

Thanks for the detailed review and helpful suggestions! 👍

@@ -46,6 +46,31 @@ TORCH_CUDA_API TensorView* reductionOp(
TensorView* v1,
bool keep_dim = false);

//! Auxiliary Struct holding result of
//! a single welford op in ternsorview
struct TORCH_CUDA_API WelfordResult {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: use class instead of struct for anything with methods.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Gave a couple of minor comments.

Copy link
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, just one comment on rfactor I'd like to see addressed then can approve.

namespace {

template <typename T>
kir::Allocate* allocGlobalBuffer(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use this to also simplify the grid reduction code? Would make more sense to do in a follow up if yes.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think so. Thanks for pointing this out. I will put further simplifications in a follow up.

@shmsong shmsong requested a review from csarofeen February 18, 2021 18:47
Copy link
Owner

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@shmsong shmsong merged commit 2bcc6a9 into 20_12_3_devel Feb 18, 2021
@csarofeen csarofeen deleted the multi_output_scan branch June 9, 2021 13:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants