Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
add fallback wrapper for speed up (#100)
Browse files Browse the repository at this point in the history
Co-authored-by: 0x45f <wangzhen45@baidu.com>
  • Loading branch information
2742195759 and 0x45f authored Jun 1, 2023
1 parent 66589ed commit bc983ee
Showing 1 changed file with 29 additions and 2 deletions.
31 changes: 29 additions & 2 deletions symbolic_trace/symbolic/compile_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,31 @@
from .interpreter import compile_sir


def clear_eager_tensor_name(output_tensors):
for output_tensor in output_tensors:
output_tensor.name = ""


class FallbackWrapper:
def __init__(self, compile_sir):
self.compile_sir = compile_sir
self.partial_program_layer = None

def __call__(self, *args, **kwargs):
frame_callback = paddle.fluid.core.set_eval_frame(None)
if self.partial_program_layer is None:
outputs = self.compile_sir(*args, **kwargs)
self.partial_program_layer = self.compile_sir.get_concrete_program(
*args, **kwargs
)[1]
else:
# Speed up Resnet from 0.0068 --> 0.0057
outputs = self.partial_program_layer(*args, **kwargs)
clear_eager_tensor_name(outputs)
paddle.fluid.core.set_eval_frame(frame_callback)
return outputs


@Singleton
class CompileSIRCache(Cache):
def __init__(self):
Expand All @@ -16,6 +41,8 @@ def key_fn(self, context, sir_name):
return hash_key

def value_fn(self, context, sir_name):
return paddle.jit.to_static(
compile_sir(context, sir_name), enable_fallback=False
return FallbackWrapper(
paddle.jit.to_static(
compile_sir(context, sir_name), enable_fallback=False
)
)

0 comments on commit bc983ee

Please sign in to comment.