-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimpatientnet.py
26 lines (23 loc) · 898 Bytes
/
impatientnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn as nn
import random
class ImpatientNetLayer(nn.Module):
def init(self, input_dim, output_dim, alpha, gamma, p_max):
"""
This implementation highlights how ImpatientNet adjusts its weight updates
based on a step-skipping probability, speeding through training epochs and
reflecting the human tendencies of impatience and overconfidence in the
learning process.
"""
super(ImpatientNetLayer, self).init()
self.weight = nn.Parameter(torch.randn(input_dim, output_dim))
self.alpha = alpha
self.gamma = gamma
self.p_max = p_max
def step_skipping_probability(self):
return random.uniform(0, self.p_max)
def forward(self, x):
skip_probability = self.step_skipping_probability()
alpha_t = self.gamma * self.alpha if skip_probability < self.p_max else 0
self.weight += alpha_t * x.grad
return x @ self.weight