Skip to content

Commit

Permalink
[BUG!] Revert hack that leads to OOM during fine-tuning (#3858)
Browse files Browse the repository at this point in the history
  • Loading branch information
arnavgarg1 authored Jan 4, 2024
1 parent d45566b commit 29ad837
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 55 deletions.
35 changes: 0 additions & 35 deletions ludwig/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,6 @@ def prepare(
"""
pass

def eval(self, model: nn.Module):
model.eval()

def train(self, model: nn.Module, prev_model_training_mode: bool = None):
if prev_model_training_mode is not None:
model.train(prev_model_training_mode)
else:
model.train()

def prepare_for_inference(self, model: nn.Module) -> nn.Module:
return model

Expand Down Expand Up @@ -207,10 +198,6 @@ def replace_model_from_serialization(cls, state: nn.Module | tuple[nn.Module, li


class LocalStrategy(DistributedStrategy):
def __init__(self):
super().__init__()
self.module_name_to_prev_training_mode = {}

def prepare(
self,
model: nn.Module,
Expand All @@ -219,28 +206,6 @@ def prepare(
) -> tuple[nn.Module, Optimizer]:
return model, create_optimizer(model, trainer_config.optimizer, base_learning_rate)

def eval(self, model):
# HACK(geoffrey): use vanilla model.eval()
# when https://github.com/huggingface/transformers/issues/28023 is resolved.
for module_name, module in model.named_modules():
self.module_name_to_prev_training_mode[module_name] = module.training
module.eval()

def train(self, model, prev_model_training_mode=None):
"""If mode is None, restore previous training mode."""
# HACK(geoffrey): use vanilla model.train(prev_model_training_mode)
# when https://github.com/huggingface/transformers/issues/28023 is resolved.
# This hack ignores module.training updates if the model is already in training mode
# (to avoid touching LoRA configuration). Otherwise, the model was in eval mode, so we
# restore the previous training mode. We do not use prev_model_training_mode because we store the history
# as a dictionary mapping to training mode to each module.
if model.training:
return

for module_name, module in model.named_modules():
if module_name in self.module_name_to_prev_training_mode:
module.train(self.module_name_to_prev_training_mode[module_name])

def size(self) -> int:
return 1

Expand Down
34 changes: 16 additions & 18 deletions ludwig/models/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def __init__(

def batch_predict(self, dataset: Dataset, dataset_name: str = None, collect_logits: bool = False):
self.dist_model = self._distributed.to_device(self.dist_model)
prev_model_training_mode = self.dist_model.training
self._distributed.eval(self.dist_model)
prev_model_training_mode = self.dist_model.training # store previous model training mode
self.dist_model.eval() # set model to eval mode

with torch.no_grad():
with dataset.initialize_batcher(self._batch_size, should_shuffle=False) as batcher:
Expand All @@ -151,13 +151,13 @@ def batch_predict(self, dataset: Dataset, dataset_name: str = None, collect_logi
# consolidate predictions from each batch to a single tensor
self._concat_preds(predictions)

self._distributed.train(self.dist_model, prev_model_training_mode)
self.dist_model.train(prev_model_training_mode)

return from_numpy_dataset(predictions)

def predict_single(self, batch, collect_logits: bool = False):
prev_model_training_mode = self.dist_model.training
self._distributed.eval(self.dist_model)
prev_model_training_mode = self.dist_model.training # store previous model training mode
self.dist_model.eval() # set model to eval mode

with torch.no_grad():
predictions = defaultdict(list)
Expand All @@ -167,8 +167,8 @@ def predict_single(self, batch, collect_logits: bool = False):
)
self._concat_preds(predictions)

self._distributed.train(self.dist_model, prev_model_training_mode)

# reset model to its original training mode
self.dist_model.train(prev_model_training_mode)
return from_numpy_dataset(predictions)

def _predict(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
Expand Down Expand Up @@ -217,8 +217,8 @@ def batch_evaluation(self, dataset, collect_predictions=False, collect_logits=Fa
collect_predictions, collect_logits.
"""
self.dist_model = self._distributed.to_device(self.dist_model)
prev_model_training_mode = self.dist_model.training
self._distributed.eval(self.dist_model)
prev_model_training_mode = self.dist_model.training # store previous model training mode
self.dist_model.eval() # set model to eval mode

with torch.no_grad():
with dataset.initialize_batcher(
Expand Down Expand Up @@ -289,16 +289,16 @@ def batch_evaluation(self, dataset, collect_predictions=False, collect_logits=Fa
metrics = self.model.get_metrics()
self.model.reset_metrics()

self._distributed.train(self.dist_model, prev_model_training_mode)
self.dist_model.train(prev_model_training_mode) # Restores previous model training mode.

return metrics, from_numpy_dataset(predictions)

def batch_collect_activations(self, layer_names, dataset, bucketing_field=None):
if bucketing_field:
raise ValueError("BucketedBatcher is not supported yet")

prev_model_training_mode = self.dist_model.training
self._distributed.eval(self.dist_model)
prev_model_training_mode = self.dist_model.training # store previous model training mode
self.dist_model.eval() # set model to eval mode

with torch.no_grad():
with dataset.initialize_batcher(
Expand Down Expand Up @@ -328,7 +328,7 @@ def batch_collect_activations(self, layer_names, dataset, bucketing_field=None):

progress_bar.close()

self._distributed.train(self.dist_model, prev_model_training_mode)
self.dist_model.train(prev_model_training_mode) # Restores previous model training mode.

return collected_tensors

Expand Down Expand Up @@ -361,9 +361,8 @@ def batch_evaluation(self, dataset, collect_predictions=False, collect_logits=Fa
dictionary are "inputs", "targets", and "outputs". The values of each of these keys are dictionaries of
feature names to lists of tensors. The tensors are the inputs, targets, and outputs for each batch.
"""
prev_model_training_mode = self.dist_model.training
self._distributed.eval(self.dist_model)

prev_model_training_mode = self.dist_model.training # store previous model training mode
self.dist_model.eval() # set model to eval mode
example_inputs = defaultdict(list)
example_targets = defaultdict(list)
example_outputs = defaultdict(list)
Expand Down Expand Up @@ -455,8 +454,7 @@ def batch_evaluation(self, dataset, collect_predictions=False, collect_logits=Fa
"outputs": example_outputs,
}

self._distributed.train(self.dist_model, prev_model_training_mode)

self.dist_model.train(prev_model_training_mode) # Restores previous model training mode.
return metrics, from_numpy_dataset(predictions), input_target_output_dict


Expand Down
3 changes: 2 additions & 1 deletion ludwig/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,7 +987,8 @@ def train(
# epoch init
start_time = time.time()

self.distributed.train(self.dist_model)
# Reset the metrics at the start of the next epoch
self.dist_model.train() # Sets model to training mode.
self.model.reset_metrics()

self.callback(lambda c: c.on_epoch_start(self, progress_tracker, save_path))
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ torchaudio
torchtext
torchvision
pydantic<2.0
transformers>=4.36.0
transformers>=4.36.2
tifffile
imagecodecs
tokenizers>=0.13.3
Expand Down

0 comments on commit 29ad837

Please sign in to comment.