Skip to content

Commit

Permalink
lint test_models
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad committed Nov 10, 2023
1 parent 6f83300 commit 908f679
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,14 @@ def test_transformer_basic():
for i in range(num_frames):
instances = []
for j in range(num_detected):
instances.append(Instance(bbox=torch.rand(size=(1, 4)),
features=torch.rand(size=(1, feats))))
frames.append(Frame(video_id = 0, frame_id=i,
instances=instances))

instances.append(
Instance(
bbox=torch.rand(size=(1, 4)), features=torch.rand(size=(1, feats))
)
)
frames.append(Frame(video_id=0, frame_id=i, instances=instances))

asso_preds,_ = transformer(frames)
asso_preds, _ = transformer(frames)

assert asso_preds[0].size() == (num_detected * num_frames,) * 2

Expand Down Expand Up @@ -274,10 +275,12 @@ def test_transformer_embedding():
for i in range(num_frames):
instances = []
for j in range(num_detected):
instances.append(Instance(bbox=torch.rand(size=(1, 4)),
features=torch.rand(size=(1, feats))))
frames.append(Frame(video_id = 0, frame_id=i,
instances=instances))
instances.append(
Instance(
bbox=torch.rand(size=(1, 4)), features=torch.rand(size=(1, feats))
)
)
frames.append(Frame(video_id=0, frame_id=i, instances=instances))

embedding_meta = {
"embedding_type": "learned_pos_temp",
Expand Down Expand Up @@ -316,13 +319,14 @@ def test_tracking_transformer():
for i in range(num_frames):
instances = []
for j in range(num_detected):
instances.append(Instance(bbox=torch.rand(size=(1, 4)),
crop=torch.rand(size=(1, 1, 64, 64))
))
frames.append(Frame(video_id=0,
frame_id=i,
img_shape=img_shape,
instances=instances))
instances.append(
Instance(
bbox=torch.rand(size=(1, 4)), crop=torch.rand(size=(1, 1, 64, 64))
)
)
frames.append(
Frame(video_id=0, frame_id=i, img_shape=img_shape, instances=instances)
)

embedding_meta = {
"embedding_type": "fixed_pos",
Expand Down

0 comments on commit 908f679

Please sign in to comment.