Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions paddle/fluid/pybind/place.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,14 @@ void BindPlace(pybind11::module &m) { // NOLINT
[](const phi::CustomPlace &self) { return self.GetDeviceType(); })
.def("__repr__", string::to_string<const phi::CustomPlace &>)
.def("__str__", string::to_string<const phi::CustomPlace &>);
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
m.def("is_float16_supported", [](const phi::CustomPlace &place) -> bool {
return phi::DeviceManager::IsFloat16Supported(place);
});
m.def("is_bfloat16_supported", [](const phi::CustomPlace &place) -> bool {
return phi::DeviceManager::IsBFloat16Supported(place);
});
#endif
py::class_<phi::GPUPlace, phi::Place> cudaplace(m, "CUDAPlace", R"DOC(

CUDAPlace is a descriptor of a device.
Expand Down
20 changes: 20 additions & 0 deletions paddle/phi/backends/custom/custom_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,26 @@ class CustomDevice : public DeviceInterface {
return grid_dim_size;
}

bool IsFloat16Supported(size_t dev_id) {
const auto device = &devices_pool[dev_id];
bool supported = false;
if (pimpl_->is_float16_supported) {
pimpl_->is_float16_supported(device, &supported);
}
VLOG(10) << Type() << " is float16 supported: " << supported;
return supported;
}

bool IsBFloat16Supported(size_t dev_id) {
const auto device = &devices_pool[dev_id];
bool supported = false;
if (pimpl_->is_bfloat16_supported) {
pimpl_->is_bfloat16_supported(device, &supported);
}
VLOG(10) << Type() << " is bfloat16 supported: " << false;
return supported;
}

void* InitEigenDevice(const Place& place,
phi::stream::stream_t stream,
phi::Allocator* allocator) override {
Expand Down
12 changes: 11 additions & 1 deletion paddle/phi/backends/device_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ size_t DeviceInterface::GetComputeCapability(size_t dev_id) {
}

DeviceProp& DeviceInterface::GetDeviceProperties(size_t dev_id) {
DeviceProp prop;
static DeviceProp prop;
VLOG(10) << Type() << " get device properties " << 0;
return prop;
}
Expand Down Expand Up @@ -73,6 +73,16 @@ std::array<unsigned int, 3> DeviceInterface::GetMaxGridDimSize(size_t dev_id) {
return {0, 0, 0};
}

bool DeviceInterface::IsFloat16Supported(size_t dev_id) {
VLOG(10) << Type() << " is float16 supported: " << false;
return false;
}

bool DeviceInterface::IsBFloat16Supported(size_t dev_id) {
VLOG(10) << Type() << " is bfloat16 supported: " << false;
return false;
}

void* DeviceInterface::InitEigenDevice(const Place& place,
phi::stream::stream_t stream,
phi::Allocator* allocator) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/backends/device_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ class DeviceInterface { // Driver / Runtime

virtual std::array<unsigned int, 3> GetMaxGridDimSize(size_t dev_id);

virtual bool IsFloat16Supported(size_t dev_id);

virtual bool IsBFloat16Supported(size_t dev_id);

virtual void* InitEigenDevice(const Place& place,
phi::stream::stream_t stream,
phi::Allocator* allocator);
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/backends/device_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,20 @@ struct C_DeviceInterface {
C_Status (*get_max_grid_dim_size)(const C_Device device,
std::array<unsigned int, 3>* grid_dim_size);

/**
* @brief Is float16 supported
*
* @param[C_Device, bool*] device, supported
*/
C_Status (*is_float16_supported)(const C_Device device, bool* supported);

/**
* @brief Is bfloat16 supported
*
* @param[C_Device, bool*] device, supported
*/
C_Status (*is_bfloat16_supported)(const C_Device device, bool* supported);

/**
* @brief init eigen device
*
Expand Down
14 changes: 14 additions & 0 deletions paddle/phi/backends/device_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,20 @@ std::array<unsigned int, 3> DeviceManager::GetMaxGridDimSize(
return dev_impl->GetMaxGridDimSize(device_id);
}

bool DeviceManager::IsFloat16Supported(const Place& place) {
auto device_type = place.GetDeviceType();
auto device_id = place.GetDeviceId();
auto dev_impl = GetDeviceInterfaceWithType(device_type);
return dev_impl->IsFloat16Supported(device_id);
}

bool DeviceManager::IsBFloat16Supported(const Place& place) {
auto device_type = place.GetDeviceType();
auto device_id = place.GetDeviceId();
auto dev_impl = GetDeviceInterfaceWithType(device_type);
return dev_impl->IsBFloat16Supported(device_id);
}

void* DeviceManager::InitEigenDevice(const Place& place,
phi::stream::stream_t stream,
phi::Allocator* allocator) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/backends/device_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ class DeviceManager {

static std::array<unsigned int, 3> GetMaxGridDimSize(const Place& place);

static bool IsFloat16Supported(const Place& place);

static bool IsBFloat16Supported(const Place& place);

static void* InitEigenDevice(const Place& place,
phi::stream::stream_t stream,
phi::Allocator* allocator);
Expand Down