Skip to content

Commit

Permalink
Merge pull request #3136 from jacquesqiao/refine-context
Browse files Browse the repository at this point in the history
add check in OperatorContext Input/Output
  • Loading branch information
gangliao authored Aug 2, 2017
2 parents 53616fd + 6b2323c commit f70e807
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
6 changes: 4 additions & 2 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ std::vector<std::string> OperatorBase::Inputs(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr");
auto input_format = GetAttr<std::vector<int>>("input_format");
auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(input_format.at((size_t)offset + 1) <= (int)inputs_.size(),
PADDLE_ENFORCE(input_format.at(static_cast<size_t>(offset) + 1) <=
static_cast<int>(inputs_.size()),
"Input Out Of Range");

return std::vector<std::string>{
Expand All @@ -78,7 +79,8 @@ std::vector<std::string> OperatorBase::Outputs(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr");
auto output_format = GetAttr<std::vector<int>>("output_format");
auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(output_format.at((size_t)offset + 1) <= (int)outputs_.size(),
PADDLE_ENFORCE(output_format.at(static_cast<size_t>(offset) + 1) <=
static_cast<int>(outputs_.size()),
"Output Out of Range");
return std::vector<std::string>{
outputs_.begin() + output_format.at(offset),
Expand Down
32 changes: 24 additions & 8 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,22 +161,30 @@ class OperatorContext {

template <typename T>
const T* Input(const size_t index) const {
return &(InputVar(index)->Get<T>());
auto var = InputVar(index);
PADDLE_ENFORCE(var != nullptr, "Input(%d) should not be nullptr", index);
return &var->Get<T>();
}

template <typename T>
T* Output(const size_t index) const {
return OutputVar(index)->GetMutable<T>();
auto var = OutputVar(index);
PADDLE_ENFORCE(var != nullptr, "Output(%d) should not be nullptr", index);
return var->GetMutable<T>();
}

template <typename T>
const T* Input(const std::string& name) const {
return &(InputVar(name)->Get<T>());
auto var = InputVar(name);
PADDLE_ENFORCE(var != nullptr, "Input(%s) should not be nullptr", name);
return &var->Get<T>();
}

template <typename T>
T* Output(const std::string& name) const {
return OutputVar(name)->GetMutable<T>();
auto var = OutputVar(name);
PADDLE_ENFORCE(var != nullptr, "Output(%s) should not be nullptr", name);
return var->GetMutable<T>();
}

template <typename T>
Expand All @@ -185,8 +193,12 @@ class OperatorContext {
std::vector<const T*> res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) {
return &scope_.FindVar(name)->Get<T>();
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
PADDLE_ENFORCE(var != nullptr,
"MultiInput(%s:%s) should not be nullptr",
name, sub_name);
return &var->Get<T>();
});
return res;
}
Expand All @@ -197,8 +209,12 @@ class OperatorContext {
std::vector<const T*> res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[this](const std::string& name) {
return scope_.FindVar(name)->GetMutable<T>();
[&](const std::string& sub_name) {
auto var = scope_.FindVar(sub_name);
PADDLE_ENFORCE(var != nullptr,
"MultiOutput(%s:%s) should not be nullptr",
name, sub_name);
return var->GetMutable<T>();
});
return res;
}
Expand Down

0 comments on commit f70e807

Please sign in to comment.