Skip to content

Commit

Permalink
v3 = hr boxes
Browse files Browse the repository at this point in the history
log/2017-10-25T15:04:45/
  • Loading branch information
pesser committed Oct 25, 2017
1 parent b6c7b46 commit 82a0830
Show file tree
Hide file tree
Showing 5 changed files with 499 additions and 127 deletions.
196 changes: 195 additions & 1 deletion batches_pg2.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def tile(X, rows, cols):

def plot_batch(X, out_path):
"""Save batch of images tiled."""
n_channels = X.shape[3]
if n_channels > 3:
X = X[:,:,:,np.random.choice(n_channels, size = 3)]
X = postprocess(X)
rc = math.sqrt(X.shape[0])
rows = cols = math.ceil(rc)
Expand Down Expand Up @@ -166,6 +169,193 @@ def make_joint_img(img_shape, jo, joints):
return img


def valid_joints(*joints):
j = np.stack(joints)
return (j >= 0).all()


def zoom(img, factor, center = None):
shape = img.shape[:2]
if center is None or not valid_joints(center):
center = np.array(shape) / 2
e1 = np.array([1,0])
e2 = np.array([0,1])

dst_center = np.array(center)
dst_e1 = e1 * factor
dst_e2 = e2 * factor

src = np.float32([center, center+e1, center+e2])
dst = np.float32([dst_center, dst_center+dst_e1, dst_center+dst_e2])
M = cv2.getAffineTransform(src, dst)

return cv2.warpAffine(img, M, shape, flags = cv2.INTER_AREA, borderMode = cv2.BORDER_REPLICATE)


def get_crop(bpart, joints, jo, wh, o_w, o_h, ar = 1.0):
bpart_indices = [jo.index(b) for b in bpart]
part_src = np.float32(joints[bpart_indices])

# fall backs
if not valid_joints(part_src):
if bpart[0] == "lhip" and bpart[1] == "lknee":
bpart = ["lhip"]
bpart_indices = [jo.index(b) for b in bpart]
part_src = np.float32(joints[bpart_indices])
elif bpart[0] == "rhip" and bpart[1] == "rknee":
bpart = ["rhip"]
bpart_indices = [jo.index(b) for b in bpart]
part_src = np.float32(joints[bpart_indices])

if not valid_joints(part_src):
return None

if part_src.shape[0] == 1:
# leg fallback
a = part_src[0]
b = np.float32([a[0],o_h - 1])
part_src = np.float32([a,b])

if part_src.shape[0] == 4:
pass
elif part_src.shape[0] == 3:
# lshoulder, rshoulder, cnose
segment = part_src[1] - part_src[0]
normal = np.array([-segment[1],segment[0]])
if normal[1] > 0.0:
normal = -normal

a = part_src[0] + normal
b = part_src[0]
c = part_src[1]
d = part_src[1] + normal
part_src = np.float32([a,b,c,d])
else:
assert part_src.shape[0] == 2

segment = part_src[1] - part_src[0]
normal = np.array([-segment[1],segment[0]])
alpha = ar / 2.0
a = part_src[0] + alpha*normal
b = part_src[0] - alpha*normal
c = part_src[1] - alpha*normal
d = part_src[1] + alpha*normal
part_src = np.float32([a,b,c,d])

dst = np.float32([[0.0,0.0],[0.0,1.0],[1.0,1.0],[1.0,0.0]])
part_dst = np.float32(wh * dst)

M = cv2.getPerspectiveTransform(part_src, part_dst)
return M


def normalize(imgs, coords, stickmen, jo):

out_imgs = list()
out_stickmen = list()

bs = len(imgs)
for i in range(bs):
img = imgs[i]
joints = coords[i]
stickman = stickmen[i]

h,w = img.shape[:2]
o_h = h
o_w = w
h = h // 4
w = w // 4
wh = np.array([w,h])
wh = np.expand_dims(wh, 0)

bparts = [
["lshoulder","lhip","rhip","rshoulder"],
["lshoulder", "rshoulder", "rshoulder"],
["lshoulder","lelbow"],
["lelbow", "lwrist"],
["rshoulder","relbow"],
["relbow", "rwrist"],
["lhip", "lknee"],
["lknee", "lankle"],
["rhip", "rknee"],
["rknee", "rankle"]]
ar = 0.5

part_imgs = list()
part_stickmen = list()
for bpart in bparts:
part_img = np.zeros((h,w,3))
part_stickman = np.zeros((h,w,3))
M = get_crop(bpart, joints, jo, wh, o_w, o_h, ar)

if M is not None:
part_img = cv2.warpPerspective(img, M, (h,w), borderMode = cv2.BORDER_REPLICATE)
part_stickman = cv2.warpPerspective(stickman, M, (h,w), borderMode = cv2.BORDER_REPLICATE)

part_imgs.append(part_img)
part_stickmen.append(part_stickman)
img = np.concatenate(part_imgs, axis = 2)
stickman = np.concatenate(part_stickmen, axis = 2)

"""
bpart = ["lshoulder","lhip","rhip","rshoulder"]
dst = np.float32([[0.0,0.0],[0.0,1.0],[1.0,1.0],[1.0,0.0]])
bpart_indices = [jo.index(b) for b in bpart]
part_src = np.float32(joints[bpart_indices])
part_dst = np.float32(wh * dst)
M = cv2.getPerspectiveTransform(part_src, part_dst)
img = cv2.warpPerspective(img, M, (h,w), borderMode = cv2.BORDER_REPLICATE)
stickman = cv2.warpPerspective(stickman, M, (h,w), borderMode = cv2.BORDER_REPLICATE)
"""

"""
# center of possible rescaling
c = joints[jo.index("cneck")]
# find valid body part for scale estimation
a = joints[jo.index("lshoulder")]
b = joints[jo.index("lhip")]
target_length = 33.0
if not valid_joints(a,b):
a = joints[jo.index("rshoulder")]
b = joints[jo.index("rhip")]
target_length = 33.0
if not valid_joints(a,b):
a = joints[jo.index("rshoulder")]
b = joints[jo.index("relbow")]
target_length = 33.0 / 2
if not valid_joints(a,b):
a = joints[jo.index("lshoulder")]
b = joints[jo.index("lelbow")]
target_length = 33.0 / 2
if not valid_joints(a,b):
a = joints[jo.index("lwrist")]
b = joints[jo.index("lelbow")]
target_length = 33.0 / 2
if not valid_joints(a,b):
a = joints[jo.index("rwrist")]
b = joints[jo.index("relbow")]
target_length = 33.0 / 2
if valid_joints(a,b):
body_length = np.linalg.norm(b - a)
factor = target_length / body_length
img = zoom(img, factor, center = c)
stickman = zoom(stickman, factor, center = c)
else:
factor = 0.25
img = zoom(img, factor, center = c)
stickman = zoom(stickman, factor, center = c)
"""

out_imgs.append(img)
out_stickmen.append(stickman)
out_imgs = np.stack(out_imgs)
out_stickmen = np.stack(out_stickmen)
return out_imgs, out_stickmen


def make_mask_img(img_shape, jo, joints):
scale_factor = img_shape[1] / 128
masks = 3*[None]
Expand Down Expand Up @@ -316,6 +506,10 @@ def __next__(self):
batch["imgs"] = batch["imgs"] * batch["masks"]


imgs, joints = normalize(batch["imgs"], batch["joints_coordinates"], batch["joints"], self.jo)
batch["norm_imgs"] = imgs
batch["norm_joints"] = joints

batch_list = [batch[k] for k in self.return_keys]
return batch_list

Expand All @@ -333,7 +527,7 @@ def get_batches(
mask,
fill_batches = True,
shuffle = True,
return_keys = ["imgs", "joints"]):
return_keys = ["imgs", "joints", "norm_imgs", "norm_joints"]):
"""Buffered IndexFlow."""
flow = IndexFlow(shape, index_path, train, mask, fill_batches, shuffle, return_keys)
return BufferedWrapper(flow)
Expand Down
Loading

0 comments on commit 82a0830

Please sign in to comment.