-
Notifications
You must be signed in to change notification settings - Fork 27
/
util.py
79 lines (60 loc) · 2.21 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from __future__ import print_function
import os
from os import path
import numpy as np
from keras import backend as K
from keras.utils import get_file
from tqdm import tqdm
def get_snli_file_path():
cache_dir = path.join(path.abspath(path.dirname(__file__)), 'data')
download_url = 'https://nlp.stanford.edu/projects/snli/snli_1.0.zip'
snli_dir = cache_dir + '/snli_1.0/'
if os.path.exists(snli_dir):
return snli_dir
get_file('/tmp/snli_1.0.zip',
origin=download_url,
cache_dir=cache_dir,
cache_subdir='',
extract=True)
return snli_dir
def get_word2vec_file_path(file_path):
if file_path is not None and path.exists(file_path):
return file_path
download_url = 'http://nlp.stanford.edu/data/glove.840B.300d.zip'
cache_dir = path.join(path.abspath(path.dirname(__file__)), 'data')
glove_file_path = cache_dir + '/glove.840B.300d.txt'
if path.exists(glove_file_path):
return glove_file_path
filename = '/tmp/glove.zip'
get_file(filename,
origin=download_url,
cache_dir=cache_dir,
cache_subdir='',
extract=True)
os.remove(filename)
return glove_file_path
class ChunkDataManager(object):
def __init__(self, load_data_path=None, save_data_path=None):
self.load_data_path = load_data_path
self.save_data_path = save_data_path
def load(self):
data = []
for file in tqdm(sorted(os.listdir(self.load_data_path))):
if not file.endswith('.npy'):
continue
data.append(np.load(self.load_data_path + '/' + file))
return data
def save(self, data):
if not os.path.exists(self.save_data_path):
os.mkdir(self.save_data_path)
print('Saving data of shapes:', [item.shape for item in data])
for i, item in tqdm(enumerate(data)):
np.save(self.save_data_path + '/' + str(i) + '.npy', item)
def broadcast_last_axis(x):
"""
:param x tensor of shape (batch, a, b)
:returns broadcasted tensor of shape (batch, a, b, a)
"""
y = K.expand_dims(x, 1) * 0
y = K.permute_dimensions(y, (0, 1, 3, 2))
return y + K.expand_dims(x)