Skip to content

Latest commit

 

History

History
68 lines (54 loc) · 2.24 KB

README.md

File metadata and controls

68 lines (54 loc) · 2.24 KB

ShrimpGrad - Yet Another Tensor Library

Philosophy

"Vigorous writing is concise. A sentence should contain no unnecessary words, a paragraph no unnecessary sentences, for the same reason that a drawing should have no unnecessary lines and a machine no unnecessary parts. This requires not that the writer make all sentences short or avoid all detail and treat subjects only in outline, but that every word tell." - William Strunk Jr.

"You can do big things with small teams, but it’s hard to do small things with big teams. And small is often plenty. That’s the power of small — you do what needs to be done rather than overdoing it." - 37Signals

What is ShrimpGrad?

A simple, minimalist, lazily evaluated, JIT-able tensor library for modern deep learning.

from shrimpgrad import Tensor, nn
from shrimpgrad.engine.jit import ShrimpJit
from sklearn.datasets import make_moons

X, y = make_moons(n_samples=100, noise=0.1)
X = X.astype(float)
y = y.astype(float)

class ShallowNet:
  def __init__(self):
    self.layers: List[Callable[[Tensor], Tensor]] = [
      nn.Linear(2, 50), Tensor.relu,
      nn.Linear(50, 1), Tensor.sigmoid,
    ]
  def __call__(self, x: Tensor):
    return x.sequential(self.layers)

@ShrimpJit
def train(X,y):
  sgd.zero_grad()
  out = model().reshape(100)
  loss = out.binary_cross_entropy()
  loss.backward()
  sgd.step()
  return out, loss

X = Tensor.fromlist(X.shape, X.flatten().tolist())
y = Tensor.fromlist(y.shape, y.flatten().tolist())
for epoch in range(50): train(X,y)

RISC Inspired

A reduced set of "instructions" is needed to define everything from matrix multiplication to 2D convolutions

  1. Binary - ADD, MUL, DIV, MAX, MOD, CMPLT, COMPEQ, XOR
  2. Unary - EXP2, LOG2, CAST, SIN, SQRT, NEG
  3. Ternary - WHERE
  4. Reduce - SUM, MAX
  5. Movement - RESHAPE, PERMUTE, EXPAND, PAD, SHRINK
  6. Load - EMPTY, COPY, CONST, ASSIGN

Easy JIT Compilation

Go full native with ease. JIT lowers execution from python to the accelerator including forward/backward passes and optimizer steps.

Install

The easiest way to get going is to install nix.

git clone https://github.com/Shrimp-AI/shrimpgrad.git
cd shrimpgrad
nix-shell

Otherwise

python3 -m pip install -e '.[testing]'