Skip to content

Commit

Permalink
Added device option on generate.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rosinality committed Jul 9, 2019
1 parent 866e506 commit c63e7c2
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from model import StyledGenerator

generator = StyledGenerator(512).cuda()
device = 'cuda'

generator = StyledGenerator(512).to(device)
generator.load_state_dict(torch.load('checkpoint/180000.model'))
generator.eval()

Expand All @@ -16,7 +18,7 @@

with torch.no_grad():
for i in range(10):
style = generator.mean_style(torch.randn(1024, 512).cuda())
style = generator.mean_style(torch.randn(1024, 512).to(device))

if mean_style is None:
mean_style = style
Expand All @@ -27,7 +29,7 @@
mean_style /= 10

image = generator(
torch.randn(15, 512).cuda(),
torch.randn(15, 512).to(device),
step=step,
alpha=alpha,
mean_style=mean_style,
Expand All @@ -37,10 +39,10 @@
utils.save_image(image, 'sample.png', nrow=5, normalize=True, range=(-1, 1))

for j in range(20):
source_code = torch.randn(5, 512).cuda()
target_code = torch.randn(3, 512).cuda()
source_code = torch.randn(5, 512).to(device)
target_code = torch.randn(3, 512).to(device)

images = [torch.ones(1, 3, shape, shape).cuda() * -1]
images = [torch.ones(1, 3, shape, shape).to(device) * -1]

source_image = generator(
source_code, step=step, alpha=alpha, mean_style=mean_style, style_weight=0.7
Expand Down

0 comments on commit c63e7c2

Please sign in to comment.