Skip to content

Commit 5f83962

Browse files
committedNov 24, 2024·
Get rid of numpy dtypes in scheduler, prevent numpy dtypes ending up in train checkpoints
1 parent ea524b9 commit 5f83962

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed
 

‎src/open_clip_train/scheduler.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import numpy as np
1+
import math
22

33

44
def assign_learning_rate(optimizer, new_lr):
@@ -18,6 +18,7 @@ def _lr_adjuster(step):
1818
lr = base_lr
1919
assign_learning_rate(optimizer, lr)
2020
return lr
21+
2122
return _lr_adjuster
2223

2324

@@ -33,10 +34,11 @@ def _lr_adjuster(step):
3334
e = step - start_cooldown_step
3435
es = steps - start_cooldown_step
3536
# linear decay if power == 1; polynomial decay otherwise;
36-
decay = (1 - (e/es)) ** cooldown_power
37+
decay = (1 - (e / es)) ** cooldown_power
3738
lr = decay * (base_lr - cooldown_end_lr) + cooldown_end_lr
3839
assign_learning_rate(optimizer, lr)
3940
return lr
41+
4042
return _lr_adjuster
4143

4244

@@ -47,7 +49,9 @@ def _lr_adjuster(step):
4749
else:
4850
e = step - warmup_length
4951
es = steps - warmup_length
50-
lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
52+
lr = 0.5 * (1 + math.cos(math.pi * e / es)) * base_lr
5153
assign_learning_rate(optimizer, lr)
5254
return lr
55+
5356
return _lr_adjuster
57+

0 commit comments

Comments
 (0)
Please sign in to comment.