Skip to content

Commit

Permalink
[Dy2St] pir dy2st unittest verification - Part 6 (#59007)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: SigureMo <sigure.qaq@gmail.com>
  • Loading branch information
gouzil and SigureMo authored Nov 17, 2023
1 parent 9f4d209 commit f2274dc
Show file tree
Hide file tree
Showing 9 changed files with 167 additions and 117 deletions.
29 changes: 26 additions & 3 deletions python/paddle/jit/sot/infer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import cached_property

import paddle
from paddle.amp.auto_cast import amp_state
from paddle.base import framework
Expand All @@ -20,7 +22,6 @@
UniqueNameGenerator,
guard as UniqueNameGuard,
)
from paddle.static import Program
from paddle.utils import flatten, is_sequence

from .utils import Cache, Singleton, map_if_extend, meta_str
Expand Down Expand Up @@ -105,8 +106,6 @@ class VariableCreator:

def __init__(self):
self.var_cache = {}
self.main_program = Program()
self.startup_program = Program()
self.var_name_generator = UniqueNameGenerator("infer_meta_variable_")

def gen_name(self, meta):
Expand All @@ -115,6 +114,30 @@ def gen_name(self, meta):
name += f"_{l}"
return name

@cached_property
def legacy_programs(self):
# Just for PIR and legacy IR compatibility.
# This can be removed after PIR become default state.
return (paddle.static.Program(), paddle.static.Program())

@cached_property
def pir_programs(self):
return (paddle.static.Program(), paddle.static.Program())

@property
def main_program(self):
if paddle.base.framework.use_pir_api():
return self.pir_programs[0]
else:
return self.legacy_programs[0]

@property
def startup_program(self):
if paddle.base.framework.use_pir_api():
return self.pir_programs[1]
else:
return self.legacy_programs[1]

def create_var(self, meta):
if paddle.base.framework.use_pir_api():
with paddle.static.program_guard(
Expand Down
5 changes: 5 additions & 0 deletions test/dygraph_to_static/dygraph_to_static_utils_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,11 @@ def test_sot_only(fn):
return fn


def test_legacy_only(fn):
fn = set_ir_mode(IrMode.LEGACY_IR)(fn)
return fn


def test_pir_only(fn):
fn = set_ir_mode(IrMode.PIR_EXE)(fn)
return fn
Expand Down
2 changes: 0 additions & 2 deletions test/dygraph_to_static/simnet_dygraph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

import paddle
import paddle.base.param_attr as attr
from paddle.jit.api import to_static
from paddle.nn import Layer


Expand Down Expand Up @@ -484,7 +483,6 @@ def __init__(self, conf_dict):
self.bow_layer_po = FCLayer(self.bow_dim, None, "fc").ops()
self.softmax_layer = FCLayer(2, "softmax", "cos_sim").ops()

@to_static
def forward(self, left, right):
"""
Forward network
Expand Down
5 changes: 1 addition & 4 deletions test/dygraph_to_static/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pir,
test_legacy_and_pir_exe_and_pir_api,
)

Expand Down Expand Up @@ -188,11 +187,9 @@ def set_func(self):
self.func = paddle.jit.to_static(full_graph=True)(test_not_var_cast)

@test_ast_only
@test_legacy_and_pir
@test_legacy_and_pir_exe_and_pir_api
def test_cast_result(self):
self.set_func()
# breakpoint()
# print("run once!!!")
res = self.do_test()
self.assertTrue(type(res) == int, msg='The casted dtype is not int.')
ref_val = int(self.input)
Expand Down
54 changes: 31 additions & 23 deletions test/dygraph_to_static/test_mobile_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@

import paddle
from paddle import base
from paddle.base.framework import unique_name
from paddle.base.param_attr import ParamAttr
from paddle.jit.api import to_static
from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn import BatchNorm, Linear

Expand Down Expand Up @@ -267,7 +267,6 @@ def __init__(self, scale=1.0, class_dim=1000):
bias_attr=ParamAttr(name="fc7_offset"),
)

@to_static
def forward(self, inputs):
y = self.conv1(inputs)
for dws in self.dwsl:
Expand Down Expand Up @@ -433,7 +432,6 @@ def __init__(self, class_dim=1000, scale=1.0):
bias_attr=ParamAttr(name="fc10_offset"),
)

@to_static
def forward(self, inputs):
y = self._conv1(inputs, if_act=True)
for inv in self._invl:
Expand Down Expand Up @@ -496,7 +494,9 @@ class Args:
print_step = 1
train_step = 10
place = (
base.CUDAPlace(0) if base.is_compiled_with_cuda() else base.CPUPlace()
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
model_save_dir = None
model_save_prefix = None
Expand All @@ -507,15 +507,20 @@ class Args:

def train_mobilenet(args, to_static):
paddle.jit.enable_to_static(to_static)
with base.dygraph.guard(args.place):

with unique_name.guard():
np.random.seed(SEED)
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)

if args.model == "MobileNetV1":
net = MobileNetV1(class_dim=args.class_dim, scale=1.0)
net = paddle.jit.to_static(
MobileNetV1(class_dim=args.class_dim, scale=1.0)
)
elif args.model == "MobileNetV2":
net = MobileNetV2(class_dim=args.class_dim, scale=1.0)
net = paddle.jit.to_static(
MobileNetV2(class_dim=args.class_dim, scale=1.0)
)
else:
print(
"wrong model name, please try model = MobileNetV1 or MobileNetV2"
Expand Down Expand Up @@ -618,34 +623,37 @@ def predict_static(args, data):
feed={feed_target_names[0]: data},
fetch_list=fetch_targets,
)
paddle.disable_static()
return pred_res[0]


def predict_dygraph(args, data):
paddle.jit.enable_to_static(False)
with base.dygraph.guard(args.place):
if args.model == "MobileNetV1":
model = MobileNetV1(class_dim=args.class_dim, scale=1.0)
elif args.model == "MobileNetV2":
model = MobileNetV2(class_dim=args.class_dim, scale=1.0)
# load dygraph trained parameters
model_dict = paddle.load(args.dy_state_dict_save_path + '.pdparams')
model.set_dict(model_dict)
model.eval()
if args.model == "MobileNetV1":
model = paddle.jit.to_static(
MobileNetV1(class_dim=args.class_dim, scale=1.0)
)
elif args.model == "MobileNetV2":
model = paddle.jit.to_static(
MobileNetV2(class_dim=args.class_dim, scale=1.0)
)
# load dygraph trained parameters
model_dict = paddle.load(args.dy_state_dict_save_path + '.pdparams')
model.set_dict(model_dict)
model.eval()

pred_res = model(base.dygraph.to_variable(data))
pred_res = model(base.dygraph.to_variable(data))

return pred_res.numpy()
return pred_res.numpy()


def predict_dygraph_jit(args, data):
with base.dygraph.guard(args.place):
model = paddle.jit.load(args.model_save_prefix)
model.eval()
model = paddle.jit.load(args.model_save_prefix)
model.eval()

pred_res = model(data)
pred_res = model(data)

return pred_res.numpy()
return pred_res.numpy()


def predict_analysis_inference(args, data):
Expand Down
133 changes: 65 additions & 68 deletions test/dygraph_to_static/test_resnet_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,76 +37,73 @@ def train(to_static, build_strategy=None):
"""
Tests model decorated by `dygraph_to_static_output` in static graph mode. For users, the model is defined in dygraph mode and trained in static graph mode.
"""
with base.dygraph.guard(place):
np.random.seed(SEED)
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)

resnet = ResNet()
if to_static:
resnet = paddle.jit.to_static(resnet, build_strategy=build_strategy)
optimizer = optimizer_setting(parameter_list=resnet.parameters())
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

for epoch in range(epoch_num):
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0

for batch_id in range(100):
start_time = time.time()
img = paddle.to_tensor(
np.random.random([batch_size, 3, 224, 224]).astype(
'float32'
)
)
label = paddle.to_tensor(
np.random.randint(0, 100, [batch_size, 1], dtype='int64')
np.random.seed(SEED)
paddle.seed(SEED)
paddle.framework.random._manual_program_seed(SEED)

resnet = ResNet()
if to_static:
resnet = paddle.jit.to_static(resnet, build_strategy=build_strategy)
optimizer = optimizer_setting(parameter_list=resnet.parameters())
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)

for epoch in range(epoch_num):
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0

for batch_id in range(100):
start_time = time.time()
img = paddle.to_tensor(
np.random.random([batch_size, 3, 224, 224]).astype('float32')
)
label = paddle.to_tensor(
np.random.randint(0, 100, [batch_size, 1], dtype='int64')
)
img.stop_gradient = True
label.stop_gradient = True

with paddle.amp.auto_cast():
pred = resnet(img)
# FIXME(Aurelius84): The following cross_entropy seems to bring out a
# precision problem, need to figure out the underlying reason.
# If we remove it, the loss between dygraph and dy2stat is exactly same.
loss = paddle.nn.functional.cross_entropy(
input=pred,
label=label,
reduction='none',
use_softmax=False,
)
img.stop_gradient = True
label.stop_gradient = True

with paddle.amp.auto_cast():
pred = resnet(img)
# FIXME(Aurelius84): The following cross_entropy seems to bring out a
# precision problem, need to figure out the underlying reason.
# If we remove it, the loss between dygraph and dy2stat is exactly same.
loss = paddle.nn.functional.cross_entropy(
input=pred,
label=label,
reduction='none',
use_softmax=False,
avg_loss = paddle.mean(x=pred)
acc_top1 = paddle.static.accuracy(input=pred, label=label, k=1)
acc_top5 = paddle.static.accuracy(input=pred, label=label, k=5)

scaled = scaler.scale(avg_loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
resnet.clear_gradients()

total_loss += avg_loss
total_acc1 += acc_top1
total_acc5 += acc_top5
total_sample += 1

end_time = time.time()
if batch_id % 2 == 0:
print(
"epoch %d | batch step %d, loss %0.3f, acc1 %0.3f, acc5 %0.3f, time %f"
% (
epoch,
batch_id,
total_loss.numpy() / total_sample,
total_acc1.numpy() / total_sample,
total_acc5.numpy() / total_sample,
end_time - start_time,
)
avg_loss = paddle.mean(x=pred)
acc_top1 = paddle.static.accuracy(input=pred, label=label, k=1)
acc_top5 = paddle.static.accuracy(input=pred, label=label, k=5)

scaled = scaler.scale(avg_loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
resnet.clear_gradients()

total_loss += avg_loss
total_acc1 += acc_top1
total_acc5 += acc_top5
total_sample += 1

end_time = time.time()
if batch_id % 2 == 0:
print(
"epoch %d | batch step %d, loss %0.3f, acc1 %0.3f, acc5 %0.3f, time %f"
% (
epoch,
batch_id,
total_loss.numpy() / total_sample,
total_acc1.numpy() / total_sample,
total_acc5.numpy() / total_sample,
end_time - start_time,
)
)
if batch_id == 10:
break
)
if batch_id == 10:
break

return total_loss.numpy()

Expand Down
5 changes: 2 additions & 3 deletions test/dygraph_to_static/test_resnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ def __init__(self, layers=50, class_dim=102):
),
)

@paddle.jit.to_static
def forward(self, inputs):
y = self.conv(inputs)
y = self.pool2d_max(y)
Expand Down Expand Up @@ -280,7 +279,7 @@ def do_train(self, to_static):
dataset, batch_size=batch_size, drop_last=True
)

resnet = ResNet()
resnet = paddle.jit.to_static(ResNet())
optimizer = optimizer_setting(parameter_list=resnet.parameters())

for epoch in range(epoch_num):
Expand Down Expand Up @@ -339,7 +338,7 @@ def do_train(self, to_static):
def predict_dygraph(self, data):
paddle.jit.enable_to_static(False)
paddle.disable_static(place)
resnet = ResNet()
resnet = paddle.jit.to_static(ResNet())

model_dict = paddle.load(self.dy_state_dict_save_path + '.pdparams')
resnet.set_dict(model_dict)
Expand Down
Loading

0 comments on commit f2274dc

Please sign in to comment.