@@ -73,14 +73,15 @@ def create_model(is_training, input_ids, input_mask, segment_ids, labels,
7373
7474 logits = tf .matmul (output_layer , output_weights , transpose_b = True )
7575 logits = tf .nn .bias_add (logits , output_bias )
76+ probabilities = tf .nn .softmax (logits , axis = - 1 )
7677 log_probs = tf .nn .log_softmax (logits , axis = - 1 )
7778
7879 one_hot_labels = tf .one_hot (labels , depth = num_labels , dtype = tf .float32 )
7980
8081 per_example_loss = - tf .reduce_sum (one_hot_labels * log_probs , axis = - 1 )
8182 loss = tf .reduce_mean (per_example_loss )
8283
83- return (loss , per_example_loss , logits )
84+ return (loss , per_example_loss , logits , probabilities )
8485
8586
8687def model_fn_builder (num_labels , learning_rate , num_train_steps ,
@@ -101,7 +102,7 @@ def model_fn(features, labels, mode, params): # pylint: disable=unused-argument
101102
102103 is_training = (mode == tf .estimator .ModeKeys .TRAIN )
103104
104- (total_loss , per_example_loss , logits ) = create_model (
105+ (total_loss , per_example_loss , logits , probabilities ) = create_model (
105106 is_training , input_ids , input_mask , segment_ids , label_ids , num_labels ,
106107 bert_hub_module_handle )
107108
@@ -130,8 +131,12 @@ def metric_fn(per_example_loss, label_ids, logits):
130131 mode = mode ,
131132 loss = total_loss ,
132133 eval_metrics = eval_metrics )
134+ elif mode == tf .estimator .ModeKeys .PREDICT :
135+ output_spec = tf .contrib .tpu .TPUEstimatorSpec (
136+ mode = mode , predictions = {"probabilities" : probabilities })
133137 else :
134- raise ValueError ("Only TRAIN and EVAL modes are supported: %s" % (mode ))
138+ raise ValueError (
139+ "Only TRAIN, EVAL and PREDICT modes are supported: %s" % (mode ))
135140
136141 return output_spec
137142
@@ -215,7 +220,8 @@ def main(_):
215220 model_fn = model_fn ,
216221 config = run_config ,
217222 train_batch_size = FLAGS .train_batch_size ,
218- eval_batch_size = FLAGS .eval_batch_size )
223+ eval_batch_size = FLAGS .eval_batch_size ,
224+ predict_batch_size = FLAGS .predict_batch_size )
219225
220226 if FLAGS .do_train :
221227 train_features = run_classifier .convert_examples_to_features (
@@ -265,6 +271,40 @@ def main(_):
265271 tf .logging .info (" %s = %s" , key , str (result [key ]))
266272 writer .write ("%s = %s\n " % (key , str (result [key ])))
267273
274+ if FLAGS .do_predict :
275+ predict_examples = processor .get_test_examples (FLAGS .data_dir )
276+ if FLAGS .use_tpu :
277+ # Discard batch remainder if running on TPU
278+ n = len (predict_examples )
279+ predict_examples = predict_examples [:(n - n % FLAGS .predict_batch_size )]
280+
281+ predict_file = os .path .join (FLAGS .output_dir , "predict.tf_record" )
282+ run_classifier .file_based_convert_examples_to_features (
283+ predict_examples , label_list , FLAGS .max_seq_length , tokenizer ,
284+ predict_file )
285+
286+ tf .logging .info ("***** Running prediction*****" )
287+ tf .logging .info (" Num examples = %d" , len (predict_examples ))
288+ tf .logging .info (" Batch size = %d" , FLAGS .predict_batch_size )
289+
290+ predict_input_fn = run_classifier .file_based_input_fn_builder (
291+ input_file = predict_file ,
292+ seq_length = FLAGS .max_seq_length ,
293+ is_training = False ,
294+ drop_remainder = FLAGS .use_tpu )
295+
296+ result = estimator .predict (input_fn = predict_input_fn )
297+
298+ output_predict_file = os .path .join (FLAGS .output_dir , "test_results.tsv" )
299+ with tf .gfile .GFile (output_predict_file , "w" ) as writer :
300+ tf .logging .info ("***** Predict results *****" )
301+ for prediction in result :
302+ probabilities = prediction ["probabilities" ]
303+ output_line = "\t " .join (
304+ str (class_probability )
305+ for class_probability in probabilities ) + "\n "
306+ writer .write (output_line )
307+
268308
269309if __name__ == "__main__" :
270310 flags .mark_flag_as_required ("data_dir" )
0 commit comments