Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kalash x transformers #137

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
2 changes: 1 addition & 1 deletion ptls/frames/coles/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .coles_dataset import ColesDataset, ColesIterableDataset
from .coles_supervised_dataset import ColesSupervisedDataset, ColesSupervisedIterableDataset
from .coles_module import CoLESModule
from .coles_module import CoLESModule, CoLESModuleWarmup
from .coles_supervised_module import ColesSupervisedModule

39 changes: 38 additions & 1 deletion ptls/frames/coles/coles_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,41 @@ def shared_step(self, x, y):
y_h = self(x)
if self._head is not None:
y_h = self._head(y_h)
return y_h, y
return y_h, y

class CoLESModuleWarmup(CoLESModule):
def __init__(self,
seq_encoder: SeqEncoderContainer = None,
head=None,
loss=None,
validation_metric=None,
optimizer_partial=None,
lr_scheduler_partial=None,
warmup_steps = 500,
initial_lr = 0.001):

super().__init__(seq_encoder,
head,
loss,
validation_metric,
optimizer_partial,
lr_scheduler_partial)
self.warmup_steps = warmup_steps
self.initial_lr = initial_lr

def optimizer_step(self,
epoch,
batch_idx,
optimizer,
optimizer_idx,
optimizer_closure,
on_tpu=False,
using_native_amp=False,
using_lbfgs=False):

optimizer.step(closure = optimizer_closure)
if self.trainer.global_step < self.warmup_steps:
lr_scale = min(1., float(self.trainer.global_step + 1) / self.warmup_steps)
for pg in optimizer.param_groups:
pg['lr'] = lr_scale * self.initial_lr

2 changes: 1 addition & 1 deletion ptls/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .seq_encoder import (
RnnEncoder, TransformerEncoder, LongformerEncoder,
RnnSeqEncoder, TransformerSeqEncoder, LongformerSeqEncoder, AggFeatureSeqEncoder,
GptEncoder
GptEncoder, XTransformerEncoder, XTransformerSeqEncoder
)

from .pb import PBDropout, PBLinear, PBL2Norm, PBLayerNorm, PBReLU
Expand Down
1 change: 1 addition & 0 deletions ptls/nn/seq_encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .longformer_encoder import LongformerEncoder
from .gpt_encoder import GptEncoder
from .custom_encoder import Encoder
from .x_transformer import XTransformerEncoder, XTransformerSeqEncoder

from .containers import RnnSeqEncoder, TransformerSeqEncoder, LongformerSeqEncoder, CustomSeqEncoder
from .agg_feature_seq_encoder import AggFeatureSeqEncoder
Expand Down
Loading
Loading