-
Notifications
You must be signed in to change notification settings - Fork 165
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PPSCI Export&Infer No.23】viv #832
【PPSCI Export&Infer No.23】viv #832
Conversation
Thanks for your contribution! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感谢提交PR,辛苦修改一下
examples/fsi/conf/viv.yaml
Outdated
export_path: ./inference/viv | ||
pdmodel_path: ${INFER.export_path}.pdmodel | ||
pdpiparams_path: ${INFER.export_path}.pdiparams | ||
input_keys: ["t_f"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
input_keys可以使用omegaconf的引用语法:input_keys: ${MODEL.input_keys}
,保持跟MODEL.inputs_keys字段一致
examples/fsi/viv.py
Outdated
def export(cfg: DictConfig): | ||
# set model | ||
model = ppsci.arch.MLP(**cfg.MODEL) | ||
|
||
# initialize equation | ||
equation = {"VIV": ppsci.equation.Vibration(2, -4, 0)} | ||
|
||
# initialize solver | ||
solver = ppsci.solver.Solver( | ||
model, | ||
equation=equation, | ||
pretrained_model_path=cfg.INFER.pretrained_model_path, | ||
) | ||
# Convert equation to func | ||
funcs = ppsci.lambdify( | ||
solver.equation["VIV"].equations["f"], | ||
solver.model, | ||
list(solver.equation["VIV"].learnable_parameters), | ||
) | ||
|
||
def wrap_prediction_to_dict(instance, func): | ||
def wrapper(instance, *args, **kwargs): | ||
result = func(*args, **kwargs) | ||
return {"f": result} | ||
|
||
if hasattr(func, "__func__"): | ||
wrapper.__func__ = func.__func__ | ||
return wrapper | ||
|
||
def wrap_forward_methods(instance): | ||
instance.input_keys = cfg.MODEL.input_keys | ||
instance.output_keys = ["f"] | ||
for attr_name in dir(instance): | ||
if attr_name == "forward": | ||
attr = getattr(instance, attr_name) | ||
setattr(instance, attr_name, wrap_prediction_to_dict(instance, attr)) | ||
return instance | ||
|
||
eqn = wrap_forward_methods(funcs) | ||
# Combine the two instances | ||
models = ppsci.arch.ModelList((solver.model, eqn)) | ||
# export models | ||
from paddle.static import InputSpec | ||
|
||
input_spec = [ | ||
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys}, | ||
] | ||
|
||
from paddle import jit | ||
|
||
jit.enable_to_static(True) | ||
|
||
static_model = jit.to_static( | ||
models, | ||
input_spec=input_spec, | ||
full_graph=True, | ||
) | ||
|
||
jit.save(static_model, cfg.INFER.export_path, skip_prune_program=True) | ||
|
||
jit.enable_to_static(False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我尝试优化了下你的代码,可以按照如下修改,配合这个PR就能跑通了: #835
def export(cfg: DictConfig): | |
# set model | |
model = ppsci.arch.MLP(**cfg.MODEL) | |
# initialize equation | |
equation = {"VIV": ppsci.equation.Vibration(2, -4, 0)} | |
# initialize solver | |
solver = ppsci.solver.Solver( | |
model, | |
equation=equation, | |
pretrained_model_path=cfg.INFER.pretrained_model_path, | |
) | |
# Convert equation to func | |
funcs = ppsci.lambdify( | |
solver.equation["VIV"].equations["f"], | |
solver.model, | |
list(solver.equation["VIV"].learnable_parameters), | |
) | |
def wrap_prediction_to_dict(instance, func): | |
def wrapper(instance, *args, **kwargs): | |
result = func(*args, **kwargs) | |
return {"f": result} | |
if hasattr(func, "__func__"): | |
wrapper.__func__ = func.__func__ | |
return wrapper | |
def wrap_forward_methods(instance): | |
instance.input_keys = cfg.MODEL.input_keys | |
instance.output_keys = ["f"] | |
for attr_name in dir(instance): | |
if attr_name == "forward": | |
attr = getattr(instance, attr_name) | |
setattr(instance, attr_name, wrap_prediction_to_dict(instance, attr)) | |
return instance | |
eqn = wrap_forward_methods(funcs) | |
# Combine the two instances | |
models = ppsci.arch.ModelList((solver.model, eqn)) | |
# export models | |
from paddle.static import InputSpec | |
input_spec = [ | |
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys}, | |
] | |
from paddle import jit | |
jit.enable_to_static(True) | |
static_model = jit.to_static( | |
models, | |
input_spec=input_spec, | |
full_graph=True, | |
) | |
jit.save(static_model, cfg.INFER.export_path, skip_prune_program=True) | |
jit.enable_to_static(False) | |
def export(cfg: DictConfig): | |
from paddle import nn | |
from paddle.static import InputSpec | |
# set model | |
model = ppsci.arch.MLP(**cfg.MODEL) | |
# initialize equation | |
equation = {"VIV": ppsci.equation.Vibration(2, -4, 0)} | |
# initialize solver | |
solver = ppsci.solver.Solver( | |
model, | |
equation=equation, | |
pretrained_model_path=cfg.INFER.pretrained_model_path, | |
) | |
# Convert equation to func | |
f_func = ppsci.lambdify( | |
solver.equation["VIV"].equations["f"], | |
solver.model, | |
list(solver.equation["VIV"].learnable_parameters), | |
) | |
class Wrapped_Model(nn.Layer): | |
def __init__(self, model, func): | |
super().__init__() | |
self.model = model | |
self.func = func | |
def forward(self, x): | |
model_out = self.model(x) | |
func_out = self.func(x) | |
return {**model_out, "f": func_out} | |
solver.model = Wrapped_Model(model, f_func) | |
# export models | |
input_spec = [ | |
{key: InputSpec([None, 1], "float32", name=key) for key in model.input_keys}, | |
] | |
solver.export(input_spec, cfg.INFER.export_path, skip_prune_program=True) |
examples/fsi/conf/viv.yaml
Outdated
|
||
# inference settings | ||
INFER: | ||
pretrained_model_path: "./viv_pretrained" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里可以改为:https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdparams
,我在 https://github.com/PaddlePaddle/PaddleScience/pull/834/files#diff-68effb6a6b046bfbf5b6cb8b843f284e8779eadfa5977b66870a783d6d61ebe7R105-R118 支持了自动下载方程参数文件的功能
examples/fsi/conf/viv.yaml
Outdated
pdmodel_path: ${INFER.export_path}.pdmodel | ||
pdpiparams_path: ${INFER.export_path}.pdiparams | ||
input_keys: ["t_f"] | ||
output_keys: ["eta", 'f'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"f"双引号
docs/zh/examples/viv.md
Outdated
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdeqn | ||
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/viv/viv_pretrained.pdparams |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两行在pretrained_model_path改完之后就可以删除了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的好的 改好了 辛苦老师了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* eadd export and inference for viv * add doc * fix viv export&infer * Rewriting function * fix viv export&infer
PR types
Others
PR changes
Others
Describe
因为需要将equation导出 就在 ComposedNode 加入了paddle.nn.LayerList,不然导出找不到里面的参数
equation导出时,模型裁剪也会出问题,我就先跳过了