Skip to content

Commit

Permalink
[stream_executor] NFC: Use std::optional instead of a bool + output p…
Browse files Browse the repository at this point in the history
…ointer

Modernize code base for the year 2023!

PiperOrigin-RevId: 578913249
  • Loading branch information
ezhulenev authored and copybara-github committed Nov 2, 2023
1 parent add5f59 commit 2c9dbe4
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 50 deletions.
3 changes: 1 addition & 2 deletions xla/service/gpu/stream_executor_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,7 @@ StatusOr<std::unique_ptr<se::KernelBase>> CreateKernel(
Status ExecuteKernelOnStream(const se::KernelBase& kernel,
absl::Span<const se::DeviceMemoryBase> args,
const LaunchDimensions& dims, se::Stream* stream) {
int shared_mem_bytes = 0;
kernel.metadata().shared_memory_bytes(&shared_mem_bytes);
int shared_mem_bytes = kernel.metadata().shared_memory_bytes().value_or(0);
static constexpr int kKernelArgsLimit = 1024;
std::unique_ptr<se::KernelArgsArrayBase> kernel_args;
// The KernelArgsArray structure requires at a minimum 48 * args.size()
Expand Down
17 changes: 7 additions & 10 deletions xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -444,13 +444,10 @@ void GpuExecutor::VlogOccupancyInfo(const KernelBase& kernel,
VLOG(2) << "Thread dimensions (" << thread_dims.x << ", " << thread_dims.y
<< ", " << thread_dims.z << ")";

int regs_per_thread;
if (!kernel.metadata().registers_per_thread(&regs_per_thread)) {
return;
}
auto regs_per_thread = kernel.metadata().registers_per_thread();
auto smem_per_block = kernel.metadata().shared_memory_bytes();

int smem_per_block;
if (!kernel.metadata().shared_memory_bytes(&smem_per_block)) {
if (!regs_per_thread && !smem_per_block) {
return;
}

Expand All @@ -460,13 +457,13 @@ void GpuExecutor::VlogOccupancyInfo(const KernelBase& kernel,
const GpuKernel* cuda_kernel = AsGpuKernel(&kernel);
CUfunction cufunc = cuda_kernel->AsGpuFunctionHandle();

int blocks_per_sm = CalculateOccupancy(device_description, regs_per_thread,
smem_per_block, thread_dims, cufunc);
int blocks_per_sm = CalculateOccupancy(device_description, *regs_per_thread,
*smem_per_block, thread_dims, cufunc);
VLOG(2) << "Resident blocks per SM is " << blocks_per_sm;

int suggested_threads =
CompareOccupancy(&blocks_per_sm, device_description, regs_per_thread,
smem_per_block, thread_dims, cufunc);
CompareOccupancy(&blocks_per_sm, device_description, *regs_per_thread,
*smem_per_block, thread_dims, cufunc);
if (suggested_threads != 0) {
VLOG(2) << "The cuda occupancy calculator recommends using "
<< suggested_threads
Expand Down
28 changes: 7 additions & 21 deletions xla/stream_executor/kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// Implementation of the pointer-to-implementation wrapper for the data-parallel
// kernel abstraction. KernelBase just delegates to the internal
// platform-specific implementation instance.

#include "xla/stream_executor/kernel.h"

#include <cstdint>
#include <optional>
#include <string>
#include <utility>

Expand All @@ -31,32 +29,20 @@ limitations under the License.

namespace stream_executor {

bool KernelMetadata::registers_per_thread(int *registers_per_thread) const {
if (has_registers_per_thread_) {
*registers_per_thread = registers_per_thread_;
return true;
}
std::optional<int64_t> KernelMetadata::registers_per_thread() const {
return registers_per_thread_;
}

return false;
std::optional<int64_t> KernelMetadata::shared_memory_bytes() const {
return shared_memory_bytes_;
}

void KernelMetadata::set_registers_per_thread(int registers_per_thread) {
registers_per_thread_ = registers_per_thread;
has_registers_per_thread_ = true;
}

bool KernelMetadata::shared_memory_bytes(int *shared_memory_bytes) const {
if (has_shared_memory_bytes_) {
*shared_memory_bytes = shared_memory_bytes_;
return true;
}

return false;
}

void KernelMetadata::set_shared_memory_bytes(int shared_memory_bytes) {
shared_memory_bytes_ = shared_memory_bytes;
has_shared_memory_bytes_ = true;
}

KernelBase::KernelBase(KernelBase &&from)
Expand Down
36 changes: 19 additions & 17 deletions xla/stream_executor/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ limitations under the License.
#include <cstdint>
#include <cstring>
#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <type_traits>
Expand All @@ -81,7 +82,6 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/platform/port.h"

namespace stream_executor {

Expand All @@ -94,6 +94,10 @@ namespace internal {
class KernelInterface;
} // namespace internal

//===----------------------------------------------------------------------===//
// Kernel cache config
//===----------------------------------------------------------------------===//

// This enum represents potential configurations of L1/shared memory when
// running a particular kernel. These values represent user preference, and
// the runtime is not required to respect these choices.
Expand All @@ -111,41 +115,39 @@ enum class KernelCacheConfig {
kPreferEqual,
};

//===----------------------------------------------------------------------===//
// Kernel metadata
//===----------------------------------------------------------------------===//

// KernelMetadata holds runtime-queryable attributes of a loaded kernel, such as
// registers allocated, shared memory used, etc.
// Not all platforms support reporting of all information, so each accessor
// returns false if the associated field is not populated in the underlying
// platform.
class KernelMetadata {
public:
KernelMetadata()
: has_registers_per_thread_(false), has_shared_memory_bytes_(false) {}
KernelMetadata() = default;

// Returns the number of registers used per thread executing this kernel.
bool registers_per_thread(int *registers_per_thread) const;

// Sets the number of registers used per thread executing this kernel.
void set_registers_per_thread(int registers_per_thread);
std::optional<int64_t> registers_per_thread() const;

// Returns the amount of [static] shared memory used per block executing this
// kernel. Note that dynamic shared memory allocations are not (and can not)
// be reported here (since they're not specified until kernel launch time).
bool shared_memory_bytes(int *shared_memory_bytes) const;
std::optional<int64_t> shared_memory_bytes() const;

// Sets the amount of [static] shared memory used per block executing this
// kernel.
void set_registers_per_thread(int registers_per_thread);
void set_shared_memory_bytes(int shared_memory_bytes);

private:
// Holds the value returned by registers_per_thread above.
bool has_registers_per_thread_;
int registers_per_thread_;

// Holds the value returned by shared_memory_bytes above.
bool has_shared_memory_bytes_;
int64_t shared_memory_bytes_;
std::optional<int64_t> registers_per_thread_;
std::optional<int64_t> shared_memory_bytes_;
};

//===----------------------------------------------------------------------===//
// Kernel
//===----------------------------------------------------------------------===//

// A data-parallel kernel (code entity) for launching via the StreamExecutor,
// analogous to a void* device function pointer. See TypedKernel for the typed
// variant.
Expand Down

0 comments on commit 2c9dbe4

Please sign in to comment.