This repository was archived by the owner on Jul 24, 2024. It is now read-only.
File tree Expand file tree Collapse file tree 2 files changed +5
-4
lines changed
Expand file tree Collapse file tree 2 files changed +5
-4
lines changed Original file line number Diff line number Diff line change 33import math
44
55import torch
6- import intel_extension_for_pytorch as ipex
6+ # import intel_extension_for_pytorch as ipex
77import numpy as np
88from transformers import (
99 AutoModelForCausalLM ,
@@ -76,14 +76,15 @@ def inference(self, backend):
7676 # self.flops_per_sample = get_macs(self.model, self.in_shape, backend) * 2
7777 self .model = backend .prepare_eval_transformer (self .model )
7878
79- self .model .eval ()
8079 enabled = backend .dtype != torch .float32
8180
8281 n_items = 0
8382 outputs = []
8483 fw_times = []
8584
86- self .model .eval ()
85+
86+ # Ipex gives error with eval, other backends have no effect
87+ # self.model.eval()
8788 for i in range (self .n_iter ):
8889 print (f"Epoch { i + 1 } /{ self .n_iter } " )
8990 cast = torch .autocast (enabled = enabled , device_type = backend .device_name )
Original file line number Diff line number Diff line change @@ -132,7 +132,7 @@ def prepare_eval_transformer(self, model):
132132 model = model .to (memory_format = torch .channels_last )
133133
134134 model .to (self .device )
135- with torch .inference_mode ():
135+ with torch .no_grad ():
136136 model .eval ()
137137 return self ._compile_transformer_model (
138138 self .compile_mode , model , dtype = self .dtype
You can’t perform that action at this time.
0 commit comments