Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Camera backside problem solved and visualizing resolution control added #4

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# Project O

[![Video Label](http://img.youtube.com/vi/ySSuvN-O8wo/0.jpg)](https://www.youtube.com/watch?v=ySSuvN-O8wo)

You can check the code used in the project above in the "apply_to_video.py" file.

<hr>

# TVCalib: Camera Calibration for Sports Field Registration in Soccer

<div align="center">
Expand Down
165 changes: 165 additions & 0 deletions apply_to_video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from collections import defaultdict
from functools import partial

import numpy as np
import torch
# import matplotlib
# matplotlib.use("Qt5Agg")
# matplotlib.use("TkAgg")

import matplotlib.pyplot as plt
plt.ioff()

from tqdm.auto import tqdm
import pandas as pd
from SoccerNet.Evaluation.utils_calibration import SoccerPitch

from tvcalib.module import TVCalibModule
from tvcalib.cam_distr.tv_main_center import get_cam_distr, get_dist_distr
from sn_segmentation.src.custom_extremities import generate_class_synthesis, get_line_extremities
from tvcalib.sncalib_dataset import custom_list_collate, split_circle_central
from tvcalib.utils.io import detach_dict, tensor2list
from tvcalib.utils.objects_3d import SoccerPitchLineCircleSegments, SoccerPitchSNCircleCentralSplit
from tvcalib.inference import InferenceDatasetCalibration, InferenceDatasetSegmentation, InferenceSegmentationModel
from tvcalib.inference import get_camera_from_per_sample_output
from tvcalib.utils import visualization_mpl_min as viz

import imageio
from skimage import img_as_ubyte
from skimage.transform import resize
from scipy.signal import savgol_filter


# def show(img):
# plt.figure()
# plt.imshow(img)
# plt.axis("off")
# plt.tight_layout()
# plt.show()


# read input video
# frame_raw_list stores frames of original resolutions of the video
# frame_list stores frames resized to (256, 455) so that the model can receive it.
input_video_name = 'demo_video_trim.mp4'

reader = imageio.get_reader(input_video_name)
fps = reader.get_meta_data()['fps']
frame_list = []
frame_raw_list = []
for im in tqdm(reader, desc="reading video", total=reader.count_frames()):
frame_list.append(resize(im, (256, 455), order=3))
frame_raw_list.append(im/256)
reader.close()
frame_shape = frame_list[0].shape[:2]
frame_raw_shape = frame_raw_list[0].shape[:2]


# detecting keypoints from the frames
device = "cuda"

object3d = SoccerPitchLineCircleSegments(
device=device, base_field=SoccerPitchSNCircleCentralSplit()
)
object3dcpu = SoccerPitchLineCircleSegments(
device="cpu", base_field=SoccerPitchSNCircleCentralSplit()
)

lines_palette = [0, 0, 0]
for line_class in SoccerPitch.lines_classes:
lines_palette.extend(SoccerPitch.palette[line_class])

fn_generate_class_synthesis = partial(generate_class_synthesis, radius=4)
fn_get_line_extremities = partial(get_line_extremities, maxdist=30, width=frame_shape[1], height=frame_shape[0],
num_points_lines=4, num_points_circles=8)


model_seg = InferenceSegmentationModel("data\\segment_localization\\train_59.pt", device)


keypoints_raw_li = []
for data in tqdm(frame_list, desc="extracting keypoints"):
data = torch.FloatTensor(np.transpose(data, (2, 0, 1))).to(device)
with torch.no_grad():
sem_lines = model_seg.inference(data.unsqueeze(0))
sem_lines = sem_lines.cpu().numpy().astype(np.uint8)

skeletons = fn_generate_class_synthesis(np.squeeze(sem_lines, axis=0))
keypoints_raw = fn_get_line_extremities(skeletons)

keypoints_raw_li.append(keypoints_raw)


# Extract camera variables base on the key points obtained above
lens_dist = False

batch_size = 256
optim_steps = 2000

model_calib = TVCalibModule(
object3d,
get_cam_distr(1.96, batch_dim=batch_size, temporal_dim=1),
get_dist_distr(batch_dim=batch_size, temporal_dim=1) if lens_dist else None,
frame_shape,
optim_steps,
device,
log_per_step=False,
tqdm_kwqargs=None,
)


dataset_calib = InferenceDatasetCalibration(keypoints_raw_li, frame_shape[1], frame_shape[0], object3d)
dataloader_calib = torch.utils.data.DataLoader(dataset_calib, batch_size, collate_fn=custom_list_collate)


per_sample_output = defaultdict(list)
per_sample_output["image_id"] = [[x] for x in range(len(frame_list))]
for x_dict in dataloader_calib:
_batch_size = x_dict["lines__ndc_projected_selection_shuffled"].shape[0]

points_line = x_dict["lines__px_projected_selection_shuffled"]
points_circle = x_dict["circles__px_projected_selection_shuffled"]

per_sample_loss, cam, _ = model_calib.self_optim_batch(x_dict)
output_dict = tensor2list(detach_dict({**cam.get_parameters(_batch_size), **per_sample_loss}))

output_dict["points_line"] = points_line
output_dict["points_circle"] = points_circle
for k in output_dict.keys():
per_sample_output[k].extend(output_dict[k])


df = pd.DataFrame.from_dict(per_sample_output)

df = df.explode(column=[k for k, v in per_sample_output.items() if isinstance(v, list)])
df.set_index("image_id", inplace=True, drop=False)


# Smoothing using the savgol filter
window = 31
df.aov_radian = savgol_filter(df.aov_radian, window, 3)
df.pan_degrees = savgol_filter(df.pan_degrees, window, 3)
df.roll_degrees = savgol_filter(df.roll_degrees, window, 3)
df.tilt_degrees = savgol_filter(df.tilt_degrees, window, 3)
df.position_meters = pd.Series(savgol_filter(np.array(df.position_meters.to_list()), window, 3, axis=0).tolist())


plt.ioff()
result = []
for i in tqdm(range(len(frame_list)), desc="generating final video"):
sample = df.iloc[i]

image_raw = torch.tensor(np.transpose(frame_raw_list[i], (2, 0, 1)))

cam = get_camera_from_per_sample_output(sample, lens_dist)

fig, ax = viz.init_figure(frame_raw_shape[1], frame_raw_shape[0])
ax = viz.draw_image(ax, image_raw)
ax = viz.draw_reprojection(ax, object3dcpu, cam, ratio_width=frame_raw_shape[1]/frame_shape[1],
ratio_height=frame_raw_shape[1]/frame_shape[1])

fig.canvas.draw()
result.append(np.array(fig.canvas.renderer._renderer))
plt.close()

imageio.mimsave("result.mp4", [img_as_ubyte(p) for p in result], fps=fps)
45 changes: 35 additions & 10 deletions tvcalib/utils/visualization_mpl_min.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Polygon, Rectangle
import kornia


def init_figure(img_width, img_height, img_delta_w=0.2, img_delta_h=0.1):
Expand All @@ -23,21 +24,43 @@ def init_figure(img_width, img_height, img_delta_w=0.2, img_delta_h=0.1):
return fig, ax


def draw_reprojection(ax, object3d, cam, dist_circles=0.25, kwargs={"alpha": 0.5, "linewidth": 5}):
def draw_reprojection(ax, object3d, cam, ratio_width=1, ratio_height=1, dist_circles=0.25, kwargs={"alpha": 0.5, "linewidth": 5}):
# ndc is an abbreviation for Normalized Device Coordinates
# for more details, refer https://carmencincotti.com/2022-05-02/homogeneous-coordinates-clip-space-ndc/#ndc
ndc_point = torch.tensor([0., 0., 1.]).view(1, 3, 1)
point_3d = torch.bmm(cam.P_ndc.pinverse(), ndc_point)
principal_3d = kornia.geometry.conversions.convert_points_from_homogeneous(point_3d.transpose(1, 2))

# original projection
lines3d = object3d.line_segments.transpose(0, 1).transpose(-1, -2)
for lidx, line_name in enumerate(object3d.line_segments_names):
points_px = torch.zeros((23, 2, 2))
for lidx in range(23):
with torch.no_grad():
line3d = lines3d[lidx]
points_px = cam.project_point2pixel(line3d, lens_distortion=False).cpu()[0, 0]
line2d = Line2D(
points_px[:, 0],
points_px[:, 1],
color=object3d.cmap_01[line_name],
**kwargs,
)
ax.add_line(line2d)
points_px[lidx, :, :] = cam.project_point2pixel(line3d, lens_distortion=False).cpu()[0, 0]

# Exception handling for camera back side problem
for lidx in range(23):
line3d = lines3d[lidx]
if (torch.dot(line3d[0] - cam.position.view(3), principal_3d.view(3) - cam.position.view(3)) < 0) and \
(torch.dot(line3d[1] - cam.position.view(3), principal_3d.view(3) - cam.position.view(3)) < 0):
continue
elif torch.dot(line3d[0] - cam.position.view(3), principal_3d.view(3) - cam.position.view(3)) < 0:
points_px[lidx][0] = points_px[lidx][1] + 10 * (points_px[lidx][1] - points_px[lidx][0])
elif torch.dot(line3d[1] - cam.position.view(3), principal_3d.view(3) - cam.position.view(3)) < 0:
points_px[lidx][1] = points_px[lidx][0] + 10 * (points_px[lidx][0] - points_px[lidx][1])

# draw lines
for lidx, line_name in enumerate(object3d.line_segments_names):
line2d = Line2D(
points_px[lidx, :, 0]*ratio_width,
points_px[lidx, :, 1]*ratio_height,
color=object3d.cmap_01[line_name],
**kwargs,
)
ax.add_line(line2d)

# draw circles
points3d_circle = {
k: torch.from_numpy(np.stack(v, axis=0)).float()
for k, v in object3d._field_sncalib.sample_field_points(
Expand All @@ -48,6 +71,8 @@ def draw_reprojection(ax, object3d, cam, dist_circles=0.25, kwargs={"alpha": 0.5
for circle_name, circle3d in points3d_circle.items():
with torch.no_grad():
points_px = cam.project_point2pixel(circle3d, lens_distortion=False).cpu()[0, 0]
points_px[:, 0] *= ratio_height
points_px[:, 1] *= ratio_width
# print(circle_name, circle3d.shape, points_px.shape)
circle2d = Polygon(
points_px[:, :2],
Expand Down