-
Notifications
You must be signed in to change notification settings - Fork 431
/
Copy pathtest_generate.py
113 lines (90 loc) · 4.95 KB
/
test_generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
import math
from typing import Optional
from unittest.mock import Mock
import pytest
import torch
from packaging import version
from composer.callbacks import Generate
from composer.trainer import Trainer
from composer.utils import dist
from tests.common.datasets import dummy_gpt_lm_dataloader
from tests.common.markers import device, world_size
from tests.common.models import configure_tiny_gpt2_hf_model
@device('cpu', 'gpu')
@pytest.mark.parametrize('use_fsdp', [True, False])
@world_size(1, 2)
class TestGenerate():
def _check_test_params(self, device, world_size, use_fsdp) -> None:
if use_fsdp and version.parse(torch.__version__) < version.parse('1.13.0'):
pytest.skip('FSDP requires torch >= 1.13.0')
if device == 'cpu' and use_fsdp:
pytest.skip('FSDP is not supported on CPU.')
if world_size == 1 and use_fsdp:
pytest.xfail((
'Generation with world size 1 and FSDP fails with '
'`RuntimeError: The tensor has a non-zero number of elements, '
'but its data is not allocated yet. Caffe2 uses a lazy allocation, '
'so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.` '
'This issue is resolved with world size > 1 by a dummy call to forward (see HuggingFaceModel.dummy_forward_called), '
'but for some reason fails with world size 1.'))
if device == 'cpu' and world_size > 1:
pytest.xfail(
'GPT2 is not currently supported with DDP. See https://github.com/huggingface/transformers/issues/22482 for more details.'
)
def _create_trainer(self, device, max_duration, use_fsdp, generate_cb: Optional[Generate] = None) -> Trainer:
return Trainer(model=configure_tiny_gpt2_hf_model(),
train_dataloader=dummy_gpt_lm_dataloader(),
device=device,
max_duration=max_duration,
callbacks=generate_cb,
fsdp_config={'sharding_strategy': 'FULL_SHARD'} if use_fsdp else None)
def test_no_effect_on_training(self, device, world_size, use_fsdp):
torch.manual_seed(0)
self._check_test_params(device, world_size, use_fsdp)
max_duration = '6ba'
trainer_ref = self._create_trainer(device, max_duration, use_fsdp)
trainer_ref.fit()
trainer_generate = self._create_trainer(device,
max_duration,
use_fsdp,
generate_cb=Generate(prompts=['a', 'bc', 'defg'],
interval=f'2ba',
batch_size=2,
max_new_tokens=5))
trainer_generate.fit()
model_ref = trainer_ref.state.model
model_generate = trainer_generate.state.model
# Assert that the models should be equivalent
assert model_ref is not model_generate, 'Same model should not be compared.'
for param1, param2 in zip(model_ref.parameters(), model_generate.parameters()):
torch.testing.assert_close(param1, param2)
def test_calls(self, device, world_size, use_fsdp):
self._check_test_params(device, world_size, use_fsdp)
# Create generate callback
prompts = ['a', 'bc', 'defg']
prompt_batch_size = 2
gen_interval = 2
generate_cb = Generate(prompts, interval=f'{gen_interval}ba', batch_size=prompt_batch_size, max_new_tokens=5)
# Create trainer
train_batches = 6
trainer = self._create_trainer(device, f'{train_batches}ba', use_fsdp, generate_cb)
# Mock methods
state = trainer.state
model = state.model.module if state.is_model_ddp else state.model
model.generate = Mock(wraps=model.generate) # type: ignore
generate_cb.generate = Mock(wraps=generate_cb.generate)
trainer.logger.log_table = Mock()
trainer.fit()
expected_cb_call_count = math.ceil(train_batches / gen_interval)
# Assert that the generate callback has been called the correct number of times.
assert generate_cb.generate.call_count == expected_cb_call_count
# Assert that model.generate has been called the correct number of times to ensure that prompt batching is correct.
assert model.generate.call_count == math.ceil( # type: ignore
len(prompts) / prompt_batch_size) * expected_cb_call_count
# Assert that log_table is called on the 0th rank only.
if dist.get_global_rank() == 0:
assert trainer.logger.log_table.call_count == expected_cb_call_count
else:
trainer.logger.log_table.assert_not_called()