Skip to content

Commit

Permalink
Fix hang with cudnn layer norm by moving cudnn init to Initialize()
Browse files Browse the repository at this point in the history
  • Loading branch information
trevor-m committed May 6, 2024
1 parent d37ae92 commit f535330
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 16 deletions.
18 changes: 2 additions & 16 deletions xla/service/gpu/gpu_norm_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,8 @@ absl::Status RunGpuNorm(const gpu::GpuNormConfig& config,
se::Stream* stream, RunNormOptions options) {
se::dnn::LazyOpRunner<se::dnn::NormOp>* lazy_runner =
options.norm_runner->AsNormRunner();
std::optional<se::dnn::LazyOpRunner<se::dnn::NormOp>> local_runner;

TF_ASSIGN_OR_RETURN(se::dnn::NormKind kind,
GetDNNNormKindFromCudnnNormKind(config.kind));

se::dnn::NormOp::Config ln_config{kind,
config.epsilon,
config.x_descriptor,
config.scale_descriptor,
config.y_or_dx_descriptor,
config.bias_descriptor,
config.dy_descriptor,
config.expectation_descriptor,
config.norm_factor_descriptor,
config.dscale_descriptor,
config.dbias_descriptor};
TF_ASSIGN_OR_RETURN(se::dnn::NormOp::Config ln_config,
config.AsDnnNormOpConfig());
TF_ASSIGN_OR_RETURN(auto* runner,
lazy_runner->GetOrCreateRunner(ln_config, stream));

Expand Down
16 changes: 16 additions & 0 deletions xla/service/gpu/gpu_norm_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,22 @@ struct GpuNormConfig {
return config;
}

absl::StatusOr<se::dnn::NormOp::Config> AsDnnNormOpConfig() const {
TF_ASSIGN_OR_RETURN(se::dnn::NormKind norm_kind,
GetDNNNormKindFromCudnnNormKind(kind));
return se::dnn::NormOp::Config{norm_kind,
epsilon,
x_descriptor,
scale_descriptor,
y_or_dx_descriptor,
bias_descriptor,
dy_descriptor,
expectation_descriptor,
norm_factor_descriptor,
dscale_descriptor,
dbias_descriptor};
}

double epsilon;
CudnnNormKind kind;
se::dnn::AlgorithmDesc algorithm;
Expand Down
9 changes: 9 additions & 0 deletions xla/service/gpu/runtime/norm_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,14 @@ absl::Status NormThunk::ExecuteOnStream(const ExecuteParams& params) {
return absl::OkStatus();
}

absl::Status NormThunk::Initialize(const InitializeParams& params) {
// Create the runner at initialization time to avoid hangs if we try to build
// the execution plan while a NCCL collective is running.
se::dnn::LazyOpRunner<se::dnn::NormOp>* lazy_runner =
GetOrCreateRunner(params.stream).AsNormRunner();
TF_ASSIGN_OR_RETURN(auto ln_config, config_.AsDnnNormOpConfig());
return lazy_runner->GetOrCreateRunner(ln_config, params.stream).status();
}

} // namespace gpu
} // namespace xla
1 change: 1 addition & 0 deletions xla/service/gpu/runtime/norm_thunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ class NormThunk : public Thunk {
NormThunk& operator=(const NormThunk&) = delete;

absl::Status ExecuteOnStream(const ExecuteParams& params) override;
absl::Status Initialize(const InitializeParams& params) override;

private:
BufferAllocation::Slice x_buffer_;
Expand Down

0 comments on commit f535330

Please sign in to comment.