This repository has been archived by the owner on Nov 16, 2019. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 357
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #176 from yahoo/lstm_inference
LSTM support in CaffeOnSpark
- Loading branch information
Showing
32 changed files
with
2,854 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from tools import * | ||
|
||
__all__=["tools"] |
92 changes: 92 additions & 0 deletions
92
caffe-grid/src/main/python/com/yahoo/ml/caffe/tools/DFConversions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
''' | ||
Copyright 2016 Yahoo Inc. | ||
Licensed under the terms of the Apache 2.0 license. | ||
Please see LICENSE file in the project root for terms. | ||
''' | ||
from PIL import Image | ||
from io import BytesIO | ||
from IPython.display import HTML | ||
import numpy as np | ||
from base64 import b64encode | ||
from google.protobuf import text_format | ||
import array | ||
from com.yahoo.ml.caffe.ConversionUtil import wrapClass, getScalaSingleton, toPython | ||
from com.yahoo.ml.caffe.RegisterContext import registerContext | ||
from pyspark.sql import DataFrame,SQLContext | ||
|
||
class DFConversions: | ||
""" | ||
:ivar SparkContext: The spark context of the current spark session | ||
""" | ||
|
||
def __init__(self,sc): | ||
registerContext(sc) | ||
wrapClass("com.yahoo.ml.caffe.tools.Conversions$") | ||
self.__dict__['conversions']=toPython(getScalaSingleton("com.yahoo.ml.caffe.tools.Conversions")) | ||
self.__dict__['sqlContext']=SQLContext(sc) | ||
|
||
def Coco2ImageCaptionFile(self,src,clusterSize): | ||
"""Convert Cocodataset to Image Caption Dataframe | ||
:param src: the source for coco dataset i.e the caption file | ||
:param clusterSize: No. of executors | ||
""" | ||
df = self.__dict__.get('conversions').Coco2ImageCaptionFile(self.__dict__.get('sqlContext'), src, clusterSize) | ||
pydf = DataFrame(df,self.__dict__.get('sqlContext')) | ||
return pydf | ||
|
||
|
||
def Image2Embedding(self, imageRootFolder, imageCaptionDF): | ||
"""Get the embedding for the image as a dataframe | ||
:param imageRootFolder: the src folder of the images | ||
:param imageCaptionDF: the dataframe with the image file and image attributes | ||
""" | ||
df = self.__dict__.get('conversions').Image2Embedding(imageRootFolder, imageCaptionDF._jdf) | ||
pydf = DataFrame(df,self.__dict__.get('sqlContext')) | ||
return pydf | ||
|
||
def ImageCaption2Embedding(self, imageRootFolder, imageCaptionDF, vocab, captionLength): | ||
"""Get the embedding for the images as well as the caption as a dataframe | ||
:param imageRootFolder: the src folder of the images | ||
:param imageCaptionDF: the dataframe with the images as well as captions | ||
:param vocab: the vocab object | ||
:param captionLength: Length of the embedding to generate for the caption | ||
""" | ||
df = self.__dict__.get('conversions').ImageCaption2Embedding(imageRootFolder, imageCaptionDF._jdf, vocab.vocabObject, captionLength) | ||
pydf = DataFrame(df,self.__dict__.get('sqlContext')) | ||
return pydf | ||
|
||
|
||
def Embedding2Caption(self, embeddingDF, vocab, embeddingColumn, captionColumn): | ||
"""Get the captions from the embeddings | ||
:param embeddingDF: the dataframe which contains the embedding | ||
:param vocab: the vocab object | ||
:param embeddingColumn: the embedding column name in embeddingDF which contains the caption embedding | ||
""" | ||
df = self.__dict__.get('conversions').Embedding2Caption(embeddingDF._jdf, vocab.vocabObject, embeddingColumn, captionColumn) | ||
pydf = DataFrame(df,self.__dict__.get('sqlContext')) | ||
return pydf | ||
|
||
|
||
def get_image(image): | ||
bytes = array.array('b', image) | ||
return "<img src='data:image/png;base64," + b64encode(bytes) + "' />" | ||
|
||
|
||
def show_captions(df, nrows=10): | ||
"""Displays a table of captions(both original as well as predictions) with their images, inline in html | ||
:param DataFrame df: A python dataframe | ||
:param int nrows: First n rows to display from the dataframe | ||
""" | ||
data = df.take(nrows) | ||
html = "<table><tr><th>Image Id</th><th>Image</th><th>Prediction</th>" | ||
for i in range(nrows): | ||
row = data[i] | ||
html += "<tr>" | ||
html += "<td>%s</td>" % row.id | ||
html += "<td>%s</td>" % get_image(row.data.image) | ||
html += "<td>%s</td>" % row.prediction | ||
html += "</tr>" | ||
html += "</table>" | ||
return HTML(html) |
43 changes: 43 additions & 0 deletions
43
caffe-grid/src/main/python/com/yahoo/ml/caffe/tools/Vocab.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
''' | ||
Copyright 2016 Yahoo Inc. | ||
Licensed under the terms of the Apache 2.0 license. | ||
Please see LICENSE file in the project root for terms. | ||
''' | ||
|
||
from com.yahoo.ml.caffe.ConversionUtil import wrapClass | ||
from com.yahoo.ml.caffe.RegisterContext import registerContext | ||
from pyspark.sql import DataFrame,SQLContext | ||
|
||
class Vocab: | ||
""" | ||
:ivar SparkContext: The spark context of the current spark session | ||
""" | ||
|
||
def __init__(self,sc): | ||
registerContext(sc) | ||
self.vocab=wrapClass("com.yahoo.ml.caffe.tools.Vocab") | ||
self.sqlContext=SQLContext(sc) | ||
self.vocabObject=self.vocab(self.sqlContext) | ||
|
||
def genFromData(self,dataset,columnName,vocabSize): | ||
"""Convert generate the vocabulary from dataset | ||
:param dataset: dataframe containing the captions | ||
:param columnName: column in the dataset which has the caption | ||
:param vocabSize: Size of the vocabulary to generate (with vocab in descending order) | ||
""" | ||
self.vocabObject.genFromData(dataset._jdf,columnName,vocabSize) | ||
|
||
def save(self, vocabFilePath): | ||
"""Save the generated vocabulary | ||
:param vocabFilePath: the name of the file to save the vocabulary to | ||
""" | ||
self.vocabObject.save(vocabFilePath) | ||
|
||
def load(self, vocabFilePath): | ||
"""Load the vocabulary from a file | ||
:param vocabFilePath: the name of the file to load the vocabulary from | ||
""" | ||
self.vocabObject.load(vocabFilePath) | ||
|
||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright 2016 Yahoo Inc. | ||
# Licensed under the terms of the Apache 2.0 license. | ||
# Please see LICENSE file in the project root for terms. | ||
import caffe | ||
from examples.coco.retrieval_experiment import * | ||
from pyspark.sql import SQLContext | ||
from pyspark import SparkConf,SparkContext | ||
from pyspark.sql.types import * | ||
from itertools import izip_longest | ||
import json | ||
import argparse | ||
|
||
def predict_caption(list_of_images, model, imagenet, lstmnet, vocab): | ||
out_iterator = [] | ||
ce = CaptionExperiment(str(model),str(imagenet),str(lstmnet),str(vocab)) | ||
for image in list_of_images: | ||
out_iterator.append(ce.getCaption(image)) | ||
return iter(out_iterator) | ||
|
||
def get_predictions(sqlContext, images, model, imagenet, lstmnet, vocab): | ||
rdd = images.mapPartitions(lambda im: predict_caption(im, model, imagenet, lstmnet, vocab)) | ||
INNERSCHEMA = StructType([StructField("id", StringType(), True),StructField("prediction", StringType(), True)]) | ||
schema = StructType([StructField("result", INNERSCHEMA, True)]) | ||
return sqlContext.createDataFrame(rdd, schema).select("result.id", "result.prediction") | ||
|
||
def main(): | ||
conf = SparkConf() | ||
sc = SparkContext(conf=conf) | ||
sqlContext = SQLContext(sc) | ||
cmdargs = conf.get('spark.pythonargs') | ||
parser = argparse.ArgumentParser(description="Image to Caption Util") | ||
parser.add_argument('-input', action="store", dest="input") | ||
parser.add_argument('-model', action="store", dest="model") | ||
parser.add_argument('-imagenet', action="store", dest="imagenet") | ||
parser.add_argument('-lstmnet', action="store", dest="lstmnet") | ||
parser.add_argument('-vocab', action="store", dest="vocab") | ||
parser.add_argument('-output', action="store", dest="output") | ||
|
||
args=parser.parse_args(cmdargs.split(" ")) | ||
|
||
df_input = sqlContext.read.parquet(str(args.input)) | ||
images = df_input.select("data.image","data.height", "data.width", "id") | ||
df=get_predictions(sqlContext, images, str(args.model), str(args.imagenet), str(args.lstmnet), str(args.vocab)) | ||
df.write.json(str(args.output)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() | ||
|
||
|
Oops, something went wrong.