Skip to content

Commit

Permalink
tests: e2e convergence with sgd
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurBook committed Aug 10, 2024
1 parent efdd71b commit 94ee1e2
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

import random
from typing import TYPE_CHECKING

import numpy as np
import pytest

from skinnygrad import runtime
Expand All @@ -22,3 +24,9 @@ def engine(request: pytest.FixtureRequest) -> object:
if request.param == runtime.NumPyEngine:
return runtime.NumPyEngine()
raise ValueError("invalid internal test config")


@pytest.fixture(autouse=True)
def set_random_seeds(seed: float = 42):
np.random.seed(seed)
random.seed(seed)
22 changes: 22 additions & 0 deletions tests/test_convergence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from skinnygrad import optim, tensors


def test_sgd_training_loop():
a = tensors.Tensor.random_normal(1, 5, mean=5, var=2, requires_grad=False)
b = tensors.Tensor.random_normal(5, 1, mean=5, var=2, requires_grad=True)
opt = optim.SGD((a, b), lr=0.001)

initial_loss, loss_val = None, None
for _ in range(1000): # Run for 100 iterations
with opt:
loss = ((a @ b) - 10) ** 2
loss.backprop()

loss_val = loss.sum().realize()
assert isinstance(loss_val, float)
if initial_loss is None:
initial_loss = loss_val

assert initial_loss is not None
assert loss_val is not None
assert 0 <= loss_val < 0.01 < initial_loss # Ensure loss decreases and is not zero

0 comments on commit 94ee1e2

Please sign in to comment.