Skip to content

Commit

Permalink
ipu device manage (PaddlePaddle#20)
Browse files Browse the repository at this point in the history
* ipu device manage

Co-authored-by: Han <hanzhao@graphcore.ai>
Co-authored-by: jianghaicheng <haichengj@graphcore.ai>
  • Loading branch information
3 people authored Aug 6, 2021
1 parent ce3844e commit 689feb4
Show file tree
Hide file tree
Showing 20 changed files with 273 additions and 135 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/framework/ipu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ set(POPART_CANONICALIZATION_HANDLERS_SRC
cc_library(popart_canonicalization_utils SRCS popart_canonicalization_utils.cc
${POPART_CANONICALIZATION_HANDLERS_SRC} DEPS framework_proto enforce)

cc_library(ipu_device SRCS device.cc DEPS enforce popart)
cc_library(ipu_utils SRCS ipu_utils.cc DEPS memory framework_proto popart)
cc_library(ipu_build_strategy SRCS ipu_build_strategy.cc DEPS popart graph framework_proto enforce)
cc_library(ipu_backend SRCS ipu_backend.cc DEPS popart graph framework_proto enforce ipu_utils ipu_build_strategy)
cc_library(ipu_backend SRCS ipu_backend.cc DEPS popart graph framework_proto enforce ipu_utils ipu_build_strategy ipu_device graph_helper)
48 changes: 48 additions & 0 deletions paddle/fluid/framework/ipu/device.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/ipu/device.h"

namespace paddle {
namespace framework {
namespace ipu {

Device::Device(const popart::DeviceInfo& device_info)
: id_(device_info.getId()), is_attached_(device_info.isAttached()) {
popart::DeviceType popart_device_type = device_info.getType();
switch (popart_device_type) {
case popart::DeviceType::IpuModel:
device_type_ = DeviceType::IpuModel;
break;
case popart::DeviceType::Cpu:
device_type_ = DeviceType::Cpu;
break;
case popart::DeviceType::Ipu:
device_type_ = DeviceType::Ipu;
break;
case popart::DeviceType::OfflineIpu:
device_type_ = DeviceType::OfflineIpu;
break;
case popart::DeviceType::Sim:
device_type_ = DeviceType::Sim;
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"popart::DeviceType:Unsupported type %d", popart_device_type));
}
}

} // namespace ipu
} // namespace framework
} // namespace paddle
44 changes: 44 additions & 0 deletions paddle/fluid/framework/ipu/device.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <popart/devicemanager.hpp>
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {
namespace ipu {

enum class DeviceType { IpuModel = 0, Cpu, Ipu, OfflineIpu, Sim };

class Device {
public:
Device() {}
explicit Device(const popart::DeviceInfo& device_info);

int getId() const { return id_; }
bool isAttached() const { return is_attached_; }
DeviceType getType() const { return device_type_; }

private:
int id_;
bool is_attached_;
DeviceType device_type_;
/* TODO:: Add more elements in the future */
};

} // namespace ipu
} // namespace framework
} // namespace paddle
88 changes: 79 additions & 9 deletions paddle/fluid/framework/ipu/ipu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,16 @@ limitations under the License. */
#include <popart/tensor.hpp>
#include <popart/tensorinfo.hpp>

#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ipu/ipu_utils.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {
namespace ipu {

std::shared_ptr<IpuBackend> IpuBackend::instance_ = nullptr;

Expand Down Expand Up @@ -127,11 +126,10 @@ void IpuBackend::Prepare() {
}
auto dataFlow = popart::DataFlow(1, anchor_ids);

std::map<std::string, std::string> deviceOpts{{"numIPUs", "1"}};
auto ipuModelDevice =
popart::DeviceManager::createDeviceManager().createIpuModelDevice(
deviceOpts);
// or acquireAvailableDevice();
PADDLE_ENFORCE_NOT_NULL(
curr_device_,
platform::errors::Unavailable("IPU device isn't attached, please call "
"IpuBackend::AttachDevice(id) first."));

if (ipu_build_strategy_ != nullptr && ipu_build_strategy_->is_training_) {
VLOG(1) << "Creating TrainingSession from Onnx Model...";
Expand All @@ -142,11 +140,11 @@ void IpuBackend::Prepare() {
paddle::platform::errors::InvalidArgument(
"loss_id = %s doesn't exist in popart graph.", optimizer_.loss_));
session_ = popart::TrainingSession::createFromOnnxModel(
proto, dataFlow, it->second, *popart_optimizer, ipuModelDevice);
proto, dataFlow, it->second, *popart_optimizer, curr_device_);
} else {
VLOG(1) << "Creating InferenceSession from Onnx Model...";
session_ = popart::InferenceSession::createFromOnnxModel(proto, dataFlow,
ipuModelDevice);
curr_device_);
}
VLOG(1) << "Creating session from Onnx Model...done";

Expand Down Expand Up @@ -308,5 +306,77 @@ void IpuBackend::LowerBody(const ir::Graph* graph) {
}
}

size_t IpuBackend::GetNumDevices() {
// IpuModel
bool ipu_model = GetBoolEnv("POPLAR_IPUMODEL");
if (ipu_model) return 1;
// Real dev
size_t num_devices =
popart::DeviceManager::createDeviceManager().enumerateDevices().size();
PADDLE_ENFORCE_GT(
num_devices, 0,
platform::errors::Unavailable(
"Do not found any IPU devices, please make "
"sure Poplar sdk is enabled or enable ENV \"POPLAR_IPUMODEL=1\""));
return num_devices;
}

std::vector<int> IpuBackend::GetDeviceIds() {
bool ipu_model = GetBoolEnv("POPLAR_IPUMODEL");
if (ipu_model) {
return {0};
}
std::vector<int> device_ids;
auto devices =
popart::DeviceManager::createDeviceManager().enumerateDevices();
PADDLE_ENFORCE_GT(
devices.size(), 0,
platform::errors::Unavailable("Do not found any IPU devices, please make "
"sure Poplar sdk is enabled."));

for (auto device : devices) {
device_ids.push_back(device->getId());
}

return device_ids;
}

Device IpuBackend::GetDevice(int id) {
bool ipu_model = GetBoolEnv("POPLAR_IPUMODEL");
if (ipu_model) {
std::map<std::string, std::string> deviceOpts{{"numIPUs", "1 "}};
curr_device_ =
popart::DeviceManager::createDeviceManager().createIpuModelDevice(
deviceOpts);
Device device(*curr_device_.get());
return device;
}
size_t num_devices = GetNumDevices();
if (id < 0 || id >= num_devices) {
PADDLE_THROW(platform::errors::InvalidArgument(
"device id %d is invalid, number devices is %d", id, num_devices));
}
std::shared_ptr<popart::DeviceInfo> popart_device_info =
popart::DeviceManager::createDeviceManager().getDevice(
popart::SyncPattern::Full, id);
Device device(*popart_device_info.get());
return device;
}

void IpuBackend::AttachDevice(int id) {
bool ipu_model = GetBoolEnv("POPLAR_IPUMODEL");
if (ipu_model) {
return;
}
curr_device_ =
popart::DeviceManager::createDeviceManager().acquireDeviceById(id);
PADDLE_ENFORCE_NOT_NULL(
curr_device_,
platform::errors::Unavailable("Can't attach IPU device id = %d.", id));
}

bool IpuBackend::DeviceIsAttached() { return curr_device_ != nullptr; }

} // namespace ipu
} // namespace framework
} // namespace paddle
11 changes: 9 additions & 2 deletions paddle/fluid/framework/ipu/ipu_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@ limitations under the License. */

#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/ipu/ipu_build_strategy.h"
#include "paddle/fluid/framework/ipu/device.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace framework {
namespace ipu {

using ipu::IpuBuildStrategy;

struct Optimizer {
std::string type_;
std::string loss_;
Expand Down Expand Up @@ -89,6 +89,11 @@ class IpuBackend {
void SetIpuBuildStrategy(const IpuBuildStrategy &strategy) {
ipu_build_strategy_ = &strategy;
}
size_t GetNumDevices();
std::vector<int> GetDeviceIds();
Device GetDevice(int id);
void AttachDevice(int id);
bool DeviceIsAttached();

static std::shared_ptr<IpuBackend> GetInstance() {
if (NULL == instance_) {
Expand All @@ -115,9 +120,11 @@ class IpuBackend {

std::unique_ptr<popart::Builder> builder_;
std::unique_ptr<popart::Session> session_;
std::shared_ptr<popart::DeviceInfo> curr_device_;

static std::shared_ptr<IpuBackend> instance_;
};

} // namespace ipu
} // namespace framework
} // namespace paddle
18 changes: 16 additions & 2 deletions paddle/fluid/framework/ipu/ipu_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

namespace paddle {
namespace framework {
namespace ipu {

popart::DataType VarType2PopartType(proto::VarType::Type type) {
switch (type) {
Expand Down Expand Up @@ -46,6 +47,19 @@ popart::DataType VarType2PopartType(proto::VarType::Type type) {
"Unsupported Paddle var type."));
}
}

// count num should > 0
bool GetBoolEnv(std::string str) {
char *str_val = getenv(str.c_str());
if (str_val == NULL) {
return false;
} else {
bool val = false;
if (strcmp(str_val, "1") == 0 || strcmp(str_val, "true") == 0 ||
strcmp(str_val, "True") == 0 || strcmp(str_val, "TRUE") == 0)
val = true;
return val;
}
}
} // namespace ipu
} // namespace framework
} // namespace paddle
} // namespace paddle
9 changes: 6 additions & 3 deletions paddle/fluid/framework/ipu/ipu_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@ limitations under the License. */

namespace paddle {
namespace framework {
namespace ipu {

popart::DataType VarType2PopartType(proto::VarType::Type type);
bool GetBoolEnv(std::string str);

template <typename T>
std::unique_ptr<popart::NDArrayWrapper<T>> Tensor2IArray(Tensor &tensor) {
std::unique_ptr<popart::NDArrayWrapper<T>> Tensor2IArray(const Tensor &tensor) {
auto dtype = VarType2PopartType(tensor.type());
auto shape = std::vector<int64_t>();
for (size_t i = 0; i < tensor.dims().size(); ++i) {
Expand All @@ -42,7 +44,7 @@ std::unique_ptr<popart::NDArrayWrapper<T>> Tensor2IArray(Tensor &tensor) {

template <typename T>
std::unique_ptr<popart::NDArrayWrapper<T>> LoDTensor2IArray(
LoDTensor &lod_tensor) {
LoDTensor const &lod_tensor) {
if (lod_tensor.lod().size() == 0) {
return Tensor2IArray<T>(lod_tensor);
} else {
Expand All @@ -51,5 +53,6 @@ std::unique_ptr<popart::NDArrayWrapper<T>> LoDTensor2IArray(
}
}

} // namespace ipu
} // namespace framework
} // namespace paddle
} // namespace paddle
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ void IpuGraphBuilderPass::ApplyImpl(ir::Graph* graph) const {
std::vector<std::string> fetch_list;
fetch_list = Get<std::vector<std::string>>("fetch_list");

std::shared_ptr<IpuBackend> ipu_backend = IpuBackend::GetInstance();
std::shared_ptr<ipu::IpuBackend> ipu_backend = ipu::IpuBackend::GetInstance();

ipu_backend->Compile(graph, feed_list, fetch_list);

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "Raw Graph: ";
VLOG(10) << DebugString(graph);

auto ipu_backend = paddle::framework::IpuBackend::GetInstance();
auto ipu_backend = paddle::framework::ipu::IpuBackend::GetInstance();

for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()) {
Expand Down
8 changes: 7 additions & 1 deletion paddle/fluid/operators/ipu_runtime_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,13 @@ class IpuRuntimeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#ifdef PADDLE_WITH_IPU
auto ipu_backend = paddle::framework::IpuBackend::GetInstance();
auto ipu_backend = framework::ipu::IpuBackend::GetInstance();
if (!ipu_backend->DeviceIsAttached()) {
const platform::IPUDeviceContext& ipu_ctx =
reinterpret_cast<const platform::IPUDeviceContext&>(
ctx.device_context());
ipu_backend->AttachDevice(ipu_ctx.DeviceId());
}
VLOG(4) << "IpuRuntime Kernel, begin to run graph";
auto inputs = ctx.MultiInput<framework::Tensor>("FeedList");
auto outputs = ctx.MultiOutput<framework::Tensor>("FetchList");
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/platform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ IF(WITH_GPU)
nv_library(gpu_info SRCS gpu_info.cc DEPS gflags glog enforce monitor dynload_cuda)
ENDIF()
IF(WITH_IPU)
set(IPU_CTX_DEPS popart)
cc_library(ipu_info SRCS ipu_info.cc DEPS popart)
set(IPU_CTX_DEPS ipu_backend)
cc_library(ipu_info SRCS ipu_info.cc DEPS ipu_backend)
ELSE()
set(IPU_CTX_DEPS)
ENDIF(WITH_IPU)
Expand Down
Loading

0 comments on commit 689feb4

Please sign in to comment.