diff --git a/examples/run_glue_tpu.py b/examples/run_glue_tpu.py index 2274826f737808..d16d1d111b25ff 100644 --- a/examples/run_glue_tpu.py +++ b/examples/run_glue_tpu.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Finetuning the library models for sequence classification on GLUE (Bert, XLNet).""" +""" Finetuning the library models for sequence classification on GLUE (Bert, XLNet, RoBERTa).""" from __future__ import absolute_import, division, print_function @@ -62,11 +62,13 @@ logger = logging.getLogger(__name__) script_start_time = time.strftime("%Y%m%d_%H%M%S", time.gmtime()) -ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig)), ()) +ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in ( + BertConfig, XLNetConfig, RobertaConfig)), ()) MODEL_CLASSES = { 'bert': (BertConfig, BertForSequenceClassification, BertTokenizer), 'xlnet': (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer), + 'roberta': (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer), }