-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add feed and fetch op to ProgramDesc before saving for inference #7636
Add feed and fetch op to ProgramDesc before saving for inference #7636
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR, this looks good. I am just not sure if we should merge this, since the current PR addresses 1.1 and the already existing code addresses 1.2 in #7580 . So do we have a preference of 1.1 over 1.2?
@sidgoyal78 I think it's fine to merge it for now. We can have both 1.1 and 1.2 supported via a PR in the future. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have some comments. However, we may change these in next PRs.
@@ -62,7 +80,7 @@ void InferenceEngine::LoadInferenceModel( | |||
} | |||
|
|||
bool InferenceEngine::IsParameter(const framework::VarDesc* var) { | |||
if (var->Persistable()) { | |||
if (var->Persistable() && var->Name() != "feed" && var->Name() != "fetch") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not use the name of Variable
to decide whether the var is input or output of feed_op
and fetch_op
, because the name is not fixed, and it is possible to specify other names.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. And we don't need to check fetch
because fetch will not be an input to an op.
We can get the feed var name from the feed op's input info.
Will fix in the future PR.
} | ||
} | ||
} | ||
|
||
void InferenceEngine::LoadInferenceModel( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove this function. If there are not feed_op
and fetch_op
in the ProgramDesc
, users can specify these when calling Run()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, i don't understand this properly. Based on the updated design, the Run()
function does not take as input the vector of fetch_var_name and feed_var_names. Right?
void Run(const ProgramDesc* program,
Scope* scope,
std::map<std::string, Tensor>& feeds,
std::map<std::string, Tensor>& fetchs,
std::string& feed_var_name = "feed",
std::string& fetch_var_name = "fetch") {
So can you please explain the idea that users can specify that information when calling Run()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can get the feed_var_names
from the argument std::map<std::string, Tensor>& feeds
, where the std::string
represent a name and the Tensor
is input data.
Why the argument is a std::map
, because the corresponding argument in Python implementation is a dict.
Have a look at the example, where show the detailed usage of the Executor.Run()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, will take a look. Thanks for the reply.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, will remove this function in the next PR.
@@ -241,6 +269,9 @@ def save_inference_model(dirname, | |||
"fetch_var_names": fetch_var_names | |||
}, f, -1) | |||
|
|||
prepend_feed_ops(inference_program, feeded_var_names) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove Line 265 - 270
now, and change the implementation of load_inference_model
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do this in the next PR. Thanks!
def prepend_feed_ops(inference_program, feeded_var_names): | ||
global_block = inference_program.global_block() | ||
feed_var = global_block.create_var( | ||
name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be some problem if fixed the name to feed
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will fix this in the next PR.
fix #7550