File tree Expand file tree Collapse file tree 2 files changed +1
-18
lines changed Expand file tree Collapse file tree 2 files changed +1
-18
lines changed Original file line number Diff line number Diff line change @@ -910,7 +910,6 @@ def numpy_mask_tokens(
910910 word_ids , no_mask_mask = self ._calc_word_ids_and_prob_mask (
911911 to_numpy (offset_mapping ), to_numpy (special_tokens_mask )
912912 )
913- no_mask_mask = np .array (no_mask_mask , dtype = bool )
914913 else :
915914 no_mask_mask = (
916915 special_tokens_mask .astype (bool )
@@ -998,7 +997,7 @@ def _calc_word_ids_and_prob_mask(
998997 word_ids = np .cumsum (is_new_word , axis = 1 )
999998 word_ids [special_tokens_mask ] = - 1
1000999
1001- prob_mask = ( ~ is_new_word ). astype ( int )
1000+ prob_mask = ~ is_new_word
10021001
10031002 return word_ids , prob_mask
10041003
@@ -1049,8 +1048,6 @@ def tolist(x) -> list[Any]:
10491048def to_numpy (x ) -> np .ndarray [Any ]:
10501049 if isinstance (x , np .ndarray ):
10511050 return x
1052- elif hasattr (x , "numpy" ):
1053- return x .numpy ()
10541051 elif hasattr (x , "detach" ):
10551052 return x .detach ().cpu ().numpy ()
10561053 else :
Original file line number Diff line number Diff line change 3030 DataCollatorWithFlattening ,
3131 DataCollatorWithPadding ,
3232 default_data_collator ,
33- is_tf_available ,
3433 is_torch_available ,
3534 set_seed ,
3635)
@@ -558,19 +557,6 @@ def test_data_collator_for_whole_word_mask(self):
558557 self .assertEqual (batch ["input_ids" ].shape , torch .Size ((2 , 10 )))
559558 self .assertEqual (batch ["labels" ].shape , torch .Size ((2 , 10 )))
560559
561- if is_tf_available ():
562- import tensorflow as tf
563-
564- # Features can already be tensors
565- features = [
566- tokenizer (" " .join (input_tokens ), return_offsets_mapping = True ).convert_to_tensors ("tf" )
567- for _ in range (2 )
568- ]
569- data_collator = DataCollatorForWholeWordMask (tokenizer , return_tensors = "tf" )
570- batch = data_collator (features )
571- self .assertEqual (batch ["input_ids" ].shape , tf .TensorShape ((2 , 10 )))
572- self .assertEqual (batch ["labels" ].shape , tf .TensorShape ((2 , 10 )))
573-
574560 def test_data_collator_for_whole_word_mask_with_seed (self ):
575561 tokenizer = BertTokenizerFast (self .vocab_file )
576562
You can’t perform that action at this time.
0 commit comments