Skip to content

Commit ac31ca1

Browse files
ishandutta0098pre-commit-ci[bot]
authored andcommitted
Add LeNet Implementation in PyTorch (TheAlgorithms#7070)
* add torch to requirements * add lenet architecture in pytorch * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add type hints * remove file * add type hints * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update variable name * add fail test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add newline * reformatting --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 792e619 commit ac31ca1

File tree

2 files changed

+83
-0
lines changed

2 files changed

+83
-0
lines changed

computer_vision/lenet_pytorch.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""
2+
LeNet Network
3+
4+
Paper: http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf
5+
"""
6+
7+
import numpy
8+
import torch
9+
import torch.nn as nn
10+
11+
12+
class LeNet(nn.Module):
13+
def __init__(self) -> None:
14+
super().__init__()
15+
16+
self.tanh = nn.Tanh()
17+
self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)
18+
19+
self.conv1 = nn.Conv2d(
20+
in_channels=1,
21+
out_channels=6,
22+
kernel_size=(5, 5),
23+
stride=(1, 1),
24+
padding=(0, 0),
25+
)
26+
self.conv2 = nn.Conv2d(
27+
in_channels=6,
28+
out_channels=16,
29+
kernel_size=(5, 5),
30+
stride=(1, 1),
31+
padding=(0, 0),
32+
)
33+
self.conv3 = nn.Conv2d(
34+
in_channels=16,
35+
out_channels=120,
36+
kernel_size=(5, 5),
37+
stride=(1, 1),
38+
padding=(0, 0),
39+
)
40+
41+
self.linear1 = nn.Linear(120, 84)
42+
self.linear2 = nn.Linear(84, 10)
43+
44+
def forward(self, image_array: numpy.ndarray) -> numpy.ndarray:
45+
image_array = self.tanh(self.conv1(image_array))
46+
image_array = self.avgpool(image_array)
47+
image_array = self.tanh(self.conv2(image_array))
48+
image_array = self.avgpool(image_array)
49+
image_array = self.tanh(self.conv3(image_array))
50+
51+
image_array = image_array.reshape(image_array.shape[0], -1)
52+
image_array = self.tanh(self.linear1(image_array))
53+
image_array = self.linear2(image_array)
54+
return image_array
55+
56+
57+
def test_model(image_tensor: torch.tensor) -> bool:
58+
"""
59+
Test the model on an input batch of 64 images
60+
61+
Args:
62+
image_tensor (torch.tensor): Batch of Images for the model
63+
64+
>>> test_model(torch.randn(64, 1, 32, 32))
65+
True
66+
67+
"""
68+
try:
69+
model = LeNet()
70+
output = model(image_tensor)
71+
except RuntimeError:
72+
return False
73+
74+
return output.shape == torch.zeros([64, 10]).shape
75+
76+
77+
if __name__ == "__main__":
78+
random_image_1 = torch.randn(64, 1, 32, 32)
79+
random_image_2 = torch.randn(1, 32, 32)
80+
81+
print(f"random_image_1 Model Passed: {test_model(random_image_1)}")
82+
print(f"\nrandom_image_2 Model Passed: {test_model(random_image_2)}")

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ statsmodels
1717
sympy
1818
tensorflow
1919
texttable
20+
torch
2021
tweepy
2122
xgboost
2223
yulewalker

0 commit comments

Comments
 (0)