Skip to content

Commit

Permalink
cleanup the codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
Impavidity authored and pnpnpn committed Aug 23, 2021
1 parent c90d4e0 commit 83af3f0
Show file tree
Hide file tree
Showing 13 changed files with 2,148 additions and 146 deletions.
176 changes: 101 additions & 75 deletions relogic/tabart-pretraining.py → relogic/logical-tabart-pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,27 @@
from dataclasses import dataclass, field
from typing import Optional
import torch

from transformers import (
MODEL_WITH_LM_HEAD_MAPPING,
AutoTokenizer,
HfArgumentParser,
PreTrainedTokenizer,
set_seed,
)
from relogic.pretrainkit.trainer import Trainer
from relogic.pretrainkit.multitask_trainer import Trainer
from relogic.pretrainkit.datasets.semparse.tabart import DataCollatorForTaBART, TaBARTDataset
from relogic.pretrainkit.datasets.semparse.text2sql import DataCollatorForQuerySchema2SQL, QuerySchema2SQLDataset
from relogic.pretrainkit.scorers.match_sequence import MatchSequenceScorer
from relogic.pretrainkit.models.semparse.tabart import TaBARTModel
from relogic.pretrainkit.models.semparse.logical_tabart import LogicalTaBARTModel
from relogic.pretrainkit.training_args import TrainingArguments
import relogic.utils.crash_on_ipy

logger = logging.getLogger(__name__)


MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

is_sagemaker = 'SM_MODEL_DIR' in os.environ

@dataclass
class ModelArguments:
Expand All @@ -73,17 +73,26 @@ class ModelArguments:
cache_dir: Optional[str] = field(
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
)
task: Optional[str] = field(
default="mlm", metadata={"help": "Learning target. mlm, col_pred, mlm+col_pred"}
pretraining_model: Optional[str] = field(
default=None, metadata={"help": "What is the model to use for pretraining."}
)
load_from_pretrained_ckpt: Optional[str] = field(
default=None, metadata={"help": "Initialize the model with pretrained checkpoint"}
)
pretrained_ckpt_dir: Optional[str] = field(
default="pretrained_checkpoint", metadata={"help": "Pretrained Checkpoint"}
)



@dataclass
class DataTrainingArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
"""

task_names: Optional[str] = field(
default=None, metadata={"help": "The name of tasks which are separated by ,"}
)
train_data_file: Optional[str] = field(
default=None, metadata={"help": "The input training data file (a text file)."}
)
Expand Down Expand Up @@ -114,11 +123,47 @@ class DataTrainingArguments:
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
not_use_text: bool = field(
default=False, metadata={"help": "To use text in pretraining or not"}
)
only_use_text: bool = field(
default=False, metadata={"help": "To only use text in pretraining or not"}
)
cross_lingual: bool = field(
default=False,
metadata={"help": "Whether to use Cross-lingual Tabart Training"},
)
dump_file_name: str = field(
default="eval_dump.json",
metadata={"help": "The file name of evaluation dumping."}
)

def get_dataset_by_name(pretraining_model, task_name, cross_lingual, tokenizer, file_path, use_text, only_use_text):
if task_name != "text2sql":
return TaBARTDataset(tokenizer=tokenizer, file_path=file_path, col_token="<col>",
task_name=task_name, use_text=use_text, only_use_text=only_use_text)
if task_name == "text2sql":
return QuerySchema2SQLDataset(tokenizer=tokenizer, file_path=file_path, task_name=task_name)

def get_dataset(args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False):
file_path = args.eval_data_file if evaluate else args.train_data_file
return TaBARTDataset(tokenizer=tokenizer, file_path=file_path, col_token="<col>")

def get_datasets(pretraining_model, args: DataTrainingArguments, tokenizer: PreTrainedTokenizer, evaluate=False):
file_paths = args.eval_data_file.split(",") if evaluate else args.train_data_file.split(",")
task_names = args.task_names.split(",")
datasets = [get_dataset_by_name(pretraining_model, task_name, args.cross_lingual, tokenizer, file_path, not args.not_use_text, args.only_use_text)
for task_name, file_path in zip(task_names, file_paths)]
return datasets

def get_data_collator_by_name(pretraining_model, task_name, cross_lingual, tokenizer):
if task_name != "text2sql":
return DataCollatorForTaBART(tokenizer=tokenizer, task=task_name, col_token="<col>")
if task_name == "text2sql":
return DataCollatorForQuerySchema2SQL(tokenizer=tokenizer)


def get_data_collators(pretraining_model, args: DataTrainingArguments, tokenizer: PreTrainedTokenizer):
task_names = args.task_names.split(",")
collators = [get_data_collator_by_name(pretraining_model, task_name, args.cross_lingual, tokenizer) for task_name in task_names]
return collators


def main():
Expand All @@ -129,21 +174,34 @@ def main():
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()

if is_sagemaker:
training_args.do_train = training_args.do_train_str == "True"
training_args.do_eval = training_args.do_eval_str == "True"
training_args.evaluate_during_training = training_args.evaluate_during_training_str == "True"
data_args.train_data_file = ",".join([os.path.join(os.environ['SM_CHANNEL_TRAIN'], item) for item in data_args.train_data_file.split(",")])
data_args.eval_data_file = ",".join([os.path.join(os.environ['SM_CHANNEL_TRAIN'], item) for item in data_args.eval_data_file.split(",")])
training_args.output_dir = os.environ['SM_MODEL_DIR']
model_args.pretrained_ckpt_dir = os.environ.get("SM_CHANNEL_PRETRAINED_CKPT_DIR", None)

if model_args.pretrained_ckpt_dir is not None and model_args.load_from_pretrained_ckpt is not None:
model_args.load_from_pretrained_ckpt = os.path.join(model_args.pretrained_ckpt_dir, model_args.load_from_pretrained_ckpt)

if data_args.eval_data_file is None and training_args.do_eval:
raise ValueError(
"Cannot do evaluation without an evaluation data file. Either supply a file to --eval_data_file "
"or remove the --do_eval argument."
)

if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)
if not is_sagemaker:
if (
os.path.exists(training_args.output_dir)
and os.listdir(training_args.output_dir)
and training_args.do_train
and not training_args.overwrite_output_dir
):
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
)

# Setup logging
logging.basicConfig(
Expand All @@ -164,91 +222,59 @@ def main():
# Set seed
set_seed(training_args.seed)

# Load pretrained model and tokenizer
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.

# if model_args.config_name:
# config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
# elif model_args.model_name_or_path:
# config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
# else:
# config = CONFIG_MAPPING[model_args.model_type]()
# logger.warning("You are instantiating a new config instance from scratch.")
#
# if model_args.tokenizer_name:
# tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir)
# elif model_args.model_name_or_path:
# tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
# else:
# raise ValueError(
# "You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
# "and load it from here, using --tokenizer_name"
# )
#
# if model_args.model_name_or_path:
# model = AutoModelWithLMHead.from_pretrained(
# model_args.model_name_or_path,
# from_tf=bool(".ckpt" in model_args.model_name_or_path),
# config=config,
# cache_dir=model_args.cache_dir,
# )
# else:
# logger.info("Training new model from scratch")
# model = AutoModelWithLMHead.from_config(config)
#
# model.resize_token_embeddings(len(tokenizer))
#
# if config.model_type in ["bert", "roberta", "distilbert", "camembert"] and not data_args.mlm:
# raise ValueError(
# "BERT and RoBERTa-like models do not have LM heads but masked LM heads. They must be run using the --mlm "
# "flag (masked language modeling)."
# )

"""Initialize models and tokenizer"""
if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir)
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=False)
elif model_args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=False)
else:
raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported, but you can do it from another script, save it,"
"and load it from here, using --tokenizer_name"
)
tokenizer.add_special_tokens({"additional_special_tokens": ["<col>"]})
model = TaBARTModel()

model = LogicalTaBARTModel(data_args.task_names)
model.bert.resize_token_embeddings(len(tokenizer))
model.bert_for_texttosql.resize_token_embeddings(len(tokenizer))
model.bert.model.shared.weight = model.bert_for_texttosql.model.shared.weight
model.bert.model.encoder.embed_tokens.weight = model.bert_for_texttosql.model.encoder.embed_tokens.weight

if training_args.do_eval and not training_args.do_train:
model_param = torch.load(os.path.join(model_args.model_name_or_path, "pytorch_model.bin"))
model.load_state_dict(model_param)
print("All key matched and load successfully.")

if data_args.block_size <= 0:
data_args.block_size = tokenizer.max_len
data_args.block_size = tokenizer.model_max_length
# Our input block size will be the max possible for the model
else:
data_args.block_size = min(data_args.block_size, tokenizer.max_len)
data_args.block_size = min(data_args.block_size, tokenizer.model_max_length)

# Get datasets

train_dataset = get_dataset(data_args, tokenizer=tokenizer) if training_args.do_train else None
eval_dataset = get_dataset(data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
train_datasets = get_datasets(model_args.pretraining_model, data_args, tokenizer=tokenizer) if training_args.do_train else None
eval_datasets = get_datasets(model_args.pretraining_model, data_args, tokenizer=tokenizer, evaluate=True) if training_args.do_eval else None
# data_collator = DataCollatorForLanguageModeling(
# tokenizer=tokenizer, mlm=data_args.mlm, mlm_probability=data_args.mlm_probability
# )
data_collator = DataCollatorForTaBART(tokenizer=tokenizer, task=model_args.task)

match_sequence_scorer = MatchSequenceScorer(bos_id=data_collator.label_bos_id, eos_id=data_collator.label_eos_id, output_path=os.path.join(training_args.output_dir, "eval_dump.json"))
data_collators = get_data_collators(model_args.pretraining_model, data_args, tokenizer=tokenizer)

eos_id = None
for data_collator in data_collators:
if eos_id is None:
eos_id = data_collator.label_eos_id
else:
assert eos_id == data_collator.label_eos_id
match_sequence_scorer = MatchSequenceScorer(
eos_id=eos_id, output_path=os.path.join(training_args.output_dir, "eval_dump.json"))
# Initialize our Trainer
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collators=data_collators,
train_datasets=train_datasets,
eval_datasets=eval_datasets,
compute_metrics=match_sequence_scorer
)

Expand Down
Empty file added relogic/logickit/__init__.py
Empty file.
Empty file.
67 changes: 67 additions & 0 deletions relogic/logickit/base/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import sys
import os
try:
import cPickle as pickle
except ImportError:
import pickle


class Memoize(object):
def __init__(self, f):
self.f = f
self.cache = {}

def __call__(self, *args):
if args not in self.cache:
self.cache[args] = self.f(*args)
return self.cache[args]

def load_pickle(path, memoized=True):
return _load_pickle_memoize(path) if memoized else _load_pickle(path)

def _load_pickle(path):
with open(path, 'rb') as f:
return pickle.load(f)

@Memoize
def _load_pickle_memoize(path):
return _load_pickle(path)


def write_pickle(o, path):
dir = path.rsplit('/', 1)[0]
if not os.path.exists(dir):
os.mkdir(dir)
with open(path, 'wb') as f:
pickle.dump(o, f, -1)

def log(*args):
msg = ' '.join(map(str, args))
sys.stdout.write(msg + '\n')
sys.stdout.flush()


def heading(*args):
log()
log(80 * '=')
log(*args)
log(80 * '=')


import torch
def print_rank_0(message, **kwargs):
"""If distributed is initialized print only on rank 0."""
# if torch.distributed.is_initialized():
# if torch.distributed.get_rank() == 0:
# print(message, flush=True, **kwargs)
# else:
# print(message, flush=True, **kwargs)
print(message, flush=True, **kwargs)
def is_rank_0():
# if torch.distributed.is_initialized():
# if torch.distributed.get_rank() == 0:
# return True
# else:
# return True
# return False
return True
Empty file.
3 changes: 3 additions & 0 deletions relogic/logickit/modules/span_extractors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from relogic.logickit.modules.span_extractors.endpoint_span_extractor import EndpointSpanExtractor
from relogic.logickit.modules.span_extractors.self_attentive_span_extractor import SelfAttentiveSpanExtractor
from relogic.logickit.modules.span_extractors.attentive_span_extractor import AttentiveSpanExtractor
Loading

0 comments on commit 83af3f0

Please sign in to comment.