Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
24 changes: 24 additions & 0 deletions machine_learning/neural_network/optimizers/sgd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Stochastic Gradient Descent (SGD) optimizer.
"""

from typing import List

Check failure on line 5 in machine_learning/neural_network/optimizers/sgd.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP035)

machine_learning/neural_network/optimizers/sgd.py:5:1: UP035 `typing.List` is deprecated, use `list` instead


def sgd_update(weights: List[float], grads: List[float], lr: float) -> List[float]:

Check failure on line 8 in machine_learning/neural_network/optimizers/sgd.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP006)

machine_learning/neural_network/optimizers/sgd.py:8:72: UP006 Use `list` instead of `List` for type annotation

Check failure on line 8 in machine_learning/neural_network/optimizers/sgd.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP006)

machine_learning/neural_network/optimizers/sgd.py:8:45: UP006 Use `list` instead of `List` for type annotation

Check failure on line 8 in machine_learning/neural_network/optimizers/sgd.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP006)

machine_learning/neural_network/optimizers/sgd.py:8:25: UP006 Use `list` instead of `List` for type annotation
"""
Update weights using SGD.

Args:
weights (List[float]): Current weights
grads (List[float]): Gradients
lr (float): Learning rate

Returns:
List[float]: Updated weights

Example:
>>> sgd_update([0.5, -0.2], [0.1, -0.1], 0.01)
[0.499, -0.199]
"""
return [w - lr * g for w, g in zip(weights, grads)]
13 changes: 13 additions & 0 deletions machine_learning/neural_network/optimizers/test_sgd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from sgd import sgd_update


def test_sgd():
weights = [0.5, -0.2]
grads = [0.1, -0.1]
updated = sgd_update(weights, grads, lr=0.01)
assert updated == [0.499, -0.199], f"Expected [0.499, -0.199], got {updated}"


if __name__ == "__main__":
test_sgd()
print("SGD test passed!")
13 changes: 13 additions & 0 deletions test_sgd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from neural_network.optimizers.sgd import sgd_update


def test_sgd():
weights = [0.5, -0.2]
grads = [0.1, -0.1]
updated = sgd_update(weights, grads, lr=0.01)
assert updated == [0.499, -0.199], f"Expected [0.499, -0.199], got {updated}"


if __name__ == "__main__":
test_sgd()
print("SGD test passed!")
Loading