Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BPEmbSerializable back in for serialized models #1232

Merged
merged 1 commit into from
Oct 20, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 131 additions & 43 deletions flair/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ def __str__(self):
def extra_repr(self):
return f"'{self.embeddings}'"


class OneHotEmbeddings(TokenEmbeddings):
"""One-hot encoded embeddings."""

Expand Down Expand Up @@ -516,11 +517,13 @@ def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:

if self.field == "text":
context_idxs = [
self.vocab_dictionary.get_idx_for_item(t.text) for t in sentence.tokens
self.vocab_dictionary.get_idx_for_item(t.text)
for t in sentence.tokens
]
else:
context_idxs = [
self.vocab_dictionary.get_idx_for_item(t.get_tag(self.field).value) for t in sentence.tokens
self.vocab_dictionary.get_idx_for_item(t.get_tag(self.field).value)
for t in sentence.tokens
]

one_hot_sentences.extend(context_idxs)
Expand Down Expand Up @@ -3171,77 +3174,130 @@ def __str__(self):


class ConvTransformNetworkImageEmbeddings(ImageEmbeddings):

def __init__(self, feats_in, convnet_parms, posnet_parms, transformer_parms):
super(ConvTransformNetworkImageEmbeddings, self).__init__()

adaptive_pool_func_map = {'max': AdaptiveMaxPool2d,
'avg': AdaptiveAvgPool2d}

convnet_arch = [] if convnet_parms['dropout'][0] <=0 else [Dropout2d(convnet_parms['dropout'][0])]
convnet_arch.extend([Conv2d(in_channels=feats_in, out_channels=convnet_parms['n_feats_out'][0],
kernel_size=convnet_parms['kernel_sizes'][0],
padding=convnet_parms['kernel_sizes'][0][0] // 2,
stride=convnet_parms['strides'][0],
groups=convnet_parms['groups'][0]),
ReLU()])
if '0' in convnet_parms['pool_layers_map']:
convnet_arch.append(MaxPool2d(kernel_size=convnet_parms['pool_layers_map']['0']))
adaptive_pool_func_map = {"max": AdaptiveMaxPool2d, "avg": AdaptiveAvgPool2d}

convnet_arch = (
[]
if convnet_parms["dropout"][0] <= 0
else [Dropout2d(convnet_parms["dropout"][0])]
)
convnet_arch.extend(
[
Conv2d(
in_channels=feats_in,
out_channels=convnet_parms["n_feats_out"][0],
kernel_size=convnet_parms["kernel_sizes"][0],
padding=convnet_parms["kernel_sizes"][0][0] // 2,
stride=convnet_parms["strides"][0],
groups=convnet_parms["groups"][0],
),
ReLU(),
]
)
if "0" in convnet_parms["pool_layers_map"]:
convnet_arch.append(
MaxPool2d(kernel_size=convnet_parms["pool_layers_map"]["0"])
)
for layer_id, (kernel_size, n_in, n_out, groups, stride, dropout) in enumerate(
zip(convnet_parms['kernel_sizes'][1:], convnet_parms['n_feats_out'][:-1],
convnet_parms['n_feats_out'][1:], convnet_parms['groups'][1:],
convnet_parms['strides'][1:], convnet_parms['dropout'][1:])):
zip(
convnet_parms["kernel_sizes"][1:],
convnet_parms["n_feats_out"][:-1],
convnet_parms["n_feats_out"][1:],
convnet_parms["groups"][1:],
convnet_parms["strides"][1:],
convnet_parms["dropout"][1:],
)
):
if dropout > 0:
convnet_arch.append(Dropout2d(dropout))
convnet_arch.append(
Conv2d(in_channels=n_in, out_channels=n_out, kernel_size=kernel_size, padding=kernel_size[0] // 2,
stride=stride, groups=groups))
Conv2d(
in_channels=n_in,
out_channels=n_out,
kernel_size=kernel_size,
padding=kernel_size[0] // 2,
stride=stride,
groups=groups,
)
)
convnet_arch.append(ReLU())
if str(layer_id+1) in convnet_parms['pool_layers_map']:
convnet_arch.append(MaxPool2d(kernel_size=convnet_parms['pool_layers_map'][str(layer_id+1)]))
convnet_arch.append(adaptive_pool_func_map[convnet_parms['adaptive_pool_func']](output_size=convnet_parms['output_size']))
if str(layer_id + 1) in convnet_parms["pool_layers_map"]:
convnet_arch.append(
MaxPool2d(
kernel_size=convnet_parms["pool_layers_map"][str(layer_id + 1)]
)
)
convnet_arch.append(
adaptive_pool_func_map[convnet_parms["adaptive_pool_func"]](
output_size=convnet_parms["output_size"]
)
)
self.conv_features = Sequential(*convnet_arch)
conv_feat_dim = convnet_parms['n_feats_out'][-1]
conv_feat_dim = convnet_parms["n_feats_out"][-1]
if posnet_parms is not None and transformer_parms is not None:
self.use_transformer = True
if posnet_parms['nonlinear']:
posnet_arch = [Linear(2, posnet_parms['n_hidden']), ReLU(), Linear(posnet_parms['n_hidden'], conv_feat_dim)]
if posnet_parms["nonlinear"]:
posnet_arch = [
Linear(2, posnet_parms["n_hidden"]),
ReLU(),
Linear(posnet_parms["n_hidden"], conv_feat_dim),
]
else:
posnet_arch = [Linear(2, conv_feat_dim)]
self.position_features = Sequential(*posnet_arch)
transformer_layer = TransformerEncoderLayer(d_model=conv_feat_dim, **transformer_parms['transformer_encoder_parms'])
self.transformer = TransformerEncoder(transformer_layer, num_layers=transformer_parms['n_blocks'])
transformer_layer = TransformerEncoderLayer(
d_model=conv_feat_dim, **transformer_parms["transformer_encoder_parms"]
)
self.transformer = TransformerEncoder(
transformer_layer, num_layers=transformer_parms["n_blocks"]
)
# <cls> token initially set to 1/D, so it attends to all image features equally
self.cls_token = Parameter(torch.ones(conv_feat_dim, 1) / conv_feat_dim)
self._feat_dim = conv_feat_dim
else:
self.use_transformer = False
self._feat_dim = convnet_parms['output_size'][0] * convnet_parms['output_size'][1] * conv_feat_dim

self._feat_dim = (
convnet_parms["output_size"][0]
* convnet_parms["output_size"][1]
* conv_feat_dim
)

def forward(self, x):
x = self.conv_features(x) # [b, d, h, w]
x = self.conv_features(x) # [b, d, h, w]
b, d, h, w = x.shape
if self.use_transformer:
# add positional encodings
y = torch.stack([torch.cat([torch.arange(h).unsqueeze(1)] * w, dim=1),
torch.cat([torch.arange(w).unsqueeze(0)] * h, dim=0)]) # [2, h, w
y = y.view([2, h*w]).transpose(1, 0) # [h*w, 2]
y = torch.stack(
[
torch.cat([torch.arange(h).unsqueeze(1)] * w, dim=1),
torch.cat([torch.arange(w).unsqueeze(0)] * h, dim=0),
]
) # [2, h, w
y = y.view([2, h * w]).transpose(1, 0) # [h*w, 2]
y = y.type(torch.float32).to(flair.device)
y = self.position_features(y).transpose(1, 0).view([d, h, w]) # [h*w, d] => [d, h, w]
y = y.unsqueeze(dim=0) # [1, d, h, w]
x = x + y # [b, d, h, w] + [1, d, h, w] => [b, d, h, w]
y = (
self.position_features(y).transpose(1, 0).view([d, h, w])
) # [h*w, d] => [d, h, w]
y = y.unsqueeze(dim=0) # [1, d, h, w]
x = x + y # [b, d, h, w] + [1, d, h, w] => [b, d, h, w]
# reshape the pixels into the sequence
x = x.view([b, d, h*w]) # [b, d, h*w]
x = x.view([b, d, h * w]) # [b, d, h*w]
# layer norm after convolution and positional encodings
x = F.layer_norm(x.permute([0,2,1]), (d,)).permute([0,2,1])
x = F.layer_norm(x.permute([0, 2, 1]), (d,)).permute([0, 2, 1])
# add <cls> token
x = torch.cat([x, torch.stack([self.cls_token] * b)], dim=2) # [b, d, h*w+1]
x = torch.cat(
[x, torch.stack([self.cls_token] * b)], dim=2
) # [b, d, h*w+1]
# transformer requires input in the shape [h*w+1, b, d]
x = x.view([b*d, h*w+1]).transpose(1, 0).view([h*w+1, b, d]) # [b, d, h*w+1] => [b*d, h*w+1] => [h*w+1, b*d] => [h*w+1, b*d]
x = self.transformer(x) # [h*w+1, b, d]
x = (
x.view([b * d, h * w + 1]).transpose(1, 0).view([h * w + 1, b, d])
) # [b, d, h*w+1] => [b*d, h*w+1] => [h*w+1, b*d] => [h*w+1, b*d]
x = self.transformer(x) # [h*w+1, b, d]
# the output is an embedding of <cls> token
x = x[-1, :, :] # [b, d]
x = x[-1, :, :] # [b, d]
else:
x = x.view([-1, self._feat_dim])
x = F.layer_norm(x, (self._feat_dim,))
Expand Down Expand Up @@ -3287,3 +3343,35 @@ def replace_with_language_code(string: str):
string = string.replace("spanish-", "es-")
string = string.replace("swedish-", "sv-")
return string


# TODO: keep for backwards compatibility, but remove in future
class BPEmbSerializable(BPEmb):
def __getstate__(self):
state = self.__dict__.copy()
# save the sentence piece model as binary file (not as path which may change)
state["spm_model_binary"] = open(self.model_file, mode="rb").read()
state["spm"] = None
return state

def __setstate__(self, state):
from bpemb.util import sentencepiece_load

model_file = self.model_tpl.format(lang=state["lang"], vs=state["vs"])
self.__dict__ = state

# write out the binary sentence piece model into the expected directory
self.cache_dir: Path = Path(flair.cache_root) / "embeddings"
if "spm_model_binary" in self.__dict__:
# if the model was saved as binary and it is not found on disk, write to appropriate path
if not os.path.exists(self.cache_dir / state["lang"]):
os.makedirs(self.cache_dir / state["lang"])
self.model_file = self.cache_dir / model_file
with open(self.model_file, "wb") as out:
out.write(self.__dict__["spm_model_binary"])
else:
# otherwise, use normal process and potentially trigger another download
self.model_file = self._load_file(model_file)

# once the modes if there, load it with sentence piece
state["spm"] = sentencepiece_load(self.model_file)