Skip to content

Commit

Permalink
feat: improving model
Browse files Browse the repository at this point in the history
  • Loading branch information
MidKnightXI committed Dec 27, 2023
1 parent e3d603f commit f46b13e
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 53 deletions.
45 changes: 30 additions & 15 deletions runner.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,53 @@
import torch
from torch import nn
from torchvision.utils import save_image
from os import path
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

# Define the denoising model
class DenoisingCNN(nn.Module):
class DenoisingAutoencoder(nn.Module):
def __init__(self):
super(DenoisingCNN, self).__init__()
super(DenoisingAutoencoder, self).__init__()

self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
)
nn.Conv2d(3, 32, kernel_size = (3,3), padding = "same"),
nn.ReLU(),
nn.MaxPool2d((2,2), padding = 0),
nn.Conv2d(32, 64, kernel_size = (3,3), padding = "same"),
nn.ReLU(),
nn.MaxPool2d((2,2), padding = 0),
nn.Conv2d(64, 128, kernel_size = (3,3), padding = "same"),
nn.ReLU(),
nn.MaxPool2d((2,2), padding = 0))

self.decoder = nn.Sequential(
nn.Conv2d(64, 3, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
)
nn.ConvTranspose2d(128, 128, kernel_size = (3,3), stride = 2, padding = 0),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size = (3,3), stride = 2, padding = 0),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size = (3,3), stride = 2, padding = 0),
nn.ReLU(),
nn.ConvTranspose2d(32, 3, kernel_size = (3,3), stride = 1, padding = 1),
nn.Sigmoid())

def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x

# Instantiate the model
model = DenoisingCNN()
model = DenoisingAutoencoder()
model.load_state_dict(torch.load('denoising_model.pth'))
model.eval()

# Define a transform for the input image
transform = transforms.Compose([
transforms.ToTensor(),
])

# Load a sample image for denoising
image_path = 'path/to/your/sample/image.jpg' # Replace with the path to your image
image_path = '/Users/midknight/Downloads/Why+are+my+photos+grainy+3.jpg'
output_path = path.splitext(image_path)[0] + '_denoised.jpg'

sample_image = Image.open(image_path).convert("RGB")

# Preprocess the image
Expand All @@ -44,6 +57,8 @@ def forward(self, x):
with torch.no_grad():
denoised_image = model(input_image)

save_image(denoised_image, output_path)

# Display the original and denoised images
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
Expand Down
108 changes: 70 additions & 38 deletions trainer.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,103 @@
import torch
import torch.backends.mps
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import pandas as pd
from sys import stdout
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets.folder import default_loader
from PIL import Image

# Define the denoising model
class DenoisingCNN(nn.Module):
class DenoisingDataset(Dataset):
def __init__(self, csv_path, transform=None, target_size=(263, 263)):
self.data = pd.read_csv(csv_path, header=0, names=['path', 'label'])
self.transform = transform
self.target_size = target_size

def __len__(self):
return len(self.data)

def __getitem__(self, idx):
img_path = self.data.iloc[idx, 0]
label = self.data.iloc[idx, 1]

image = Image.open(img_path).convert('RGB')

# Resize the image to the target size
image = image.resize(self.target_size, Image.BICUBIC)

# Apply transformations if provided
if self.transform:
image = self.transform(image)

return {'image': image, 'label': label}



class DenoisingAutoencoder(nn.Module):
def __init__(self):
super(DenoisingCNN, self).__init__()
super(DenoisingAutoencoder, self).__init__()

self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
)
nn.Conv2d(3, 32, kernel_size = (3,3), padding = "same"),
nn.ReLU(),
nn.MaxPool2d((2,2), padding = 0),
nn.Conv2d(32, 64, kernel_size = (3,3), padding = "same"),
nn.ReLU(),
nn.MaxPool2d((2,2), padding = 0),
nn.Conv2d(64, 128, kernel_size = (3,3), padding = "same"),
nn.ReLU(),
nn.MaxPool2d((2,2), padding = 0))

self.decoder = nn.Sequential(
nn.Conv2d(64, 3, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
)
nn.ConvTranspose2d(128, 128, kernel_size = (3,3), stride = 2, padding = 0),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size = (3,3), stride = 2, padding = 0),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size = (3,3), stride = 2, padding = 0),
nn.ReLU(),
nn.ConvTranspose2d(32, 3, kernel_size = (3,3), stride = 1, padding = 1),
nn.Sigmoid())

def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x

# Set up data loaders and transformations

transform = transforms.Compose([
transforms.ToTensor(),
])

# Download CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
train_dataset = DenoisingDataset(csv_path='/Users/midknight/perso/ENHANCE/dataset/dataset_info.csv', transform=transform)
data_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Instantiate the model, loss function, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DenoisingCNN().to(device)
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
model = DenoisingAutoencoder().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10

for epoch in range(num_epochs):
model.train()
running_loss = 0.0
total_loss = 0.0

for data in train_loader:
inputs, _ = data
inputs = inputs.to(device)
for batch_idx, batch in enumerate(data_loader):
inputs = batch['image'].to(device)
targets = batch['image'].to(device)

# Zero the parameter gradients
optimizer.zero_grad()

# Forward pass
outputs = model(inputs)
loss = criterion(outputs, inputs) # MSE loss between the denoised image and the input

# Backward pass and optimization
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()

running_loss += loss.item()

print(f'Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}')
total_loss += loss.item()

print('Training finished.')
average_loss = total_loss / len(data_loader)
stdout.write(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {average_loss:.4f}\n')
stdout.flush()

# Save the trained model
torch.save(model.state_dict(), 'denoising_model.pth')
model.to(torch.device("cpu"))
torch.save(model.state_dict(), "denoising_model.pth")

0 comments on commit f46b13e

Please sign in to comment.