Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

增加图像增强方式(HShift、VShift、Pad 等) #2

Merged
merged 18 commits into from
Nov 7, 2021
56 changes: 56 additions & 0 deletions patta/functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import paddle
import paddle.nn.functional as F
import numpy as np
import cv2


def rot90(x, k=1):
Expand All @@ -23,6 +24,14 @@ def vflip(x):
return x.flip([2])


def hshift(x, shifts=0):
return paddle.roll(x, int(shifts*x.shape[3]), axis=3)


def vshift(x, shifts=0):
return paddle.roll(x, int(shifts*x.shape[2]), axis=2)


def sum(x1, x2):
"""sum of two tensors"""
return x1 + x2
Expand Down Expand Up @@ -124,6 +133,19 @@ def keypoints_vflip(keypoints):
return _assemble_keypoints(x, 1. - y)


def keypoints_hshift(keypoints, shifts):
x, y = _disassemble_keypoints(keypoints)
return _assemble_keypoints((x + shifts) % 1, y)


def keypoints_vshift(keypoints, shifts):
x, y = _disassemble_keypoints(keypoints)
return _assemble_keypoints(x, (y + shifts) % 1)

def keypoints_pad(keypoints, pad ):
x, y = _disassemble_keypoints(keypoints)
return _assemble_keypoints(x*x/(x+pad[0]), y*y/(y + pad[0]))

def keypoints_rot90(keypoints, k=1):

if k not in {0, 1, 2, 3}:
Expand All @@ -140,3 +162,37 @@ def keypoints_rot90(keypoints, k=1):
xy = [1. - y, x]

return _assemble_keypoints(*xy)



def adjust_contrast(x, contrast_factor=0.5):
table = np.array([(i - 74) * contrast_factor + 74
for i in range(0, 256)]).clip(0, 255).astype('uint8')
try:
x = x.paddle.to_tensor(x).numpy()
except:
x = x.numpy()
x=x.clip(0,255).astype('uint8')
x = cv2.LUT(x, table)
x = x.astype(np.float32)
return paddle.to_tensor(x)



def adjust_brightness(x, brightness_factor=1):
table = np.array([i * brightness_factor
for i in range(0, 256)]).clip(0, 255).astype('uint8')
try:
x = x.paddle.to_tensor(x).numpy()
except:
x = x.numpy()
x=x.clip(0,255).astype('uint8')
x = cv2.LUT(x, table)
x = x.astype(np.float32)
return paddle.to_tensor(x)



def pad(x, pad=0, mode='constant', value=0):
return F.pad(x, pad, mode, value)

109 changes: 109 additions & 0 deletions patta/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,112 @@ def apply_deaug_mask(self, mask, **kwargs):

def apply_deaug_keypoints(self, keypoints, **kwargs):
raise ValueError("`FiveCrop` augmentation is not suitable for keypoints!")


class HorizontallyShift(DualTransform):
"""Roll the x tensor along the given axis(axes=3). """
identity_param = 0

def __init__(self, shifts: List[float]):
if self.identity_param not in shifts:
shifts = [self.identity_param] + list(shifts)
super().__init__("shifts", shifts)

def apply_aug_image(self, image, shifts=0, **kwargs):
image = F.hshift(image, shifts)
return image

def apply_deaug_mask(self, mask, shifts=0, **kwargs):
return self.apply_aug_image(mask, -shifts)

def apply_deaug_label(self, label, shifts=0, **kwargs):
return label

def apply_deaug_keypoints(self, keypoints, shifts=0, **kwargs):
return F.keypoints_hshift(keypoints, -shifts)


class VerticalShift(DualTransform):
"""Roll the x tensor along the given axis(axes=2). """
identity_param = 0

def __init__(self, shifts: List[float]):
if self.identity_param not in shifts:
shifts = [self.identity_param] + list(shifts)
super().__init__("shifts", shifts)

def apply_aug_image(self, image, shifts=0, **kwargs):
image = F.vshift(image, shifts)
return image

def apply_deaug_mask(self, mask, shifts=0, **kwargs):
return self.apply_aug_image(mask, -shifts)

def apply_deaug_label(self, label, shifts=0, **kwargs):
return label

def apply_deaug_keypoints(self, keypoints, shifts=0, **kwargs):
return F.keypoints_vshift(keypoints, -shifts)



class Pad(DualTransform):
"""Pad the picture. """
identity_param = 0
def __init__(
self,
pads:List[Tuple[int, int]],
mode:str,
original_pad:Tuple[int, int] = None,
value:int=0):
if self.identity_param not in pads:
pads = [self.identity_param] + list(pads)
self.original_pad = original_pad,
self.mode = mode,
self.value = value,
super().__init__("pads",pads)

def apply_aug_image(self, image, pad=(0,0), **kwargs):
image = F.pad(image, pad, self.mode, self.value)
return image

def apply_deaug_mask(self, mask, pad=(0,0), **kwargs):
if self.original_pad is None:
raise ValueError(
"Provide original image size to make mask backward transformation"
)
if pad != self.original_pad:
H = mask.shape[2]
W = mask.shape[3]
mask = mask[:,:,H-pad[0],W-pad[0]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里切片不对,应当是个范围,你这里只剩一个数了,而且为什么都是 pad[0] 呢?你先试着改一下,等会我再来看

return mask

def apply_deaug_label(self, label, **kwargs):
return label

def apply_deaug_keypoints(self, keypoints, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO

return F.keypoints_pad(keypoints,)


class AdjustContrast(ImageOnlyTransform):
''''''
SigureMo marked this conversation as resolved.
Show resolved Hide resolved
identity_param = 1
def __init__(self, factors: List[int]):
if self.identity_param not in factors:
factors = [self.identity_param] + list(factors)
super.__init__("contrast_factor", factors)
SigureMo marked this conversation as resolved.
Show resolved Hide resolved
def apply_aug_image(self, image, factors=0.5, **kwargs):
SigureMo marked this conversation as resolved.
Show resolved Hide resolved
return F.adjust_contrast(image, factors)


class AdjustBrightness(ImageOnlyTransform):
''''''
identity_param = 1
SigureMo marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, factors: List[int]):
if self.identity_param not in factors:
factors = [self.identity_param] + list(factors)
super.__init__("brightness_factor", factors)
SigureMo marked this conversation as resolved.
Show resolved Hide resolved

def apply_aug_image(self, image, brightness_factor=0.5, **kwargs):
SigureMo marked this conversation as resolved.
Show resolved Hide resolved
return F.adjust_brightness(image, brightness_factor)