diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 91c985538687..21b042698652 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -645,7 +645,7 @@ def forward( # clamp inf values to enable fp16 training if self.ort: - #Remove data-based control flow for static graph + # Remove data-based control flow for static graph if hidden_states.dtype == torch.float16: clamp_value = torch.where(torch.isinf(hidden_states).any(), torch.finfo(hidden_states.dtype).max - 1000, torch.finfo(hidden_states.dtype).max) @@ -679,7 +679,7 @@ def forward( # clamp inf values to enable fp16 training if self.ort: - #Remove data-based control flow for static graph + # Remove data-based control flow for static graph if hidden_states.dtype == torch.float16: clamp_value = torch.where(torch.isinf(hidden_states).any(), torch.finfo(hidden_states.dtype).max - 1000, torch.finfo(hidden_states.dtype).max) @@ -701,7 +701,7 @@ def forward( # clamp inf values to enable fp16 training if self.ort: - #Remove data-based control flow for static graph + # Remove data-based control flow for static graph if hidden_states.dtype == torch.float16: clamp_value = torch.where(torch.isinf(hidden_states).any(), torch.finfo(hidden_states.dtype).max - 1000, torch.finfo(hidden_states.dtype).max)