Skip to content

Commit

Permalink
[DoubleGrad PR PaddlePaddle#8] Enabled triple grads for sigmoid and m…
Browse files Browse the repository at this point in the history
…atmul
  • Loading branch information
jim19930609 committed Apr 4, 2022
1 parent a44667e commit 0da4325
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
### Global Variables ###
########################
ops_to_fill_zero_for_empty_grads = set(
["split_grad", "rnn_grad", "matmul_double_grad"])
["split_grad", "rnn_grad", "matmul_double_grad", "matmul_triple_grad"])

# For API dispatch used at python-level
# { op_name : [arg_name, ...] }
Expand Down Expand Up @@ -171,12 +171,6 @@ def GetForwardFunctionName(string):
return f"{string}_final_state_dygraph_function"


def TransformGradVarNameForDoubleGradGeneration(string):
if IsGradName(string):
string = "grad_" + string[:-5]
return string


def GetIndent(num):
tab = " "
return "".join([tab for i in range(num)])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from codegen_utils import ParseYamlForward, ParseYamlBackward
from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase
from codegen_utils import ops_to_fill_zero_for_empty_grads
from codegen_utils import TransformGradVarNameForDoubleGradGeneration
from codegen_utils import AssertMessage, GetIndent


Expand Down Expand Up @@ -476,10 +475,8 @@ def ForwardsValidationCheck(self):
orig_forward_returns_list = self.orig_forward_returns_list

for i in range(len(forward_inputs_list)):
forward_input_name = forward_inputs_list[i][0]
forward_input_type = forward_inputs_list[i][1]
forward_input_pos = forward_inputs_list[i][2]
orig_input_name = orig_forward_inputs_list[i][0]
orig_input_type = orig_forward_inputs_list[i][1]
orig_input_pos = orig_forward_inputs_list[i][2]

Expand All @@ -489,11 +486,9 @@ def ForwardsValidationCheck(self):
forward_input_pos, orig_input_pos)

for i in range(len(forward_attrs_list)):
orig_attr_name = orig_forward_attrs_list[i][0]
orig_attr_type = orig_forward_attrs_list[i][1]
orig_attr_default = orig_forward_attrs_list[i][2]
orig_attr_pos = orig_forward_attrs_list[i][3]
forward_attr_name = forward_attrs_list[i][0]
forward_attr_type = forward_attrs_list[i][1]
forward_attr_default = forward_attrs_list[i][2]
forward_attr_pos = forward_attrs_list[i][3]
Expand Down Expand Up @@ -1125,11 +1120,20 @@ def __init__(self,
DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
grad_api_contents, namespace)

# Record name mapping from forward_api_name to grad_api_names
self.to_next_grad_name_mapping = {} # {name : name}

# Generated Results
self.node_declaration_str = ""
self.node_definition_str = ""
self.next_grad_api_contents = next_grad_api_contents

def TransformToNextGradName(self, string):
name_mapping = self.to_next_grad_name_mapping
if string in name_mapping.keys():
return name_mapping[string]
return string

def ResetOptionalInputs(self):
namespace = self.namespace
grad_api_contents = self.grad_api_contents
Expand All @@ -1139,6 +1143,22 @@ def ResetOptionalInputs(self):

self.optional_inputs = base_generator.optional_inputs

def RecordGrad2NextGradNameMapping(self, next_node_generator):
next_orig_inputs_list = next_node_generator.orig_forward_inputs_list
next_orig_returns_list = next_node_generator.orig_forward_returns_list

next_forward_inputs_list = next_node_generator.forward_inputs_list
next_forward_returns_list = next_node_generator.forward_returns_list
for i in range(len(next_orig_inputs_list)):
grad_name = next_orig_inputs_list[i][0]
next_forward_name = next_forward_inputs_list[i][0]
self.to_next_grad_name_mapping[grad_name] = next_forward_name

for i in range(len(next_orig_returns_list)):
grad_ret_name = next_orig_returns_list[i][0]
next_ret_name = next_forward_returns_list[i][0]
self.to_next_grad_name_mapping[grad_ret_name] = next_ret_name

def GenerateHigherOrderNodeCreationCode(self):
namespace = self.namespace
grad_api_contents = self.grad_api_contents
Expand All @@ -1156,6 +1176,8 @@ def GenerateHigherOrderNodeCreationCode(self):
next_node_generator.GenerateNodeCreationCodes()
grad_node_creation_str = next_node_generator.node_creation_str

self.RecordGrad2NextGradNameMapping(next_node_generator)

return grad_node_creation_str

def GenerateNodeDeclaration(self):
Expand Down Expand Up @@ -1244,8 +1266,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
for name, (_, is_fwd_input,
grad_api_position), in backward_forward_inputs_map.items():
tensor_wrapper_name = GetSavedName(name)
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

is_optional = (name in self.optional_inputs)
if is_optional:
Expand All @@ -1258,8 +1279,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
# Grad Ins from grads
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_inputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

is_optional = (name in self.optional_inputs)
if IsPlainTensorType(ttype):
Expand Down Expand Up @@ -1300,8 +1320,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
num_outputs = len(backward_grad_outputs_map.keys())
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

if num_outputs == 1:
get_tensor_str = f"{indent}auto& {transformed_tensor_name} = grad_api_result;"
Expand All @@ -1323,8 +1342,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos,
grad_api_position) in backward_grad_inputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
Expand All @@ -1342,8 +1360,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):

# 2. Get TensorWrapper AutoGradMeta
for name, (ttype, _, pos), in backward_forward_inputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
Expand All @@ -1366,8 +1383,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
outputs_autograd_meta_list = []
num_fwd_outputs = len(backward_grad_outputs_map.keys())
for name, (rtype, pos, _) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

output_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
Expand Down Expand Up @@ -1401,8 +1417,7 @@ def GenerateNodeDefinition(self, grad_node_creation_str):
returns_str = f"{indent}std::vector<std::vector<paddle::experimental::Tensor>> returns({slot_num_bwd_outputs});\n"
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)

# Infer Grad API Return Type
if num_bwd_outputs == 1:
Expand Down Expand Up @@ -1441,11 +1456,11 @@ def run(self):
#####################
## Code Generation ##
#####################
self.GenerateNodeDeclaration()

# Higher-order GradNode generation
grad_node_creation_str = self.GenerateHigherOrderNodeCreationCode()

self.GenerateNodeDeclaration()

self.GenerateNodeDefinition(grad_node_creation_str)


Expand Down
48 changes: 48 additions & 0 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,54 @@ void GeneralTernaryGradInferMeta(const MetaTensor& x,
dz->share_meta(z);
}
}
void GeneralQuaternaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
const MetaTensor& k,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz,
MetaTensor* dk) {
if (dx) {
dx->share_meta(x);
}
if (dy) {
dy->share_meta(y);
}
if (dz) {
dz->share_meta(z);
}
if (dk) {
dk->share_meta(k);
}
}

void GeneralQuinaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
const MetaTensor& k,
const MetaTensor& l,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz,
MetaTensor* dk,
MetaTensor* dl) {
if (dx) {
dx->share_meta(x);
}
if (dy) {
dy->share_meta(y);
}
if (dz) {
dz->share_meta(z);
}
if (dk) {
dk->share_meta(k);
}
if (dl) {
dl->share_meta(l);
}
}

void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx) {
if (dx) {
Expand Down
20 changes: 20 additions & 0 deletions paddle/phi/infermeta/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,26 @@ void GeneralTernaryGradInferMeta(const MetaTensor& x,
MetaTensor* dy,
MetaTensor* dz);

void GeneralQuaternaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
const MetaTensor& k,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz,
MetaTensor* dk);

void GeneralQuinaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
const MetaTensor& k,
const MetaTensor& l,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz,
MetaTensor* dk,
MetaTensor* dl);

void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx);

void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1414,12 +1414,12 @@ struct SigmoidTripleGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev,
const DenseTensor* Out,
const DenseTensor* ddX,
const DenseTensor* dOut,
const DenseTensor* d_DDOut,
const DenseTensor* ddX,
const DenseTensor* d_dOut_New,
DenseTensor* d_d_Out,
const DenseTensor* d_DDOut,
DenseTensor* d_Out_New,
DenseTensor* d_d_Out,
DenseTensor* d_DDx) const {
auto* d = dev.eigen_device();
auto ddx = EigenVector<T>::Flatten(
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/impl/activation_grad_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ void LogitGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void SigmoidDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
const DenseTensor& dout,
const DenseTensor& ddx,
DenseTensor* dout_new,
DenseTensor* ddout) {
if (dout_new) {
Expand Down
34 changes: 34 additions & 0 deletions python/paddle/utils/code_gen/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@
param : [x, y, grad_out]
kernel :
func : matmul_double_grad
backward : matmul_triple_grad
optional : grad_x_grad, grad_y_grad

- backward_api : matmul_grad
Expand All @@ -547,6 +548,17 @@
func : matmul_grad
backward : matmul_double_grad

- backward_api : matmul_triple_grad
forward : matmul_double_grad (Tensor x, Tensor y, Tensor fwd_grad_out, Tensor fwd_grad_grad_x, Tensor fwd_grad_grad_y, bool transpose_x=false, bool transpose_y=false) -> Tensor(grad_x), Tensor(grad_y), Tensor(grad_grad_out)
args : (Tensor x, Tensor y, Tensor fwd_grad_out, Tensor fwd_grad_grad_x, Tensor fwd_grad_grad_y, Tensor grad_x_grad, Tensor grad_y_grad, Tensor grad_grad_out_grad, bool transpose_x=false, bool transpose_y=false)
output : Tensor(x_grad), Tensor(y_grad), Tensor(fwd_grad_out_grad), Tensor(fwd_grad_grad_x_grad), Tensor(fwd_grad_grad_y_grad)
infer_meta :
func : GeneralQuinaryGradInferMeta
param : [x, y, fwd_grad_out, fwd_grad_grad_x, fwd_grad_grad_y]
kernel :
func : matmul_triple_grad
optional : grad_x_grad, grad_y_grad, grad_grad_out_grad

- backward_api : matrix_power_grad
forward : matrix_power (Tensor x, int n) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int n)
Expand Down Expand Up @@ -832,6 +844,17 @@
kernel :
func : sigmoid_cross_entropy_with_logits_grad

- backward_api : sigmoid_double_grad
forward : sigmoid_grad (Tensor out, Tensor fwd_grad_out) -> Tensor(grad_x)
args : (Tensor out, Tensor fwd_grad_out, Tensor grad_x_grad)
output : Tensor(out_grad), Tensor(fwd_grad_out_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [out, fwd_grad_out]
kernel :
func : sigmoid_double_grad
backward : sigmoid_triple_grad

- backward_api : sigmoid_grad
forward : sigmoid (Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
Expand All @@ -841,6 +864,17 @@
param : [out]
kernel :
func : sigmoid_grad
backward : sigmoid_double_grad

- backward_api : sigmoid_triple_grad
forward : sigmoid_double_grad (Tensor out, Tensor fwd_grad_out, Tensor grad_grad_x) -> Tensor(grad_out), Tensor(grad_grad_out)
args : (Tensor out, Tensor fwd_grad_out, Tensor grad_grad_x, Tensor grad_out_grad, Tensor grad_grad_out_grad)
output : Tensor(out_grad), Tensor(fwd_grad_out_grad), Tensor(grad_grad_x_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [out, fwd_grad_out, grad_grad_x]
kernel :
func : sigmoid_double_grad

- backward_api : silu_grad
forward : silu (Tensor x) -> Tensor(out)
Expand Down

0 comments on commit 0da4325

Please sign in to comment.