From bc983ee560666e98e5d2bf241cd9b5b4dd3ab9f7 Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 1 Jun 2023 19:10:11 +0800 Subject: [PATCH] add fallback wrapper for speed up (#100) Co-authored-by: 0x45f --- symbolic_trace/symbolic/compile_cache.py | 31 ++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/symbolic_trace/symbolic/compile_cache.py b/symbolic_trace/symbolic/compile_cache.py index d71fb04c8..176772a4b 100644 --- a/symbolic_trace/symbolic/compile_cache.py +++ b/symbolic_trace/symbolic/compile_cache.py @@ -4,6 +4,31 @@ from .interpreter import compile_sir +def clear_eager_tensor_name(output_tensors): + for output_tensor in output_tensors: + output_tensor.name = "" + + +class FallbackWrapper: + def __init__(self, compile_sir): + self.compile_sir = compile_sir + self.partial_program_layer = None + + def __call__(self, *args, **kwargs): + frame_callback = paddle.fluid.core.set_eval_frame(None) + if self.partial_program_layer is None: + outputs = self.compile_sir(*args, **kwargs) + self.partial_program_layer = self.compile_sir.get_concrete_program( + *args, **kwargs + )[1] + else: + # Speed up Resnet from 0.0068 --> 0.0057 + outputs = self.partial_program_layer(*args, **kwargs) + clear_eager_tensor_name(outputs) + paddle.fluid.core.set_eval_frame(frame_callback) + return outputs + + @Singleton class CompileSIRCache(Cache): def __init__(self): @@ -16,6 +41,8 @@ def key_fn(self, context, sir_name): return hash_key def value_fn(self, context, sir_name): - return paddle.jit.to_static( - compile_sir(context, sir_name), enable_fallback=False + return FallbackWrapper( + paddle.jit.to_static( + compile_sir(context, sir_name), enable_fallback=False + ) )