Skip to content

Commit

Permalink
add SetScope API (PaddlePaddle#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
XBWGC authored Aug 4, 2021
1 parent fedeabf commit d02eef1
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 8 deletions.
7 changes: 6 additions & 1 deletion paddle/fluid/framework/ipu/ipu_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ class IpuBackend {
return builder_->getTensorShape(tensors_[var_name]);
}

// SetScope, so we can get model parameters from scope
void SetScope(Scope* scope) {
scope_ = scope;
}

static std::shared_ptr<IpuBackend> GetInstance() {
if (NULL == instance_) {
instance_.reset(new IpuBackend());
Expand All @@ -88,13 +93,13 @@ class IpuBackend {
private:
Optimizer optimizer_;
IpuBuildStrategy ipu_build_strategy_;
Scope *scope_ = nullptr;

std::vector<popart::TensorId> inputs_;
std::vector<popart::TensorId> outputs_;
std::map<std::string, popart::TensorId> tensors_;

std::unique_ptr<popart::Builder> builder_;
popart::SessionOptions popart_options_;
std::unique_ptr<popart::Session> session_;

static std::shared_ptr<IpuBackend> instance_;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ipu/ipu_build_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace framework {
namespace ipu {

struct IpuBuildStrategy {
popart::SessionOptions popart_options;
popart::SessionOptions popart_options_;
};

} // namespace ipu
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3206,7 +3206,8 @@ All parameter, weight, gradient are variables in Paddle.
#ifdef PADDLE_WITH_IPU
py::class_<framework::IpuBackend, std::shared_ptr<framework::IpuBackend>>(m,
"IpuBackend")
.def(py::init(&IpuBackend::GetInstance));
.def(py::init(&IpuBackend::GetInstance))
.def("set_scope", &IpuBackend::SetScope);
#endif

BindFleetWrapper(&m);
Expand Down
17 changes: 12 additions & 5 deletions python/paddle/fluid/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,27 +488,35 @@ class IpuCompiler(object):
Args:
program(framework.Program): This argument is the Program being executed.
scope: This argument is the scope which contains model parameters.
ipu_build_strategy: This argument is used to build the program with the
specified options, such as operators' replacement, dtype, etc.
Returns:
framework.Program
"""

def __init__(self, program, ipu_build_strategy=None):
def __init__(self, program, scope=None, ipu_build_strategy=None):
if not isinstance(program, framework.Program):
raise TypeError(
"The type of program is wrong, expected Program, but got %s" %
type(program))
# import here to avoiding confused
import paddle

self._scope = None
if scope is not None:
self._scope = scope
else:
self._scope = paddle.static.global_scope()
self._program = program
self._graph = core.Graph(program.desc)
self._ipu_build_strategy = ipu_build_strategy
self._compiled = False
self._backend = core.IpuBackend()
self._graph_passes = ["optimizer_extract_pass",
"forward_graph_extract_pass"]
self._backend.set_scope(self._scope)
self._graph_passes = [
"optimizer_extract_pass", "forward_graph_extract_pass"
]

def compile(self, feed_list, fetch_list, scope=None):
for pass_name in self._graph_passes:
Expand All @@ -527,4 +535,3 @@ def compile(self, feed_list, fetch_list, scope=None):
program = framework.Program._construct_from_desc(desc)

return program

0 comments on commit d02eef1

Please sign in to comment.