1
1
import enum
2
- import torch
2
+ import functools
3
3
import math
4
4
import os
5
+ from collections import OrderedDict
5
6
6
- from torch .nn import CrossEntropyLoss , MSELoss , BCEWithLogitsLoss
7
- from transformers import PreTrainedModel
8
- from transformers .modeling_outputs import SequenceClassifierOutput
9
- from transformers import AutoModelForSequenceClassification
10
- from datasets import load_dataset
11
- import evaluate
12
7
import torch
13
- from transformers import AutoTokenizer , get_linear_schedule_with_warmup , set_seed
14
- from torch .utils .data import DataLoader
15
8
from accelerate import Accelerator
16
9
from accelerate .state import AcceleratorState
17
10
from accelerate .utils .dataclasses import FullyShardedDataParallelPlugin
18
- import functools
19
- from torch .distributed .fsdp import (
20
- FullyShardedDataParallel ,
21
- CPUOffload ,
22
- )
23
- from torch .distributed .fsdp .wrap import (
24
- enable_wrap ,
25
- wrap ,
26
- ModuleWrapPolicy ,
27
- transformer_auto_wrap_policy ,
28
- lambda_auto_wrap_policy ,
29
- _or_policy ,
11
+ from torch .distributed .fsdp .wrap import _or_policy , lambda_auto_wrap_policy , transformer_auto_wrap_policy
12
+ from torch .nn import BCEWithLogitsLoss , CrossEntropyLoss , MSELoss
13
+ from torch .utils .data import DataLoader
14
+ from transformers import (
15
+ AutoModelForSequenceClassification ,
16
+ AutoTokenizer ,
17
+ PreTrainedModel ,
18
+ get_linear_schedule_with_warmup ,
19
+ set_seed ,
30
20
)
31
- from collections import OrderedDict
21
+ from transformers .modeling_outputs import SequenceClassifierOutput
22
+
23
+ import evaluate
24
+ from datasets import load_dataset
32
25
33
26
34
27
class PromptEncoderReparameterizationType (str , enum .Enum ):
@@ -49,8 +42,7 @@ class PromptTuningInit(str, enum.Enum):
49
42
50
43
class PromptEncoder (torch .nn .Module ):
51
44
"""
52
- The prompt encoder network that is used to generate the virtual
53
- token embeddings for p-tuning.
45
+ The prompt encoder network that is used to generate the virtual token embeddings for p-tuning.
54
46
"""
55
47
56
48
def __init__ (self , config ):
@@ -92,13 +84,23 @@ def __init__(self, config):
92
84
)
93
85
94
86
elif self .encoder_type == PromptEncoderReparameterizationType .MLP :
95
- layers = [torch .nn .Linear (self .input_size , self .hidden_size ), torch .nn .ReLU ()]
96
- layers .extend ([torch .nn .Linear (self .hidden_size , self .hidden_size ), torch .nn .ReLU ()])
87
+ layers = [
88
+ torch .nn .Linear (self .input_size , self .hidden_size ),
89
+ torch .nn .ReLU (),
90
+ ]
91
+ layers .extend (
92
+ [
93
+ torch .nn .Linear (self .hidden_size , self .hidden_size ),
94
+ torch .nn .ReLU (),
95
+ ]
96
+ )
97
97
layers .append (torch .nn .Linear (self .hidden_size , self .output_size ))
98
98
self .mlp_head = torch .nn .Sequential (* layers )
99
99
100
100
else :
101
- raise ValueError ("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM." )
101
+ raise ValueError (
102
+ "Prompt encoder type not recognized. " " Please use one of MLP (recommended) or LSTM."
103
+ )
102
104
103
105
def forward (self , indices ):
104
106
input_embeds = self .embedding (indices )
@@ -130,11 +132,15 @@ def __init__(self, config):
130
132
self .trans = torch .nn .Sequential (
131
133
torch .nn .Linear (config ["token_dim" ], config ["prompt_hidden_size" ]),
132
134
torch .nn .Tanh (),
133
- torch .nn .Linear (config ["prompt_hidden_size" ], config ["num_layers" ] * 2 * config ["token_dim" ]),
135
+ torch .nn .Linear (
136
+ config ["prompt_hidden_size" ],
137
+ config ["num_layers" ] * 2 * config ["token_dim" ],
138
+ ),
134
139
)
135
140
else :
136
141
self .embedding = torch .nn .Embedding (
137
- config ["num_virtual_tokens" ], config ["num_layers" ] * 2 * config ["token_dim" ]
142
+ config ["num_virtual_tokens" ],
143
+ config ["num_layers" ] * 2 * config ["token_dim" ],
138
144
)
139
145
140
146
def forward (self , prefix : torch .Tensor ):
@@ -247,8 +253,8 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False):
247
253
248
254
def load_state_dict (self , state_dict , strict : bool = True ):
249
255
"""
250
- Custom load state dict method that only loads prompt table and prompt encoder
251
- parameters. Matching load method for this class' custom state dict method.
256
+ Custom load state dict method that only loads prompt table and prompt encoder parameters. Matching load method
257
+ for this class' custom state dict method.
252
258
"""
253
259
self .prompt_encoder .embedding .load_state_dict ({"weight" : state_dict ["prompt_embeddings" ]}, strict )
254
260
@@ -389,8 +395,8 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False):
389
395
390
396
def load_state_dict (self , state_dict , strict : bool = True ):
391
397
"""
392
- Custom load state dict method that only loads prompt table and prompt encoder
393
- parameters. Matching load method for this class' custom state dict method.
398
+ Custom load state dict method that only loads prompt table and prompt encoder parameters. Matching load method
399
+ for this class' custom state dict method.
394
400
"""
395
401
super ().load_state_dict (state_dict ["prompt_encoder" ], strict )
396
402
self .classifier .load_state_dict (state_dict ["classifier" ], strict )
@@ -528,7 +534,7 @@ def main():
528
534
batch_size = 16
529
535
lr = 5e-3
530
536
num_epochs = 100
531
- device = "cuda"
537
+ # device = "cuda"
532
538
seed = 11
533
539
set_seed (seed )
534
540
@@ -544,7 +550,12 @@ def main():
544
550
545
551
def tokenize_function (examples ):
546
552
# max_length=None => use the model max length (it's actually the default)
547
- outputs = tokenizer (examples ["sentence1" ], examples ["sentence2" ], truncation = True , max_length = None )
553
+ outputs = tokenizer (
554
+ examples ["sentence1" ],
555
+ examples ["sentence2" ],
556
+ truncation = True ,
557
+ max_length = None ,
558
+ )
548
559
return outputs
549
560
550
561
# Apply the method we just defined to all the examples in all the splits of the dataset
@@ -564,10 +575,16 @@ def collate_fn(examples):
564
575
565
576
# Instantiate dataloaders.
566
577
train_dataloader = DataLoader (
567
- tokenized_datasets ["train" ], shuffle = True , collate_fn = collate_fn , batch_size = batch_size
578
+ tokenized_datasets ["train" ],
579
+ shuffle = True ,
580
+ collate_fn = collate_fn ,
581
+ batch_size = batch_size ,
568
582
)
569
583
eval_dataloader = DataLoader (
570
- tokenized_datasets ["validation" ], shuffle = False , collate_fn = collate_fn , batch_size = batch_size
584
+ tokenized_datasets ["validation" ],
585
+ shuffle = False ,
586
+ collate_fn = collate_fn ,
587
+ batch_size = batch_size ,
571
588
)
572
589
573
590
# Instantiate optimizer
@@ -582,9 +599,13 @@ def collate_fn(examples):
582
599
583
600
accelerator .state .fsdp_plugin .auto_wrap_policy = fsdp_auto_wrap_policy (model )
584
601
585
- model , train_dataloader , eval_dataloader , optimizer , lr_scheduler = accelerator .prepare (
586
- model , train_dataloader , eval_dataloader , optimizer , lr_scheduler
587
- )
602
+ (
603
+ model ,
604
+ train_dataloader ,
605
+ eval_dataloader ,
606
+ optimizer ,
607
+ lr_scheduler ,
608
+ ) = accelerator .prepare (model , train_dataloader , eval_dataloader , optimizer , lr_scheduler )
588
609
accelerator .print (model )
589
610
590
611
for epoch in range (num_epochs ):
@@ -616,17 +637,14 @@ def collate_fn(examples):
616
637
accelerator .print (f"epoch { epoch } :" , eval_metric )
617
638
accelerator .print (f"epoch { epoch } train loss:" , total_loss / len (train_dataloader ))
618
639
640
+ from torch .distributed .fsdp .fully_sharded_data_parallel import FullStateDictConfig
619
641
from torch .distributed .fsdp .fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
620
- from torch .distributed .fsdp .fully_sharded_data_parallel import (
621
- BackwardPrefetch ,
622
- CPUOffload ,
623
- FullStateDictConfig ,
624
- ShardingStrategy ,
625
- StateDictType ,
626
- )
642
+ from torch .distributed .fsdp .fully_sharded_data_parallel import StateDictType
627
643
628
644
FSDP .set_state_dict_type (
629
- model , StateDictType .FULL_STATE_DICT , FullStateDictConfig (offload_to_cpu = True , rank0_only = True )
645
+ model ,
646
+ StateDictType .FULL_STATE_DICT ,
647
+ FullStateDictConfig (offload_to_cpu = True , rank0_only = True ),
630
648
)
631
649
state_dict = model .state_dict ()
632
650
state_dict = model .clean_state_dict (state_dict )
0 commit comments