Skip to content

Commit

Permalink
Fixed the issue where token IDs were not converted to word-piece IDs …
Browse files Browse the repository at this point in the history
…for BERT value linking. Closes #4.
  • Loading branch information
alexpolozov committed Aug 15, 2020
1 parent 4a013a9 commit 648fc87
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ If you use RAT-SQL in your work, please cite it as follows:
**2020-08-14:**
- The Docker image now inherits from a CUDA-enabled base image.
- Clarified memory and dataset requirements on the image.
- Fixed the issue where token IDs were not converted to word-piece IDs for BERT value linking.

## Usage

Expand Down
2 changes: 1 addition & 1 deletion ratsql/datasets/spider.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(self, paths, tables_paths, db_path, demo_path=None, limit=None):
for db_id, schema in tqdm(self.schemas.items(), desc="DB connections"):
sqlite_path = Path(db_path) / db_id / f"{db_id}.sqlite"
source: sqlite3.Connection
with sqlite3.connect(sqlite_path) as source:
with sqlite3.connect(str(sqlite_path)) as source:
dest = sqlite3.connect(':memory:')
dest.row_factory = sqlite3.Row
source.backup(dest)
Expand Down
22 changes: 19 additions & 3 deletions ratsql/models/spider/spider_enc.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,7 @@ def __init__(self, pieces):
self.pieces = pieces

self.normalized_pieces = None
self.recovered_pieces = None
self.idx_map = None

self.normalize_toks()
Expand Down Expand Up @@ -605,6 +606,7 @@ def normalize_toks(self):
normalized_toks.append(lemma_word)

self.normalized_pieces = normalized_toks
self.recovered_pieces = new_toks

def bert_schema_linking(self, columns, tables):
question_tokens = self.normalized_pieces
Expand All @@ -624,6 +626,21 @@ def bert_schema_linking(self, columns, tables):
new_sc_link[m_type] = _match
return new_sc_link

def bert_cv_linking(self, schema):
question_tokens = self.recovered_pieces # Not using normalized tokens here because values usually match exactly
cv_link = compute_cell_value_linking(question_tokens, schema)

new_cv_link = {}
for m_type in cv_link:
_match = {}
for ij_str in cv_link[m_type]:
q_id_str, col_tab_id_str = ij_str.split(",")
q_id, col_tab_id = int(q_id_str), int(col_tab_id_str)
real_q_id = self.idx_map[q_id]
_match[f"{real_q_id},{col_tab_id}"] = cv_link[m_type][ij_str]
new_cv_link[m_type] = _match
return new_cv_link


class SpiderEncoderBertPreproc(SpiderEncoderV2Preproc):

Expand Down Expand Up @@ -667,8 +684,8 @@ def add_item(self, item, section, validation_info):
def preprocess_item(self, item, validation_info):
question = self._tokenize(item.text, item.orig['question'])
preproc_schema = self._preprocess_schema(item.schema)
question_bert_tokens = Bertokens(question)
if self.compute_sc_link:
question_bert_tokens = Bertokens(question)
sc_link = question_bert_tokens.bert_schema_linking(
preproc_schema.normalized_column_names,
preproc_schema.normalized_table_names
Expand All @@ -677,8 +694,7 @@ def preprocess_item(self, item, validation_info):
sc_link = {"q_col_match": {}, "q_tab_match": {}}

if self.compute_cv_link:
question_bert_tokens = Bertokens(question)
cv_link = compute_cell_value_linking(question_bert_tokens.normalized_pieces, item.schema)
cv_link = question_bert_tokens.bert_cv_linking(item.schema)
else:
cv_link = {"num_date_match": {}, "cell_match": {}}

Expand Down

0 comments on commit 648fc87

Please sign in to comment.