Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]Add lstm show case #125

Open
wants to merge 25 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
58406fd
add
gongweibao May 8, 2020
3475a92
merge
gongweibao May 11, 2020
78e7c8f
merge
gongweibao May 11, 2020
cddda8c
Merge branch 'develop' of https://github.com/elasticdeeplearning/edl …
gongweibao May 11, 2020
a2e4abc
Merge branch 'develop' of https://github.com/elasticdeeplearning/edl …
gongweibao May 14, 2020
b454b4c
Merge branch 'develop' of https://github.com/elasticdeeplearning/edl …
gongweibao May 15, 2020
44558a7
Merge branch 'develop' of https://github.com/elasticdeeplearning/edl …
gongweibao May 15, 2020
d8941b6
Merge branch 'develop' of https://github.com/elasticdeeplearning/edl …
gongweibao May 18, 2020
593771b
Merge branch 'develop' of https://github.com/elasticdeeplearning/edl …
gongweibao May 18, 2020
2ecf315
Merge branch 'develop' of https://github.com/elasticdeeplearning/edl …
gongweibao May 20, 2020
2222d3e
Merge branch 'develop' of https://github.com/elasticdeeplearning/edl …
gongweibao May 26, 2020
49b650b
Merge branch 'develop' of https://github.com/elasticdeeplearning/edl …
gongweibao May 26, 2020
d3602c0
Merge branch 'develop' of https://github.com/elasticdeeplearning/edl …
gongweibao Jun 2, 2020
b61e04e
Merge branch 'develop' of https://github.com/elasticdeeplearning/edl …
gongweibao Jun 8, 2020
a48f911
Merge branch 'develop' of https://github.com/elasticdeeplearning/edl …
gongweibao Jun 8, 2020
6d79b03
Merge branch 'develop' of https://github.com/elasticdeeplearning/edl …
gongweibao Jun 10, 2020
def2fb7
Merge branch 'develop' of https://github.com/elasticdeeplearning/edl …
gongweibao Jun 12, 2020
f5a1ba3
Merge branch 'develop' of https://github.com/elasticdeeplearning/edl …
gongweibao Jun 19, 2020
ccd7409
Merge branch 'develop' of https://github.com/elasticdeeplearning/edl …
gongweibao Jun 28, 2020
59338c8
add lstm
gongweibao Jun 28, 2020
067dd4e
add lstm
gongweibao Jun 28, 2020
5444d00
add
gongweibao Jun 28, 2020
00da273
add
gongweibao Jun 28, 2020
c058d36
add
gongweibao Jun 28, 2020
2d5ac20
fix some test=develop
gongweibao Jun 29, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 34 additions & 27 deletions example/distill/nlp/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import sys
from paddle_serving_client import Client
from paddle_serving_app.reader import ChineseBertReader
from model import CNN, AdamW, evaluate_student, KL, BOW, KL_T
from model import CNN, AdamW, evaluate_student, KL, BOW, KL_T, model_factory

parser = argparse.ArgumentParser(__doc__)
parser.add_argument(
Expand All @@ -53,7 +53,9 @@
parser.add_argument(
"--use_data_au", type=int, default=1, help="use data augmentation")
parser.add_argument(
"--T", type=float, default=2.0, help="weight of student in loss")
"--T", type=float, default=None, help="weight of student in loss")
parser.add_argument(
"--model", type=str, default="BOW", help="student model name")
args = parser.parse_args()
print("parsed args:", args)

Expand All @@ -63,19 +65,16 @@

def train_with_distill(train_reader, dev_reader, word_dict, test_reader,
epoch_num):
boundaries = [2250 * 2, 2250 * 4, 2250 * 6]
values = [1e-4, 1.5e-4, 2.5e-4, 4e-4]
lr = D.PiecewiseDecay(boundaries, values, 0)
model = BOW(word_dict)
model = model_factory(args.model, word_dict)
if args.opt == "Adam":
opt = F.optimizer.Adam(
learning_rate=lr,
learning_rate=model.lr(steps_per_epoch=2250),
parameter_list=model.parameters(),
regularization=F.regularizer.L2Decay(
regularization_coeff=args.weight_decay))
else:
opt = AdamW(
learning_rate=lr,
learning_rate=model.lr(steps_per_epoch=2250),
parameter_list=model.parameters(),
weight_decay=args.weight_decay)

Expand All @@ -101,30 +100,34 @@ def train_with_distill(train_reader, dev_reader, word_dict, test_reader,
) * loss_kd
else:
loss_kd = KL_T(logits_s, logits_t, args.T)
loss = args.T * args.T * (args.s_weight * loss_ce +
(1.0 - args.s_weight) * loss_kd)
loss = args.T * args.T * (loss_ce + loss_kd)
#loss_kd = KL(logits_s, logits_t)
#loss = loss_ce + loss_kd

loss = L.reduce_mean(loss)
loss.backward()
if step % 10 == 0:
if step % 100 == 0:
print("stduent logits:", logits_s)
print("teatcher logits:", logits_t)
print('[step %03d] distill train loss %.5f lr %.3e' %
(step, loss.numpy(), opt.current_step_lr()))
opt.minimize(loss)
model.clear_gradients()
f1, acc = evaluate_student(model, dev_reader)
print('student on dev f1 %.5f acc %.5f' % (f1, acc))
print('student on dev f1 %.5f acc %.5f epoch_no %d' % (f1, acc, epoch))

if max_dev_acc < acc:
max_dev_acc = acc

f1, acc = evaluate_student(model, test_reader)
print('student on test f1 %.5f acc %.5f' % (f1, acc))
print('student on test f1 %.5f acc %.5f epoch_no %d' %
(f1, acc, epoch))

if max_test_acc < acc:
max_test_acc = acc

g_max_dev_acc.append(g_max_dev_acc)
g_max_test_acc.append(g_max_test_acc)
g_max_dev_acc.append(max_dev_acc)
g_max_test_acc.append(max_test_acc)


def ernie_reader(s_reader, key_list):
Expand Down Expand Up @@ -155,10 +158,7 @@ def reader():
return reader


if __name__ == "__main__":
place = F.CUDAPlace(0)
D.guard(place).__enter__()

def train():
ds = ChnSentiCorp()
word_dict = ds.student_word_dict("./data/vocab.bow.txt")
batch_size = 16
Expand Down Expand Up @@ -195,14 +195,21 @@ def reader():
input_files, word_dict, batch_size=batch_size)
dr_t = dr.set_batch_generator(ernie_reader(dr_train_reader, feed_keys))

train_with_distill(
dr_t, dev_reader, word_dict, test_reader, epoch_num=args.epoch_num)


if __name__ == "__main__":
place = F.CUDAPlace(0)
D.guard(place).__enter__()

for i in range(args.train_range):
train_with_distill(
dr_t, dev_reader, word_dict, test_reader, epoch_num=args.epoch_num)
train()

arr = np.array(g_max_dev_acc)
print("max_dev_acc:", arr, "average:", np.average(arr), "train_args:",
args)
arr = np.array(g_max_dev_acc)
print("max_dev_acc:", arr, "average:", np.average(arr), "train_args:",
args)

arr = np.array(g_max_test_acc)
print("max_test_acc:", arr, "average:", np.average(arr), "train_args:",
args)
arr = np.array(g_max_test_acc)
print("max_test_acc:", arr, "average:", np.average(arr), "train_args:",
args)
65 changes: 65 additions & 0 deletions example/distill/nlp/lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import os

import numpy as np
import argparse
from sklearn.metrics import f1_score, accuracy_score
import paddle as P
import paddle.fluid as F
import paddle.fluid.layers as L
import paddle.fluid.dygraph as D
from reader import ChnSentiCorp, pad_batch_data
from paddle_edl.distill.distill_reader import DistillReader
import re

import os
import sys
from paddle_serving_client import Client
from paddle_serving_app.reader import ChineseBertReader
from text_basic import LSTM as basic_lstm


class LSTM(D.Layer):
def __init__(self, word_dict):
super().__init__()

self.emb = D.Embedding([len(word_dict), 300])
self.lstm = basic_lstm(input_size=300, hidden_size=150)
self.fc = D.Linear(150, 2)

def forward(self, ids, labels=None):
embbed = self.emb(ids)
#print("embed shape:", embbed.shape)

lstm_out, hidden = self.lstm(embbed)
#print("lstm_out shape:", lstm_out.shape)
#print("hiden list len:", len(hidden))

logits = self.fc(lstm_out[:, -1])
#print("logits shape:", logits.shape)

if labels is not None:
if len(labels.shape) == 1:
labels = L.reshape(labels, [-1, 1])
loss = L.softmax_with_cross_entropy(logits, labels)
else:
loss = None

return loss, logits

def lr(self, steps_per_epoch=None):
return 1e-3
32 changes: 30 additions & 2 deletions example/distill/nlp/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import sys
from paddle_serving_client import Client
from paddle_serving_app.reader import ChineseBertReader
from lstm import LSTM
from nets import GRU


class AdamW(F.optimizer.AdamOptimizer):
Expand Down Expand Up @@ -66,12 +68,15 @@ def KL_T(logits_s, logits_t, T=2.0):
return loss


def evaluate_student(model, test_reader):
def evaluate_student(model, test_reader, batch_size=None):
all_pred, all_label = [], []
with D.base._switch_tracer_mode_guard_(is_train=False):
model.eval()
for step, (ids_student, labels, _) in enumerate(test_reader()):
_, logits = model(ids_student)
if batch_size is not None:
_, logits = model(ids_student, batch_size=batch_size)
else:
_, logits = model(ids_student)
pred = L.argmax(logits, -1)
all_pred.extend(pred.numpy())
all_label.extend(labels.numpy())
Expand Down Expand Up @@ -105,6 +110,13 @@ def forward(self, ids, labels=None):
loss = None
return loss, logits

def lr(self, steps_per_epoch):
values = [1e-4, 1.5e-4, 2.5e-4, 4e-4]
boundaries = [
steps_per_epoch * 2, steps_per_epoch * 4, steps_per_epoch * 6
]
return D.PiecewiseDecay(boundaries, values, 0)


class CNN(D.Layer):
def __init__(self, word_dict):
Expand Down Expand Up @@ -133,3 +145,19 @@ def forward(self, ids, labels=None):
else:
loss = None
return loss, logits

def lr(self, steps_per_epoch=None):
return 1e-4


def model_factory(model_name, word_dict):
if model_name == "BOW":
return BOW(word_dict)
elif model_name == "CNN":
return CNN(word_dict)
elif model_name == "LSTM":
return LSTM(word_dict)
elif model_name == "GRU":
return GRU(word_dict)
else:
assert False, "not supported model name:{}".format(model_name)
Loading