Skip to content

Commit

Permalink
update configs and tests to reflect changes
Browse files Browse the repository at this point in the history
* add tests for empty embedding
* use better tests for overriding embedding `kwargs`
  • Loading branch information
aaprasad committed Apr 25, 2024
1 parent 5fbda51 commit c6f2606
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 50 deletions.
4 changes: 2 additions & 2 deletions biogtr/training/configs/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ model:
dropout_attn_head: 0.1
embedding_meta:
pos:
type: "learned"
mode: "learned"
emb_num: 16
over_boxes: false
temp:
type: "learned"
mode: "learned"
emb_num: 16
return_embedding: False
decoder_self_attn: False
Expand Down
5 changes: 3 additions & 2 deletions biogtr/training/configs/params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ model:
num_decoder_layers: 2
embedding_meta:
pos:
type: learned
mode: learned
emb_num: 16
over_boxes: True
temp: null
temp:
mode: "off"
dataset:
train_dataset:
slp_files: ['190612_110405_wt_18159111_rig2.2@11730.slp']
Expand Down
4 changes: 3 additions & 1 deletion tests/configs/params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ model:
mode: 'learned'
emb_num: 16
over_boxes: True
temp: null
temp:
mode: "off"

dataset:
train_dataset:
Expand All @@ -29,3 +30,4 @@ trainer:
limit_test_batches: 1
limit_val_batches: 1
max_epochs: 1
enable_checkpointing: true
107 changes: 63 additions & 44 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,41 +66,41 @@ def test_embedding_validity():

# this would throw assertion since embedding should be "pos"
with pytest.raises(Exception):
_ = Embedding(type="position", mode="learned")
_ = Embedding(emb_type="position", mode="learned", features=128)
with pytest.raises(Exception):
_ = Embedding(type="position", mode="fixed")
_ = Embedding(emb_type="position", mode="fixed", features=128)

with pytest.raises(Exception):
_ = Embedding(type="temporal", mode="learned")
_ = Embedding(emb_type="temporal", mode="learned", features=128)
with pytest.raises(Exception):
_ = Embedding(type="position", mode="fixed")
_ = Embedding(emb_type="position", mode="fixed", features=128)

with pytest.raises(Exception):
_ = Embedding(type="pos", mode="learn")
_ = Embedding(emb_type="pos", mode="learn", features=128)
with pytest.raises(Exception):
_ = Embedding(type="temp", mode="learn")
_ = Embedding(emb_type="temp", mode="learn", features=128)

with pytest.raises(Exception):
_ = Embedding(type="pos", mode="fix")
_ = Embedding(emb_type="pos", mode="fix", features=128)
with pytest.raises(Exception):
_ = Embedding(type="temp", mode="fix")
_ = Embedding(emb_type="temp", mode="fix", features=128)

with pytest.raises(Exception):
_ = Embedding(type="position", mode="learn")
_ = Embedding(emb_type="position", mode="learn", features=128)
with pytest.raises(Exception):
_ = Embedding(type="temporal", mode="learn")
_ = Embedding(emb_type="temporal", mode="learn", features=128)
with pytest.raises(Exception):
_ = Embedding(type="position", mode="fix")
_ = Embedding(emb_type="position", mode="fix", features=128)
with pytest.raises(Exception):
_ = Embedding(type="temporal", mode="learn")
_ = Embedding(emb_type="temporal", mode="learn", features=128)

with pytest.raises(Exception):
_ = Embedding(type="temp", mode="fixed")
_ = Embedding(emb_type="temp", mode="fixed", features=128)

_ = Embedding(type="temp", mode="learned")
_ = Embedding(type="pos", mode="learned")
_ = Embedding(emb_type="temp", mode="learned", features=128)
_ = Embedding(emb_type="pos", mode="learned", features=128)

_ = Embedding(type="pos", mode="learned")
_ = Embedding(emb_type="pos", mode="learned", features=128)


def test_embedding():
Expand All @@ -116,7 +116,7 @@ def test_embedding():
times = torch.rand(size=(N,))

pos_emb = Embedding(
type="pos",
emb_type="pos",
mode="fixed",
features=d_model,
temperature=objects,
Expand All @@ -126,12 +126,30 @@ def test_embedding():

sine_pos_emb = pos_emb(boxes)

pos_emb = Embedding(type="pos", mode="learned", features=d_model, emb_num=100)
pos_emb = Embedding(emb_type="pos", mode="learned", features=d_model, emb_num=100)
learned_pos_emb = pos_emb(boxes)

temp_emb = Embedding(type="temp", mode="learned", features=d_model, emb_num=16)
temp_emb = Embedding(emb_type="temp", mode="learned", features=d_model, emb_num=16)
learned_temp_emb = temp_emb(times)

pos_emb_off = Embedding(emb_type="pos", mode="off", features=d_model)
off_pos_emb = pos_emb_off(boxes)

temp_emb_off = Embedding(emb_type="temp", mode="off", features=d_model)
off_temp_emb = temp_emb_off(times)

learned_emb_off = Embedding(emb_type="off", mode="learned", features=d_model)
off_learned_emb_boxes = learned_emb_off(boxes)
off_learned_emb_times = learned_emb_off(times)

fixed_emb_off = Embedding(emb_type="off", mode="fixed", features=d_model)
off_fixed_emb_boxes = fixed_emb_off(boxes)
off_fixed_emb_times = fixed_emb_off(times)

off_emb = Embedding(emb_type="off", mode="off", features=d_model)
off_emb_boxes = off_emb(boxes)
off_emb_times = off_emb(times)

assert sine_pos_emb.size() == (N, d_model)
assert learned_pos_emb.size() == (N, d_model)
assert learned_temp_emb.size() == (N, d_model)
Expand All @@ -140,6 +158,24 @@ def test_embedding():
assert not torch.equal(sine_pos_emb, learned_temp_emb)
assert not torch.equal(learned_pos_emb, learned_temp_emb)

assert off_pos_emb.size() == (N, d_model)
assert off_temp_emb.size() == (N, d_model)
assert off_learned_emb_boxes.size() == (N, d_model)
assert off_learned_emb_times.size() == (N, d_model)
assert off_fixed_emb_boxes.size() == (N, d_model)
assert off_fixed_emb_times.size() == (N, d_model)
assert off_emb_boxes.size() == (N, d_model)
assert off_emb_times.size() == (N, d_model)

assert not off_pos_emb.any()
assert not off_temp_emb.any()
assert not off_learned_emb_boxes.any()
assert not off_learned_emb_times.any()
assert not off_fixed_emb_boxes.any()
assert not off_fixed_emb_times.any()
assert not off_emb_boxes.any()
assert not off_emb_times.any()


def test_embedding_kwargs():
"""Test embedding config logic."""
Expand All @@ -154,52 +190,35 @@ def test_embedding_kwargs():

# sine embedding

sine_no_args = Embedding("pos", "fixed")(boxes)
sine_no_args = Embedding("pos", "fixed", 128)(boxes)

sine_args = {
"temperature": objects,
"scale": frames,
"normalize": True,
}

sine_with_args = Embedding("pos", "fixed", **sine_args)(boxes)
sine_with_args = Embedding("pos", "fixed", 128, **sine_args)(boxes)

assert not torch.equal(sine_no_args, sine_with_args)

# learned pos embedding

lp_no_args = Embedding("pos", "learned")(boxes)
lp_no_args = Embedding("pos", "learned", 128)

lp_args = {"emb_num": 100, "over_boxes": False}

lp_with_args = Embedding("pos", "learned", **lp_args)(boxes)
lp_with_args = Embedding("pos", "learned", 128, **lp_args)
assert lp_no_args.lookup.weight.shape != lp_with_args.lookup.weight.shape

assert not torch.equal(lp_no_args, lp_with_args)
assert not torch.equal(lp_no_args, sine_no_args)
assert not torch.equal(lp_no_args, sine_with_args)
assert not torch.equal(lp_with_args, sine_no_args)
assert not torch.equal(lp_with_args, sine_with_args)
# learned temp embedding

lt_no_args = Embedding("temp", "learned")(times)
lt_no_args = Embedding("temp", "learned", 128)

lt_args = {"emb_num": 100}

lt_with_args = Embedding("temp", "learned", **lt_args)(times)

assert not torch.equal(lt_no_args, lt_with_args)

assert not torch.equal(lt_no_args, lp_no_args)
assert not torch.equal(lt_no_args, lp_with_args)

assert not torch.equal(lt_no_args, sine_no_args)
assert not torch.equal(lt_no_args, sine_with_args)

assert not torch.equal(lt_with_args, lp_no_args)
assert not torch.equal(lt_with_args, lp_with_args)

assert not torch.equal(lt_with_args, sine_no_args)
assert not torch.equal(lt_with_args, sine_with_args)
lt_with_args = Embedding("temp", "learned", 128, **lt_args)
assert lt_no_args.lookup.weight.shape != lt_with_args.lookup.weight.shape


def test_transformer_encoder():
Expand Down
3 changes: 2 additions & 1 deletion tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from biogtr.config import Config
from biogtr.training.train import main

# todo: add named tensor tests
# TODO: add named tensor tests
# TODO: use temp dir and cleanup after tests (https://docs.pytest.org/en/7.1.x/how-to/tmp_path.html)


def test_asso_loss():
Expand Down

0 comments on commit c6f2606

Please sign in to comment.