Skip to content

aradha/recursive_feature_machines

Repository files navigation

Recursive Feature Machines

There are two notebooks to test out RFM:

  • low_rank.ipynb (an example of low rank polynomials)
  • svhn.ipynb (for the SVHN dataset)

Installation

Can be installed using the command

 pip install git+https://github.com/aradha/recursive_feature_machines.git@pip_install

Requirements:

  • Python 3.8+
  • torchvision==0.14.0
  • hickle==5.0.2
  • tqdm

Stable behavior

Code has been tested using PyTorch 1.13, Python 3.8

Testing installation

import torch
from rfm import LaplaceRFM

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    # find GPU memory in GB, keeping aside 1GB for safety
    DEV_MEM_GB = torch.cuda.get_device_properties(DEVICE).total_memory//1024**3 - 1 
else:
    DEVICE = torch.device("cpu")
    DEV_MEM_GB = 8

def fstar(X):
    return torch.cat([(X[:,0]>0)[:,None], 
	(X[:,1]<0.5)[:,None]], axis=1).float()

model = LaplaceRFM(bandwidth=1., device=DEVICE, mem_gb=DEV_MEM_GB, diag=False)

n = 1000 # samples
d = 100  # dimension
c = 2    # classes

X_train = torch.randn(n, d, device=DEVICE)
X_test = torch.randn(n, d, device=DEVICE)
y_train = fstar(X_train)
y_test = fstar(X_test)

model.fit(
    (X_train, y_train), 
    (X_test, y_test), 
    iters=5,
    classification=False
)

Paper

Mechanism for feature learning in neural networks and backpropagation-free machine learning models

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published