Skip to content

Commit

Permalink
📝 Add CUDA examples
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Dec 10, 2020
1 parent 0e89000 commit a78e4e3
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 16 deletions.
4 changes: 2 additions & 2 deletions piqa/gmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ class GMSD(nn.Module):
Example:
>>> criterion = GMSD()
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> y = torch.rand(5, 3, 256, 256).cuda()
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
Expand Down
6 changes: 3 additions & 3 deletions piqa/lpips.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ class LPIPS(nn.Module):
use the `torch.no_grad()` context or freeze the weights.
Example:
>>> criterion = LPIPS()
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> criterion = LPIPS().cuda()
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> y = torch.rand(5, 3, 256, 256).cuda()
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
Expand Down
4 changes: 2 additions & 2 deletions piqa/mdsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ class MDSI(nn.Module):
Example:
>>> criterion = MDSI()
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> y = torch.rand(5, 3, 256, 256).cuda()
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
Expand Down
4 changes: 2 additions & 2 deletions piqa/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ class PSNR(nn.Module):
Example:
>>> criterion = PSNR()
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> y = torch.rand(5, 3, 256, 256).cuda()
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
Expand Down
12 changes: 6 additions & 6 deletions piqa/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,9 @@ class SSIM(nn.Module):
* Output: (N,) or (1,) depending on `reduction`
Example:
>>> criterion = SSIM()
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> criterion = SSIM().cuda()
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> y = torch.rand(5, 3, 256, 256).cuda()
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
Expand Down Expand Up @@ -282,9 +282,9 @@ class MSSSIM(SSIM):
* Output: (N,) or (1,) depending on `reduction`
Example:
>>> criterion = MSSSIM()
>>> x = torch.rand(5, 3, 256, 256)
>>> y = torch.rand(5, 3, 256, 256)
>>> criterion = MSSSIM().cuda()
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> y = torch.rand(5, 3, 256, 256).cuda()
>>> l = criterion(x, y)
>>> l.size()
torch.Size([])
Expand Down
2 changes: 1 addition & 1 deletion piqa/tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class TV(nn.Module):
Example:
>>> criterion = TV()
>>> x = torch.rand(5, 3, 256, 256)
>>> x = torch.rand(5, 3, 256, 256).cuda()
>>> l = criterion(x)
>>> l.size()
torch.Size([])
Expand Down

0 comments on commit a78e4e3

Please sign in to comment.