forked from TheAlgorithms/Python
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add LeNet Implementation in PyTorch (TheAlgorithms#7070)
* add torch to requirements * add lenet architecture in pytorch * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add type hints * remove file * add type hints * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update variable name * add fail test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add newline * reformatting --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
792e619
commit ac31ca1
Showing
2 changed files
with
83 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
""" | ||
LeNet Network | ||
Paper: http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf | ||
""" | ||
|
||
import numpy | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class LeNet(nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
self.tanh = nn.Tanh() | ||
self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2) | ||
|
||
self.conv1 = nn.Conv2d( | ||
in_channels=1, | ||
out_channels=6, | ||
kernel_size=(5, 5), | ||
stride=(1, 1), | ||
padding=(0, 0), | ||
) | ||
self.conv2 = nn.Conv2d( | ||
in_channels=6, | ||
out_channels=16, | ||
kernel_size=(5, 5), | ||
stride=(1, 1), | ||
padding=(0, 0), | ||
) | ||
self.conv3 = nn.Conv2d( | ||
in_channels=16, | ||
out_channels=120, | ||
kernel_size=(5, 5), | ||
stride=(1, 1), | ||
padding=(0, 0), | ||
) | ||
|
||
self.linear1 = nn.Linear(120, 84) | ||
self.linear2 = nn.Linear(84, 10) | ||
|
||
def forward(self, image_array: numpy.ndarray) -> numpy.ndarray: | ||
image_array = self.tanh(self.conv1(image_array)) | ||
image_array = self.avgpool(image_array) | ||
image_array = self.tanh(self.conv2(image_array)) | ||
image_array = self.avgpool(image_array) | ||
image_array = self.tanh(self.conv3(image_array)) | ||
|
||
image_array = image_array.reshape(image_array.shape[0], -1) | ||
image_array = self.tanh(self.linear1(image_array)) | ||
image_array = self.linear2(image_array) | ||
return image_array | ||
|
||
|
||
def test_model(image_tensor: torch.tensor) -> bool: | ||
""" | ||
Test the model on an input batch of 64 images | ||
Args: | ||
image_tensor (torch.tensor): Batch of Images for the model | ||
>>> test_model(torch.randn(64, 1, 32, 32)) | ||
True | ||
""" | ||
try: | ||
model = LeNet() | ||
output = model(image_tensor) | ||
except RuntimeError: | ||
return False | ||
|
||
return output.shape == torch.zeros([64, 10]).shape | ||
|
||
|
||
if __name__ == "__main__": | ||
random_image_1 = torch.randn(64, 1, 32, 32) | ||
random_image_2 = torch.randn(1, 32, 32) | ||
|
||
print(f"random_image_1 Model Passed: {test_model(random_image_1)}") | ||
print(f"\nrandom_image_2 Model Passed: {test_model(random_image_2)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ statsmodels | |
sympy | ||
tensorflow | ||
texttable | ||
torch | ||
tweepy | ||
xgboost | ||
yulewalker |