@@ -1025,136 +1025,79 @@ def test_multi_images(self):
10251025class TruncateWithProtectedTokensTester (TrlTestCase ):
10261026 def test_basic_example (self ):
10271027 """Test the basic example from the problem description."""
1028- prompt_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 ], [6 , 7 , 8 , 9 , 10 ]])
1029- prompt_mask = torch .ones_like (prompt_ids )
1030- protected_tokens = [2 , 3 , 6 ]
1028+ prompt_ids = [1 , 2 , 3 , 4 , 5 ]
1029+ protected_tokens = [2 , 3 ]
10311030 target_length = 3
10321031
1033- new_ids , new_mask = truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1034-
1035- expected_ids = torch .tensor ([[2 , 3 , 5 ], [6 , 9 , 10 ]])
1036- expected_mask = torch .ones_like (expected_ids )
1032+ new_ids = truncate_with_protected_tokens (prompt_ids , target_length , protected_tokens )
10371033
1038- self . assertTrue ( torch . equal ( new_ids , expected_ids ))
1039- self .assertTrue ( torch . equal ( new_mask , expected_mask ) )
1034+ expected_ids = [ 2 , 3 , 5 ]
1035+ self .assertEqual ( new_ids , expected_ids )
10401036
10411037 def test_no_truncation_needed (self ):
10421038 """Test when target length equals current length."""
1043- prompt_ids = torch .tensor ([[1 , 2 , 3 ]])
1044- prompt_mask = torch .ones_like (prompt_ids )
1039+ prompt_ids = [1 , 2 , 3 ]
10451040 protected_tokens = [2 ]
10461041 target_length = 3
10471042
1048- new_ids , new_mask = truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1043+ new_ids = truncate_with_protected_tokens (prompt_ids , target_length , protected_tokens )
10491044
1050- self .assertTrue (torch .equal (new_ids , prompt_ids ))
1051- self .assertTrue (torch .equal (new_mask , prompt_mask ))
1045+ self .assertEqual (new_ids , prompt_ids )
10521046
10531047 def test_no_protected_tokens (self ):
10541048 """Test truncation with no protected tokens (normal right truncation)."""
1055- prompt_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 ]])
1056- prompt_mask = torch .ones_like (prompt_ids )
1049+ prompt_ids = [1 , 2 , 3 , 4 , 5 ]
10571050 protected_tokens = []
10581051 target_length = 3
10591052
1060- new_ids , new_mask = truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1053+ new_ids = truncate_with_protected_tokens (prompt_ids , target_length , protected_tokens )
10611054
1062- expected_ids = torch . tensor ([[ 3 , 4 , 5 ]]) # Last 3 tokens
1063- self .assertTrue ( torch . equal ( new_ids , expected_ids ) )
1055+ expected_ids = [ 3 , 4 , 5 ] # Last 3 tokens
1056+ self .assertEqual ( new_ids , expected_ids )
10641057
10651058 def test_all_tokens_protected (self ):
10661059 """Test when all remaining tokens are protected."""
1067- prompt_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 ]])
1068- prompt_mask = torch .ones_like (prompt_ids )
1060+ prompt_ids = [1 , 2 , 3 , 4 , 5 ]
10691061 protected_tokens = [3 , 4 , 5 ]
10701062 target_length = 3
10711063
1072- new_ids , new_mask = truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1064+ new_ids = truncate_with_protected_tokens (prompt_ids , target_length , protected_tokens )
10731065
1074- expected_ids = torch . tensor ([[ 3 , 4 , 5 ]])
1075- self .assertTrue ( torch . equal ( new_ids , expected_ids ) )
1066+ expected_ids = [ 3 , 4 , 5 ]
1067+ self .assertEqual ( new_ids , expected_ids )
10761068
10771069 def test_too_many_protected_tokens (self ):
10781070 """Test error when too many protected tokens for target length."""
1079- prompt_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 ]])
1080- prompt_mask = torch .ones_like (prompt_ids )
1071+ prompt_ids = [1 , 2 , 3 , 4 , 5 ]
10811072 protected_tokens = [1 , 2 , 3 , 4 ]
10821073 target_length = 3
10831074
10841075 with self .assertRaises (ValueError ):
1085- truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1076+ truncate_with_protected_tokens (prompt_ids , target_length , protected_tokens )
10861077
10871078 def test_single_batch_single_token (self ):
10881079 """Test edge case with single batch and single token."""
1089- prompt_ids = torch .tensor ([[5 ]])
1090- prompt_mask = torch .ones_like (prompt_ids )
1080+ prompt_ids = [5 ]
10911081 protected_tokens = [5 ]
10921082 target_length = 1
10931083
1094- new_ids , new_mask = truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1095-
1096- self .assertTrue (torch .equal (new_ids , prompt_ids ))
1097-
1098- def test_mask_preservation (self ):
1099- """Test that mask values are correctly preserved."""
1100- prompt_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 ]])
1101- prompt_mask = torch .tensor ([[1 , 0 , 1 , 0 , 1 ]]) # Mixed mask values
1102- protected_tokens = [2 , 4 ]
1103- target_length = 3
1104-
1105- new_ids , new_mask = truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1106-
1107- expected_ids = torch .tensor ([[2 , 4 , 5 ]])
1108- expected_mask = torch .tensor ([[0 , 0 , 1 ]]) # Corresponding mask values
1109-
1110- self .assertTrue (torch .equal (new_ids , expected_ids ))
1111- self .assertTrue (torch .equal (new_mask , expected_mask ))
1112-
1113- def test_multiple_batches_different_protected (self ):
1114- """Test multiple batches where protected tokens appear differently."""
1115- prompt_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 ], [2 , 6 , 7 , 8 , 9 ], [10 , 11 , 12 , 2 , 13 ]])
1116- prompt_mask = torch .ones_like (prompt_ids )
1117- protected_tokens = [2 ]
1118- target_length = 3
1119-
1120- new_ids , new_mask = truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1084+ new_ids = truncate_with_protected_tokens (prompt_ids , target_length , protected_tokens )
11211085
1122- expected_ids = torch .tensor (
1123- [
1124- [2 , 4 , 5 ], # 2 is protected, keep last 2 non-protected (4,5)
1125- [2 , 8 , 9 ], # 2 is protected, keep last 2 non-protected (8,9)
1126- [12 , 2 , 13 ], # 2 is protected, keep last 2 non-protected (12,13)
1127- ]
1128- )
1129-
1130- self .assertTrue (torch .equal (new_ids , expected_ids ))
1086+ self .assertEqual (new_ids , prompt_ids )
11311087
11321088 def test_order_preservation (self ):
11331089 """Test that relative order is preserved."""
1134- prompt_ids = torch .tensor ([[10 , 2 , 20 , 3 , 30 , 40 ]])
1135- prompt_mask = torch .ones_like (prompt_ids )
1090+ prompt_ids = [10 , 2 , 20 , 3 , 30 , 40 ]
11361091 protected_tokens = [2 , 3 ]
11371092 target_length = 4
11381093
1139- new_ids , new_mask = truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1094+ new_ids = truncate_with_protected_tokens (prompt_ids , target_length , protected_tokens )
11401095
1141- # Should keep protected tokens 2,3 and last 2 non-protected tokens 30,40
1096+ # Should keep protected tokens 2, 3 and last 2 non-protected tokens 30, 40
11421097 # Order should be: 2, 3, 30, 40 (maintaining original relative positions)
1143- expected_ids = torch .tensor ([[2 , 3 , 30 , 40 ]])
1144-
1145- self .assertTrue (torch .equal (new_ids , expected_ids ))
1146-
1147- def test_empty_protected_tokens_list (self ):
1148- """Test with empty protected tokens list."""
1149- prompt_ids = torch .tensor ([[1 , 2 , 3 , 4 , 5 ]])
1150- prompt_mask = torch .ones_like (prompt_ids )
1151- protected_tokens = []
1152- target_length = 2
1153-
1154- new_ids , new_mask = truncate_with_protected_tokens (prompt_ids , prompt_mask , target_length , protected_tokens )
1098+ expected_ids = [2 , 3 , 30 , 40 ]
11551099
1156- expected_ids = torch .tensor ([[4 , 5 ]]) # Last 2 tokens
1157- self .assertTrue (torch .equal (new_ids , expected_ids ))
1100+ self .assertEqual (new_ids , expected_ids )
11581101
11591102
11601103class UnsplitPixelValuesByGridTester (TrlTestCase ):
0 commit comments