Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Errors when using tansforms.Normalize() instead of define a normalisation module #8

Open
Balabala-Hong opened this issue Mar 3, 2019 · 0 comments

Comments

@Balabala-Hong
Copy link

Hello, cool work! I tried to use the transforms.Normalize() method instead of designing a Normalization class as you did ,but the loss seems not converging, is it unachievable to use transforms.Normalize() method in your code?

load_transform = transforms.Compose([
    transforms.Resize(
        image_size),  # notice the resized img width is image_size
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])


def img_loader(image_name):
    img = Image.open(image_name)
    img = load_transform(img).unsqueeze(0)  
    return img.to(device, torch.float)


style_img = img_loader("./datasets/images/picasso.jpg")  # 650*650
content_img = img_loader("./datasets/images/dancing.jpg")  # 444*444
assert style_img.size() == content_img.size(
), "The content-image and the style-image is not compatiable in shape"


# Define the function to show the tensor(Caution: we need to change the tensor format to PIL format)
def img_show(tensor, title=None):
    img = tensor.cpu().clone()  # we clone the tensor to not do changes on it
    img = img.squeeze(0)  # CHW format
    img = img.detach().numpy().transpose((1, 2, 0))  # HWC format
    img = img * np.array([0.229, 0.224, 0.225]) + \
          np.array([0.485, 0.456, 0.406])
    img = img.clip(0, 1)
    plt.imshow(img)
    #  function plt.imshow() performs on RGB data of float [0-1] or int [0-255]
    if title != None:
        plt.title(title)
    plt.pause(0.5)


plt.ion()
plt.figure()
img_show(style_img, title="Style Image")
plt.figure()
img_show(content_img, title="Content Image")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant