Skip to content

Commit

Permalink
Merge pull request #49 from naver/ws
Browse files Browse the repository at this point in the history
Ws
  • Loading branch information
whwang299 authored Sep 30, 2019
2 parents 1c08a23 + 0e3a83b commit 523c006
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 114 deletions.
77 changes: 67 additions & 10 deletions add_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,20 @@
# All columns are treated as text - no attempt is made to sniff the type of value
# stored in the column.

import argparse, csv, json, os
import argparse, csv, json, os, re
from sqlalchemy import Column, create_engine, MetaData, String, Table


def get_table_name(table_id):
return 'table_{}'.format(table_id)

def csv_to_sqlite(table_id, csv_file_name, sqlite_file_name):

def csv_to_sqlite(table_id, csv_file_name, sqlite_file_name, working_folder='.'):
sqlite_file_name = os.path.join(working_folder, sqlite_file_name)
csv_file_name = os.path.join(working_folder, csv_file_name)

engine = create_engine('sqlite:///{}'.format(sqlite_file_name))

with open(csv_file_name) as f:
metadata = MetaData(bind=engine)
cf = csv.DictReader(f, delimiter=',')
Expand All @@ -30,29 +36,80 @@ def csv_to_sqlite(table_id, csv_file_name, sqlite_file_name):
table.insert().values(**row).execute()
return engine

def csv_to_json(table_id, csv_file_name, json_file_name):

def is_num(val):
pattern = re.compile(r'[-+]?\d*\.\d+|\d+')
if pattern.search(val):
return True
else:
return False


def get_types(rows):
types = []
row1 = rows[0]
types = []
for val in row1:
if is_num(val):
types.append('real')
else:
types.append('text')
return types


def get_refined_rows(rows, types):
real_idx = []
for i, type in enumerate(types):
if type == 'real':
real_idx.append(i)

if len(real_idx) == 0:
rrs = rows
else:
rrs = []
for row in rows:
rr = row
for idx in real_idx:
rr[idx] = float(row[idx])
rrs.append(rr)
return rrs





def csv_to_json(table_id, csv_file_name, json_file_name, working_folder='.'):
csv_file_name = os.path.join(working_folder, csv_file_name)
json_file_name = os.path.join(working_folder, json_file_name)
with open(csv_file_name) as f:
cf = csv.DictReader(f, delimiter=',')
record = {}
record['header'] = [(name or 'col{}'.format(i)) for i, name in enumerate(cf.fieldnames)]
record['page_title'] = None
record['types'] = ['text'] * len(cf.fieldnames)
record['id'] = table_id
record['caption'] = None
record['rows'] = [list(row.values()) for row in cf]
record['name'] = get_table_name(table_id)
with open(json_file_name, 'a+') as fout:
json.dump(record, fout)
fout.write('\n')

# infer type based on first row

record['types'] = get_types(rows=record['rows'])
refined_rows = get_refined_rows(rows=record['rows'], types=record['types'])
record['rows'] = refined_rows

# save
with open(json_file_name, 'a+') as fout:
json.dump(record, fout)
fout.write('\n')

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('split')
parser.add_argument('file', metavar='file.csv')
working_folder = './data_and_model'
args = parser.parse_args()
table_id = os.path.splitext(os.path.basename(args.file))[0]
csv_to_sqlite(table_id, args.file, '{}.db'.format(args.split))
csv_to_json(table_id, args.file, '{}.tables.jsonl'.format(args.split))
csv_to_sqlite(table_id, args.file, '{}.db'.format(args.split), working_folder)
csv_to_json(table_id, args.file, '{}.tables.jsonl'.format(args.split), working_folder)
print("Added table with id '{id}' (name '{name}') to {split}.db and {split}.tables.jsonl".format(
id=table_id, name=get_table_name(table_id), split=args.split))

1 change: 1 addition & 0 deletions run_infer.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python3 train.py --do_infer --infer_loop --trained --bert_type_abb uL --max_seq_leng 222
2 changes: 2 additions & 0 deletions run_make_table.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
python3 add_csv.py ctable ftable1.csv
python3 add_csv.py ctable ftable2.csv
1 change: 1 addition & 0 deletions run_train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python3 train.py --do_train --seed 1 --bS 16 --accumulate_gradients 2 --bert_type_abb uS --fine_tune --lr 0.001 --lr_bert 0.00001 --max_seq_leng 222
24 changes: 23 additions & 1 deletion sqlova/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# Apache License v2.0

# Wonseok Hwang
import os
import os, json
import random as python_random
from matplotlib.pylab import *


Expand Down Expand Up @@ -65,3 +66,24 @@ def json_default_type_checker(o):
"""
if isinstance(o, int64): return int(o)
raise TypeError


def load_jsonl(path_file, toy_data=False, toy_size=4, shuffle=False, seed=1):
data = []

with open(path_file, "r", encoding="utf-8") as f:
for idx, line in enumerate(f):
if toy_data and idx >= toy_size and (not shuffle):
break
t1 = json.loads(line.strip())
data.append(t1)

if shuffle and toy_data:
# When shuffle required, get all the data, shuffle, and get the part of data.
print(
f"If the toy-data is used, the whole data loaded first and then shuffled before get the first {toy_size} data")

python_random.Random(seed).shuffle(data) # fixed
data = data[:toy_size]

return data
Loading

0 comments on commit 523c006

Please sign in to comment.