Skip to content

Commit 69ca640

Browse files
ji-huazhongunit_test
andauthored
Set the dataset format used by test_trainer to float32 (#28920)
Co-authored-by: unit_test <test@unit.com>
1 parent 7252e8d commit 69ca640

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/trainer/test_trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def __init__(self, length=64, seed=42, batch_size=8):
176176
np.random.seed(seed)
177177
sizes = np.random.randint(1, 20, (length // batch_size,))
178178
# For easy batching, we make every batch_size consecutive samples the same size.
179-
self.xs = [np.random.normal(size=(s,)) for s in sizes.repeat(batch_size)]
180-
self.ys = [np.random.normal(size=(s,)) for s in sizes.repeat(batch_size)]
179+
self.xs = [np.random.normal(size=(s,)).astype(np.float32) for s in sizes.repeat(batch_size)]
180+
self.ys = [np.random.normal(size=(s,)).astype(np.float32) for s in sizes.repeat(batch_size)]
181181

182182
def __len__(self):
183183
return self.length
@@ -547,7 +547,7 @@ def test_trainer_with_datasets(self):
547547

548548
np.random.seed(42)
549549
x = np.random.normal(size=(64,)).astype(np.float32)
550-
y = 2.0 * x + 3.0 + np.random.normal(scale=0.1, size=(64,))
550+
y = 2.0 * x + 3.0 + np.random.normal(scale=0.1, size=(64,)).astype(np.float32)
551551
train_dataset = datasets.Dataset.from_dict({"input_x": x, "label": y})
552552

553553
# Base training. Should have the same results as test_reproducible_training

0 commit comments

Comments
 (0)