You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
import sys
from absl import app
from absl import flags
from absl import logging
from lit_nlp import dev_server
from lit_nlp import server_flags
from lit_nlp.api import model as lit_model
from lit_nlp.api import types as lit_types
# Use the regular GLUE data loaders, because these are very simple already.
from lit_nlp.examples.datasets import glue
from lit_nlp.lib import utils
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types as lit_types
import tensorflow_datasets as tfds
from transformers import BertTokenizer, BertForSequenceClassification
import pandas as pd
import torch
import transformers
df = pd.read_excel("data.xlsx",sheet_name='master_data')
print(df.shape)
df = df[df['train'] == 1]
df = df.head(100)
df = df[['UTTERANCE','label']]
df['label'] = df['label'].astype(int)
print(df.head(2))
def load_tfds(*args, do_sort=True, **kw):
"""Load from TFDS, with optional sorting."""
# Materialize to NumPy arrays.
# This also ensures compatibility with TF1.x non-eager mode, which doesn't
# support direct iteration over a tf.data.Dataset.
# ds = tfds.load('glue/sst2', split='train', shuffle_files=True,download=True)
ret = df.values.tolist()
print(ret)
# if do_sort:
# # Recover original order, as if you loaded from a TSV file.
# ret.sort(key=lambda ex: ex['idx'])
return ret
class SST2Data(lit_dataset.Dataset):
"""Stanford Sentiment Treebank, binary version (SST-2).
See https://www.tensorflow.org/datasets/catalog/glue#gluesst2.
"""
LABELS = ['0', '1']
def __init__(self, data):
self._examples = []
for ex in load_tfds(df):
self._examples.append({
'sentence': ex[0],
'label': self.LABELS[ex[1]],
})
print(self._examples)
def spec(self):
return {
'sentence': lit_types.TextSegment(),
'label': lit_types.CategoryLabel(vocab=self.LABELS)
}
FLAGS = flags.FLAGS
FLAGS.set_default("development_demo", True)
flags.DEFINE_string(
"model_path",
"https://storage.googleapis.com/what-if-tool-resources/lit-models/sst2_tiny.tar.gz",
"Path to trained model, in standard transformers format, e.g. as "
"saved by model.save_pretrained() and tokenizer.save_pretrained()")
def _from_pretrained(cls, *args, **kw):
"""Load a transformers model in PyTorch, with fallback to TF2/Keras weights."""
try:
return cls.from_pretrained(*args, **kw)
except OSError as e:
logging.warning("Caught OSError loading model: %s", e)
logging.warning(
"Re-trying to convert from TensorFlow checkpoint (from_tf=True)")
return cls.from_pretrained(*args, from_tf=True, **kw)
class SimpleSentimentModel(lit_model.Model):
"""Simple sentiment analysis model."""
LABELS = ["0", "1"] # negative, positive
def __init__(self, model_name_or_path):
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# This is a just a regular PyTorch model.
self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased',num_labels=2,output_hidden_states=True,output_attentions=True)
self.model.eval()
##
# LIT API implementation
def max_minibatch_size(self):
# This tells lit_model.Model.predict() how to batch inputs to
# predict_minibatch().
# Alternately, you can just override predict() and handle batching yourself.
return 32
def predict_minibatch(self, inputs):
# Preprocess to ids and masks, and make the input batch.
encoded_input = self.tokenizer.batch_encode_plus(
[ex["sentence"] for ex in inputs],
return_tensors="pt",
add_special_tokens=True,
max_length=256,
padding="longest",
truncation="longest_first")
# Check and send to cuda (GPU) if available
if torch.cuda.is_available():
self.model.cuda()
for tensor in encoded_input:
encoded_input[tensor] = encoded_input[tensor].cuda()
# Run a forward pass.
with torch.no_grad(): # remove this if you need gradients.
out: transformers.modeling_outputs.SequenceClassifierOutput = self.model(**encoded_input)
# Post-process outputs.
batched_outputs = {
"probas": torch.nn.functional.softmax(out.logits, dim=-1),
"input_ids": encoded_input["input_ids"],
"ntok": torch.sum(encoded_input["attention_mask"], dim=1),
"cls_emb": out.hidden_states[-1][:, 0], # last layer, first token
}
# Return as NumPy for further processing.
detached_outputs = {k: v.cpu().numpy() for k, v in batched_outputs.items()}
# Unbatch outputs so we get one record per input example.
for output in utils.unbatch_preds(detached_outputs):
ntok = output.pop("ntok")
output["tokens"] = self.tokenizer.convert_ids_to_tokens(
output.pop("input_ids")[1:ntok - 1])
yield output
def input_spec(self) -> lit_types.Spec:
return {
"sentence": lit_types.TextSegment(),
"label": lit_types.CategoryLabel(vocab=self.LABELS, required=False)
}
def output_spec(self) -> lit_types.Spec:
return {
"tokens": lit_types.Tokens(),
"probas": lit_types.MulticlassPreds(parent="label", vocab=self.LABELS,
null_idx=0),
"cls_emb": lit_types.Embeddings()
}
def get_wsgi_app():
"""Returns a LitApp instance for consumption by gunicorn."""
FLAGS.set_default("server_type", "external")
FLAGS.set_default("demo_mode", True)
# Parse flags without calling app.run(main), to avoid conflict with
# gunicorn command line flags.
unused = flags.FLAGS(sys.argv, known_only=True)
return main(unused)
def main(_):
# Normally path is a directory; if it's an archive file, download and
# extract to the transformers cache.
model_path = FLAGS.model_path
if model_path.endswith(".tar.gz"):
model_path = transformers.file_utils.cached_path(
model_path, extract_compressed_file=True)
# Load the model we defined above.
models = {"sst": SimpleSentimentModel(model_path)}
# Load SST-2 validation set from TFDS.
datasets = {"sst_dev": SST2Data(df)}
# Start the LIT server. See server_flags.py for server options.
lit_demo = dev_server.Server(models, datasets, **server_flags.get_flags())
return lit_demo.serve()
if __name__ == "__main__":
app.run(main)
The text was updated successfully, but these errors were encountered:
Here is the code.
The text was updated successfully, but these errors were encountered: