Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JitLayer]Split layer_utils.cc and polish interface BaseFunction #43754

Merged
merged 6 commits into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions paddle/fluid/jit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,20 @@ cc_library(
SRCS serializer.cc
DEPS lod_tensor device_context)

cc_library(
jit_layer_utils
SRCS layer_utils.cc
DEPS scope proto_desc)

cc_library(
jit_layer
SRCS layer.cc
DEPS executor parallel_executor executor_cache)

cc_library(
jit_base_function
SRCS base_function.cc
DEPS scope proto_desc)
jit_function_schema
SRCS function_schema.cc
DEPS proto_desc)

if(WITH_TESTING AND NOT WIN32)
add_custom_target(
Expand All @@ -32,7 +37,8 @@ if(WITH_TESTING AND NOT WIN32)
scale_op
jit_serializer
jit_layer
jit_base_function)
jit_layer_utils
jit_function_schema)
cc_test(
layer_test
SRCS layer_test.cc
Expand Down
139 changes: 0 additions & 139 deletions paddle/fluid/jit/base_function.cc

This file was deleted.

67 changes: 5 additions & 62 deletions paddle/fluid/jit/base_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,77 +17,20 @@
#include <ostream>
#include <string>

#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/common/place.h"

#include "paddle/fluid/framework/variable.h"

namespace paddle {
namespace jit {

using Variable = paddle::framework::Variable;
using VariableNameMap = std::map<std::string, Variable>;
using DenseTensor = phi::DenseTensor;

class Argument {
public:
explicit Argument(const std::string &name, bool is_out = false);

const std::string &Name() const;

private:
std::string name_;
// paddle::optional<Variable> default_val_;
bool is_output_;
};

class FunctionSchema {
public:
FunctionSchema() = default;

std::vector<std::string> GetInputArgNames();

std::vector<std::string> GetOutputArgNames();

void AddInputArg(std::string name);

void AddOutputArg(std::string name);

private:
// input_args and output_args are ordered
std::vector<Argument> input_args;
std::vector<Argument> output_args;
};

// TODO(dev): make it as abstract class
class BaseFunction {
public:
BaseFunction(const framework::ProgramDesc &program_desc,
const std::vector<std::string> &param_names,
const VariableNameMap &params_dict,
const phi::Place &place);

virtual ~BaseFunction() {}

virtual std::vector<Variable> operator()(
const std::vector<Variable> &inputs) = 0;

protected:
void FetchOutput(std::vector<Variable> *outs);

void ShareInputsIntoScope(const std::vector<Variable> &vars);

void ShareParamsIntoScope(const std::vector<std::string> &param_names,
const VariableNameMap &params_dict);

void RemoveFeedFetch();

protected:
framework::ProgramDesc program_desc_;
FunctionSchema schema_;
// global_scope place params
framework::Scope scope_;
phi::Place place_;
virtual ~BaseFunction() {}
// virtual void SetPalce(const phi::Place &place);
};

} // namespace jit
Expand Down
38 changes: 25 additions & 13 deletions paddle/fluid/jit/exector_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,52 @@

#pragma once

#include <iostream>
#include <string>
#include <vector>

#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable.h"

#include "paddle/fluid/jit/base_function.h"
#include "paddle/fluid/jit/function_schema.h"
#include "paddle/fluid/jit/layer_utils.h"

namespace paddle {
namespace jit {

class ExectorFunction : public BaseFunction {
public:
ExectorFunction(const framework::ProgramDesc &program_desc,
const std::vector<std::string> param_names,
const VariableNameMap &params_dict,
ExectorFunction(const std::shared_ptr<FunctionInfo> &info,
const Name2VariableMap &params_dict,
const phi::Place &place)
: BaseFunction(program_desc, param_names, params_dict, place),
inner_exe_(place_) {}
: info_(info), place_(place), inner_exe_(place_) {
ShareParamsIntoScope(info_->GetParamNames(), params_dict, &scope_);
VLOG(6) << framework::GenScopeTreeDebugInfo(&scope_);
}

~ExectorFunction() {}
~ExectorFunction() noexcept {}

std::vector<Variable> operator()(const std::vector<Variable> &inputs) {
// share input into scope
ShareInputsIntoScope(inputs);
// run program
inner_exe_.Run(program_desc_,
ShareInputsIntoScope(info_->GetInputArgNames(), inputs, &scope_);
inner_exe_.Run(info_->GetProgramDesc(),
&scope_,
/*blockID=*/0,
false,
true,
schema_.GetOutputArgNames());
info_->GetOutputArgNames());
VLOG(6) << framework::GenScopeTreeDebugInfo(&scope_);
// fetch outputs
std::vector<Variable> res;
FetchOutput(&res);
FetchVarsByNames(info_->GetOutputArgNames(), scope_, &res);
return res;
}

private:
std::shared_ptr<FunctionInfo> info_;
framework::Scope scope_;
phi::Place place_;
framework::Executor inner_exe_;
};

Expand Down
Loading