Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JitLayer]Split layer_utils.cc and polish interface BaseFunction #43754

Merged
merged 6 commits into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions paddle/fluid/jit/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,25 @@ cc_library(
SRCS serializer.cc
DEPS lod_tensor device_context)

cc_library(
jit_layer_utils
SRCS layer_utils.cc
DEPS scope proto_desc)

cc_library(
jit_compilation_unit
SRCS compilation_unit.cc
DEPS proto_desc executor parallel_executor executor_cache)

cc_library(
jit_layer
SRCS layer.cc
DEPS executor parallel_executor executor_cache)
DEPS jit_compilation_unit)

cc_library(
jit_base_function
SRCS base_function.cc
DEPS scope proto_desc)
jit_function_schema
SRCS function_schema.cc
DEPS jit_layer_utils)

if(WITH_TESTING AND NOT WIN32)
add_custom_target(
Expand All @@ -32,7 +42,9 @@ if(WITH_TESTING AND NOT WIN32)
scale_op
jit_serializer
jit_layer
jit_base_function)
jit_layer_utils
jit_function_schema
jit_compilation_unit)
cc_test(
layer_test
SRCS layer_test.cc
Expand Down
139 changes: 0 additions & 139 deletions paddle/fluid/jit/base_function.cc

This file was deleted.

67 changes: 5 additions & 62 deletions paddle/fluid/jit/base_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,77 +17,20 @@
#include <ostream>
#include <string>

#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/common/place.h"

#include "paddle/fluid/framework/variable.h"

namespace paddle {
namespace jit {

using Variable = paddle::framework::Variable;
using VariableNameMap = std::map<std::string, Variable>;
using DenseTensor = phi::DenseTensor;

class Argument {
public:
explicit Argument(const std::string &name, bool is_out = false);

const std::string &Name() const;

private:
std::string name_;
// paddle::optional<Variable> default_val_;
bool is_output_;
};

class FunctionSchema {
public:
FunctionSchema() = default;

std::vector<std::string> GetInputArgNames();

std::vector<std::string> GetOutputArgNames();

void AddInputArg(std::string name);

void AddOutputArg(std::string name);

private:
// input_args and output_args are ordered
std::vector<Argument> input_args;
std::vector<Argument> output_args;
};

// TODO(dev): make it as abstract class
class BaseFunction {
public:
BaseFunction(const framework::ProgramDesc &program_desc,
const std::vector<std::string> &param_names,
const VariableNameMap &params_dict,
const phi::Place &place);

virtual ~BaseFunction() {}

virtual std::vector<Variable> operator()(
const std::vector<Variable> &inputs) = 0;

protected:
void FetchOutput(std::vector<Variable> *outs);

void ShareInputsIntoScope(const std::vector<Variable> &vars);

void ShareParamsIntoScope(const std::vector<std::string> &param_names,
const VariableNameMap &params_dict);

void RemoveFeedFetch();

protected:
framework::ProgramDesc program_desc_;
FunctionSchema schema_;
// global_scope place params
framework::Scope scope_;
phi::Place place_;
virtual ~BaseFunction() {}
// virtual void SetPalce(const phi::Place &place);
};

} // namespace jit
Expand Down
43 changes: 43 additions & 0 deletions paddle/fluid/jit/compilation_unit.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/jit/compilation_unit.h"

namespace paddle {
namespace jit {

void CompilationUnit::AddExecutorFunction(
const std::string &func_name,
const std::shared_ptr<FunctionInfo> &info,
const Name2VariableMap &params_dict,
const phi::Place &place) {
function_dict_[func_name] =
std::make_shared<ExecutorFunction>(info, params_dict, place);
}

void CompilationUnit::AddPEFunction(const std::string &func_name,
const std::shared_ptr<FunctionInfo> &info,
const Name2VariableMap &params_dict,
const phi::Place &place) {
function_dict_[func_name] =
std::make_shared<PEFunction>(info, params_dict, place);
}

std::shared_ptr<BaseFunction> CompilationUnit::GetFunction(
const std::string &name) const {
return function_dict_.at(name);
}

} // namespace jit
} // namespace paddle
22 changes: 17 additions & 5 deletions paddle/fluid/jit/compilation_unit.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,35 @@

#pragma once

#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/jit/executor_function.h"
#include "paddle/fluid/jit/function_schema.h"
#include "paddle/fluid/jit/pe_function.h"

namespace paddle {
namespace jit {
class BaseFunction;

class CompilationUnit {
public:
CompilationUnit() = default;
~CompilationUnit() {}

void AddExecutorFunction(const std::string &func_name,
const std::shared_ptr<FunctionInfo> &info,
const Name2VariableMap &params_dict,
const phi::Place &place);

void AddPEFunction(const std::string &func_name,
const std::shared_ptr<FunctionInfo> &info,
const Name2VariableMap &params_dict,
const phi::Place &place);

std::shared_ptr<BaseFunction> GetFunction(const std::string &name) const;

private:
std::vector<std::unique_ptr<BaseFunction>> functions_;
std::unordered_map<std::string, size_t> functions_idx_;
std::unordered_map<std::string, std::shared_ptr<BaseFunction>> function_dict_;
};

} // namespace jit
Expand Down
Loading