Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CPU offload + disk offload tests #27204

Merged
merged 1 commit into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4520,7 +4520,9 @@ def expand_device_map(device_map, param_names):
"""
new_device_map = {}
for module, device in device_map.items():
new_device_map.update({p: device for p in param_names if p == module or p.startswith(f"{module}.")})
new_device_map.update(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch!

{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
)
return new_device_map


Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,6 +1125,11 @@ def __init__(self, config: BartConfig):
# Initialize weights and apply final processing
self.post_init()

def _tie_weights(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

def get_input_embeddings(self):
return self.shared

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2312,6 +2312,11 @@ def set_input_embeddings(self, value):
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared

def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

def get_encoder(self):
return self.encoder

Expand Down
14 changes: 14 additions & 0 deletions src/transformers/models/longt5/modeling_longt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1783,6 +1783,11 @@ def set_input_embeddings(self, new_embeddings):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)

def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

def get_encoder(self):
return self.encoder

Expand Down Expand Up @@ -1937,6 +1942,11 @@ def set_input_embeddings(self, new_embeddings):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)

def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

Expand Down Expand Up @@ -2170,6 +2180,10 @@ def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)

def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)

def get_encoder(self):
return self.encoder

Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/m2m_100/modeling_m2m_100.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,6 +1103,11 @@ def set_input_embeddings(self, value):
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared

def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

def get_encoder(self):
return self.encoder

Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/nllb_moe/modeling_nllb_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,6 +1471,11 @@ def set_input_embeddings(self, value):
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared

def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

def get_encoder(self):
return self.encoder

Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/plbart/modeling_plbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,11 @@ def set_input_embeddings(self, value):
self.encoder.embed_tokens = self.shared
self.decoder.embed_tokens = self.shared

def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

def get_encoder(self):
return self.encoder

Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/seamless_m4t/modeling_seamless_m4t.py
Original file line number Diff line number Diff line change
Expand Up @@ -4125,6 +4125,11 @@ def set_input_embeddings(self, value):
self.text_decoder.embed_tokens = value
self.shared = value

def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)

@add_start_docstrings_to_model_forward(M4T_MODEL_INPUTS_DOCSTRING)
def forward(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,11 @@ def set_input_embeddings(self, new_embeddings):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)

def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

def get_encoder(self):
return self.encoder

Expand Down Expand Up @@ -1505,6 +1510,11 @@ def set_input_embeddings(self, new_embeddings):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)

def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

Expand Down Expand Up @@ -1807,6 +1817,10 @@ def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)

def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)

def get_encoder(self):
return self.encoder

Expand Down
19 changes: 19 additions & 0 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,11 @@ def set_input_embeddings(self, new_embeddings):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)

def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

def get_encoder(self):
return self.encoder

Expand Down Expand Up @@ -1620,6 +1625,11 @@ def set_input_embeddings(self, new_embeddings):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)

def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

Expand Down Expand Up @@ -1920,6 +1930,10 @@ def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)

def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)

def get_encoder(self):
return self.encoder

Expand Down Expand Up @@ -2152,6 +2166,11 @@ def set_input_embeddings(self, new_embeddings):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)

def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

def get_encoder(self):
return self.encoder

Expand Down
23 changes: 23 additions & 0 deletions src/transformers/models/umt5/modeling_umt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,12 @@ def set_input_embeddings(self, new_embeddings):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)

# Copied from transformers.models.t5.modeling_t5.T5Model._tie_weights
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

# Copied from transformers.models.t5.modeling_t5.T5Model.get_encoder
def get_encoder(self):
return self.encoder
Expand Down Expand Up @@ -1142,6 +1148,12 @@ def set_input_embeddings(self, new_embeddings):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)

# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._tie_weights
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

# Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.set_output_embeddings
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
Expand Down Expand Up @@ -1380,6 +1392,11 @@ def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)

# Copied from transformers.models.t5.modeling_t5.T5EncoderModel._tie_weights
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)

# Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_encoder
def get_encoder(self):
return self.encoder
Expand Down Expand Up @@ -1615,6 +1632,12 @@ def set_input_embeddings(self, new_embeddings):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)

# Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering._tie_weights
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)

# Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_encoder
def get_encoder(self):
return self.encoder
Expand Down
6 changes: 5 additions & 1 deletion tests/models/vitdet/test_modeling_vitdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,11 @@ def test_cpu_offload(self):

# TODO: Fix me (once this model gets more usage)
@unittest.skip("Does not work on the tiny model as we keep hitting edge cases.")
def test_disk_offload(self):
def test_disk_offload_bin(self):
super().test_disk_offload()

@unittest.skip("Does not work on the tiny model as we keep hitting edge cases.")
def test_disk_offload_safetensors(self):
super().test_disk_offload()

# TODO: Fix me (once this model gets more usage)
Expand Down
6 changes: 5 additions & 1 deletion tests/models/whisper/test_modeling_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1787,7 +1787,11 @@ def test_cpu_offload(self):
pass

@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
def test_disk_offload(self):
def test_disk_offload_bin(self):
pass

@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
def test_disk_offload_safetensors(self):
pass

@unittest.skip(reason="Some undefined behavior encountered with tiny versions of this model. Skip for now.")
Expand Down
36 changes: 34 additions & 2 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2578,7 +2578,7 @@ def check_device_map_is_respected(self, model, device_map):
@require_accelerate
@mark.accelerate_tests
@require_torch_gpu
def test_disk_offload(self):
def test_disk_offload_bin(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
Expand All @@ -2593,7 +2593,7 @@ def test_disk_offload(self):

model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)

with self.assertRaises(ValueError):
max_size = int(self.model_split_percents[0] * model_size)
Expand All @@ -2613,6 +2613,38 @@ def test_disk_offload(self):

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

@require_accelerate
@mark.accelerate_tests
@require_torch_gpu
def test_disk_offload_safetensors(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

for model_class in self.all_model_classes:
if model_class._no_split_modules is None:
continue

inputs_dict_class = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config).eval()
model = model.to(torch_device)
torch.manual_seed(0)
base_output = model(**inputs_dict_class)

model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)

max_size = int(self.model_split_percents[1] * model_size)
max_memory = {0: max_size, "cpu": max_size}

# This doesn't error out as it's in safetensors and doesn't need an offload folder
new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)

self.check_device_map_is_respected(new_model, new_model.hf_device_map)
torch.manual_seed(0)
new_output = new_model(**inputs_dict_class)

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

@require_accelerate
@mark.accelerate_tests
@require_torch_gpu
Expand Down
Loading