From dd38542c7097e64cda9156cfc844e62456db7781 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 13 Jun 2023 16:53:18 -0700 Subject: [PATCH 1/3] add precision context to gen callback --- llmfoundry/callbacks/generate_callback.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/llmfoundry/callbacks/generate_callback.py b/llmfoundry/callbacks/generate_callback.py index 34a21c734f..215ff9e67a 100644 --- a/llmfoundry/callbacks/generate_callback.py +++ b/llmfoundry/callbacks/generate_callback.py @@ -6,7 +6,7 @@ import torch import wandb -from composer.core import Callback, State +from composer.core import Callback, State, _get_precision_context from composer.loggers import Logger, WandBLogger from composer.utils import dist, ensure_tuple from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -72,15 +72,16 @@ def generate(self, state: State, logger: Logger): # dummy forward call needed for FSDP to work consistently dummy_input = torch.tensor([[0]], dtype=torch.long) dummy_input = device.tensor_to_device(dummy_input) - with torch.no_grad(): - _ = model.model(input_ids=dummy_input) # type: ignore - - output_token_ids = model.model.generate( # type: ignore - input_ids=tokenized_input['input_ids'], - attention_mask=tokenized_input['attention_mask'], - synced_gpus=True, - **self.generate_kwargs, - ) + with _get_precision_context(state.precision, False): + with torch.no_grad(): + _ = model.model(input_ids=dummy_input) # type: ignore + + output_token_ids = model.model.generate( # type: ignore + input_ids=tokenized_input['input_ids'], + attention_mask=tokenized_input['attention_mask'], + synced_gpus=True, + **self.generate_kwargs, + ) if dist.get_global_rank() == 0: if self.wandb_logger is not None: From b9214cf5dfabe446bed4e3c69a300473669a97f4 Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 13 Jun 2023 18:09:23 -0700 Subject: [PATCH 2/3] typo --- llmfoundry/callbacks/generate_callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llmfoundry/callbacks/generate_callback.py b/llmfoundry/callbacks/generate_callback.py index 215ff9e67a..5408acf7a1 100644 --- a/llmfoundry/callbacks/generate_callback.py +++ b/llmfoundry/callbacks/generate_callback.py @@ -6,7 +6,7 @@ import torch import wandb -from composer.core import Callback, State, _get_precision_context +from composer.core import Callback, State, get_precision_context from composer.loggers import Logger, WandBLogger from composer.utils import dist, ensure_tuple from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -72,7 +72,7 @@ def generate(self, state: State, logger: Logger): # dummy forward call needed for FSDP to work consistently dummy_input = torch.tensor([[0]], dtype=torch.long) dummy_input = device.tensor_to_device(dummy_input) - with _get_precision_context(state.precision, False): + with get_precision_context(state.precision, False): with torch.no_grad(): _ = model.model(input_ids=dummy_input) # type: ignore From 1486a7c86a26b7418271c768902030629a295aaf Mon Sep 17 00:00:00 2001 From: Daniel King Date: Tue, 13 Jun 2023 18:53:51 -0700 Subject: [PATCH 3/3] fix --- llmfoundry/callbacks/generate_callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/generate_callback.py b/llmfoundry/callbacks/generate_callback.py index 5408acf7a1..89dc4e965e 100644 --- a/llmfoundry/callbacks/generate_callback.py +++ b/llmfoundry/callbacks/generate_callback.py @@ -72,7 +72,7 @@ def generate(self, state: State, logger: Logger): # dummy forward call needed for FSDP to work consistently dummy_input = torch.tensor([[0]], dtype=torch.long) dummy_input = device.tensor_to_device(dummy_input) - with get_precision_context(state.precision, False): + with get_precision_context(state.precision): with torch.no_grad(): _ = model.model(input_ids=dummy_input) # type: ignore