Skip to content

Conversation

@tomasatdatabricks
Copy link
Contributor

@tomasatdatabricks tomasatdatabricks commented Dec 2, 2017

Use Spark 2.3's ImageSchema as image interface.

  • the biggest change is using opposite ordering of color channels - BGR instead of RGB, requires extra reordering in couple of places.
    -preserved ability to read and resize images in python using PIL to match Keras
    (resize gives different result but also reading jpegs produced images which were off by 1 on some green pixels)
  • needed few tweeks to run with spark 2.3 - notably UDFs are now referenced by SQL identifier and can not have dash as part of the name

[TODO] - In order to run on spark < 2.3, the image schema files have been copied here and need to be removed in the future.

Copy link
Collaborator

@smurching smurching left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks pretty good, just had a few small comments + some thoughts regarding error messages.

'TFImageTransformer', 'TFInputGraph', 'TFTransformer',
'DeepImagePredictor', 'DeepImageFeaturizer', 'KerasImageFileTransformer', 'KerasTransformer',
'imageInputPlaceholder']

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: remove the extra newlines here

@@ -1,4 +1,5 @@
// You may use this file to add plugin dependencies for sbt.
resolvers += "Local Spark repo" at "file:///Users/tomas/.m2/repository"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Eventually remove this

image_float = tf.decode_raw(image_buffer, tf.float32, name="decode_raw")

else:
raise ValueError('unsupported image data type "%s"' % img_dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for this. How hard would it be to print a list of supported image data types here (and potentially in other places where we validate data types)? Can we just use supportedOcvTypes from imageIO.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good idea. I'll update the error message. I can't use the OcvTypes though, currently this code is independent and it only knows (it's hardcoded) how to handle uint8 and float32.

"""
return udf(_resizeFunction(size), imageSchema)
if len(size) != 2:
raise ValueError("New image size should have for [height, width] but got {}".format(size))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"for" -> "format"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup yup, thanks for noticing!

:param imageDirectory: str, file path.
:param numPartition: int, number or partitions to use for reading files.
:return: DataFrame, with columns: (filepath: str, image: imageSchema).
def readImagesWithCustomLib(path, decode_f,numPartition = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (style): Space after commas

* @return Row image in spark.ml.image format with 3 channels in BGR order.
*/
private[sparkdl] def spImageFromBufferedImage(image: BufferedImage): Row = {
private[sparkdl] def spImageFromBufferedImage(image: BufferedImage, origin:String = null): Row = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (style): Space after ":" (i.e. origin: String)

outputCol="prediction",)
fullPredict = transformer.transform(self.imageDF).collect()
fileOrder = self.fileOrder
fullPredict = sorted(transformer.transform(self.imageDF).collect(),key=lambda x:fileOrder[x['image']['origin'].split('/')[-1]])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: Put this sorting into a helper function since it appears multiple times and/or add a comment explaining what it does?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, both suggestions make sense.

def builder(children: Seq[Expression]) = udf.apply(children.map(cx => new Column(cx)) : _*).expr
val registry = sqlCtx.sessionState.functionRegistry
registry.registerFunction(name, builder)
registry.registerFunction(sqlCtx.sessionState.sqlParser.parseFunctionIdentifier(name), builder)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line is a mouthful :P but it does seem like the right way to parse UDF names based on apache/spark#17518

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for referencing the PR!

Technically you could pass the identifier straight without calling the sql parse, however I hit an issue with that before - using an id with embedded dashes would fail silently and later you get weird error that the udf is not found so I think it's better/safer to parse it like this.

.drop(tfs_output_col)
)
mode = imageIO.imageTypeByName('CV_32FC%d' % orig_image.nChannels)
return Row(origin="",mode=mode.ord, height=height, width=width, nChannels=orig_image.nChannels, data=bytearray(np.array(numeric_data).astype(np.float32).tobytes()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit (style): Try to keep line length < 100 chars if possible

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have a Scala style guide (https://github.com/databricks/scala-style-guide) but I can't find one for Python -- AFAIK for Python we just follow PEP8 conventions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kk, makes sense :) I'll try to follow that in the future.


def imageTypeByOrdinal(ord):
if not ord in __ocvTypesByOrdinal:
raise KeyError("unsupported image type with ordinal " + ord)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: mention that we're working with OpenCV image types, e.g. "received unsupported OpenCV image type with ordinal " + ord

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be another place where it's useful to include a list of supported values (ordinals) in the error message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point

img = Image.open(BytesIO(imageData))
except IOError:
return None
def resizeImage_jvm(size):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline: Let's go ahead and remove this method, since it's not currently being used.

Copy link
Collaborator

@sueann sueann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some rough comments for you - I haven't looked at everything in detail. I believe @smurching has left you some comments as well so I'll send you these for now. I particular next time I will take a closer look at the logic changes in the transformers to make sure there's nothing we've missed. Thanks!

Note: Let's make a release (0.3.0) before this gets merged in and mark it as the last one backwards-compatible with Spark 2.2. We should also send an email to the mailing-list (now that we have one) about the breaking changes coming in 1.0.0.


from .graph.input import TFInputGraph
from .image.imageIO import imageSchema, imageType, readImages
from pyspark.ml.image import ImageSchema
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move above this group to its own group





Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

build.sbt Outdated
// http://www.scala-sbt.org/0.13/docs/index.html

val sparkVer = sys.props.getOrElse("spark.version", "2.1.1")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to ignore the built file for now while you work on the dependency problem. Let me know if you'd like me to look at it.

README.md Outdated
```

The resulting DataFrame contains a string column named "filePath" containing the path to each image file, and a image struct ("`SpImage`") column named "image" containing the decoded image data.
or alternatively, using PIL in python
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this way return the images in the same schema (BGR)? If not, I don't think we should recommend this (i.e. remove from the README).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup it does, it produces the same image schema with the same channel ordering. This is the one used in tests.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a user, when should I use the Spark version vs the PIL version? It seems confusing to have both here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I can remove it from the readme. I put it there because we actually did need to use the python read to match Keras so I thought user should know about the option.

elif img_dtype == 'float32':
image_float = tf.decode_raw(image_buffer, tf.float32, name="decode_raw")

else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the new schema, are there legitimate types that have float64 (or any other dtypes) as img_dtype?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK the schema does not specify types. It only specifies a field with OpenCv type number. There are open CV types which have float64. Technically the schema includes openCvTypes map with only a subset of types, however we already need types outside of this subset (Tf produced images are stored as float32)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So does ImageSchema support OpenCV types that have float64? If so, should we support them here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently as far as I know there is no way how you can get float64 image.

ImageSchema as a data format supports it in that it has a mode field which is supposed to have OpenCV type in it and there are OpenCV types with float64. However, it is not listed in the list of openCV types in their scala code (and neither are any float32 which we need) and as it stands now, readImages can only ever produce images stored in unsigned bytes (both scala an PIL version) so one of CV_8U* formats. We also need the float32 formats since thats' what we return when returning images from TF so I added those to our python side.

The python code from image schema can only handle unsigned byte images, thats why I use our own version in imageIO (imageArrayToStruct and imageStructToArray).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From offline discussion: The ImageSchema utilities in Spark only support uint8 types. Ideally float32 types would also be supported natively in Spark so we don't have to have special logic in this package to handle it. We'll create a Jira in Spark for that and try to address it there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a Jira for this already? If so, could you link from here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes we do. https://issues.apache.org/jira/browse/SPARK-22730

You mean you want it in the code? That would probably go to imageIO, I'll put it there

class ImageUtilsSuite extends FunSuite {
// We want to make sure to test ImageUtils in headless mode to ensure it'll work on all systems.
assert(System.getProperty("java.awt.headless") === "true")
// assert(System.getProperty("java.awt.headless") === "true")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

graphic.dispose()

spImageFromBufferedImage(tgtImg)
spImageFromBufferedImage(tgtImg,origin=ImageSchema.getOrigin(spImage))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: space between arguments

}

def getResizeImageUDF(h:Int,w:Int): UserDefinedFunction = udf( (x:Row) => {
resizeImage(h,w,3 /** hardcoded for now, currently resize code only accepts 3 channels **/,x);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: space between arguments

from sparkdl.utils import jvmapi as JVMAPI
from sparkdl.image.imageIO import imageSchema, imageArrayToStruct
from sparkdl.image.imageIO import imageArrayToStruct
from pyspark.ml.image import ImageSchema
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to the pyspark section above

assert isinstance(img_arr_reloaded, np.ndarray), \
"expect preprocessor to return a numpy array"
img_arr_reloaded = img_arr_reloaded.astype(np.uint8)
img_arr_reloaded = img_arr_reloaded.astype(np.uint8)[...,::-1]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we don't use ImageSchema in the context of the udfs, do we? i forget if the user ever gets access to the column with the image loaded throughout the udf using process. i'm wondering if we need to do this BGR - RGB back-and-forth here, especially given that a typical user will probably define a keras_load function using the RGB format.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The udf returns image shcema so it has to be reversed.

@tomasatdatabricks tomasatdatabricks force-pushed the tomas/ML_3037 branch 4 times, most recently from 13f0e1e to ec1a2f7 Compare December 5, 2017 17:24
@tomasatdatabricks tomasatdatabricks force-pushed the tomas/ML_3037 branch 16 times, most recently from 83cb7f7 to b6d0a27 Compare December 6, 2017 07:53
@codecov-io
Copy link

codecov-io commented Dec 6, 2017

Codecov Report

Merging #85 into master will decrease coverage by 3.5%.
The diff coverage is 63.96%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #85      +/-   ##
==========================================
- Coverage   85.99%   82.49%   -3.51%     
==========================================
  Files          30       33       +3     
  Lines        1692     1879     +187     
  Branches       17       35      +18     
==========================================
+ Hits         1455     1550      +95     
- Misses        237      329      +92
Impacted Files Coverage Δ
python/sparkdl/utils/keras_model.py 66.66% <ø> (ø) ⬆️
python/sparkdl/graph/input.py 98.05% <ø> (ø) ⬆️
python/sparkdl/param/shared_params.py 83.48% <ø> (ø) ⬆️
python/sparkdl/utils/jvmapi.py 96.77% <ø> (ø) ⬆️
python/sparkdl/transformers/keras_tensor.py 100% <ø> (ø) ⬆️
python/sparkdl/graph/tensorframes_udf.py 90% <ø> (ø) ⬆️
python/sparkdl/transformers/utils.py 100% <ø> (ø) ⬆️
...a/com/databricks/sparkdl/DeepImageFeaturizer.scala 95.38% <ø> (ø) ⬆️
python/sparkdl/__init__.py 100% <ø> (ø) ⬆️
...n/sparkdl/estimators/keras_image_file_estimator.py 74.35% <ø> (ø) ⬆️
... and 20 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 94452d6...5ef9a6b. Read the comment docs.

… spark-deep-learning, so we can use it against current production version of spark. Should be removed once Spark2.3 is released.
…ameter to TFImage to specify current channel order. Spark deep learning will automatically fix given graph to match the image schema input.
…rk.ml.image to sparkdl.image to avoid patchin the __path__ variable of the pyspark.ml.image module
@tomasatdatabricks
Copy link
Contributor Author

@sueann
I made the changes we talked about earlier, so that:

  1. I added channelOrder argument to TFImage and to buildSpImageConverter (code duplication, I'll consolidate these in subsequent PR, most likely will remove GraphFunction)
  2. When taking Keras model in udf we assume it's RGB and so is the output of the preprocessor.

Copy link
Collaborator

@sueann sueann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moar questions & comments for you.

for i, img in enumerate(images):
assert img is not None and img.mode == "RGB"
imageArray[i] = np.array(img.resize(shape))
imageArray[i] = imageIO._reverseChannels(np.array(img.resize(shape)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have getSampleImageList() return images in BGR. It's only used here, and there is nothing about that function that suggests it should return RGB (e.g. not keras-related or anything).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It returns PIL images, not image schema, that's what dictates the RGB order. I think it's fine.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From this test's perspective it doesn't need an image to come in as RGB and the logic here is not relevant to it. Is there any reason why getSampleImageList() shouldn't just return imageArray itself?

resized_images = tf.image.resize_images(image_arr, InceptionV3Constants.INPUT_SHAPE)
processed_images = preprocess_input(resized_images)
# keras expects array in RGB order, we get it from image schema in BGR => need to flip
processed_images = preprocess_input(imageIO._reverseChannels(resized_images))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO(sueann); check why we need this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In our previous discussion we decided to all Keras code is RGB, then the preprocessor returns RGB and we need to flip it to return valid image schema.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok i think that's fine

InceptionV3Constants.INPUT_SHAPE)
preprocessed = preprocess_input(resized_images)
# keras expects array in RGB order, we get it from image schema in BGR => need to flip
preprocessed = preprocess_input(imageIO._reverseChannels(resized_images))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO(sueann): check why we need this and whether we can put it somewhere else

preppedImage = cls.appModel._testPreprocess(imageArray.astype('float32'))
cls.kerasPredict = cls.appModel._testKerasModel(include_top=True).predict(preppedImage)
cls.preppedImage = preppedImage
cls.kerasPredict = cls.appModel._testKerasModel(include_top=True).predict(preppedImage,batch_size=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't preppedImage BGR? if so, wouldn't keras_model.predict() on it be wrong? (i'm confused why these tests are passing)

"""

def buildSpImageConverter(img_dtype):
def buildSpImageConverter(channelOrder,img_dtype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space after comma


xcpt_model = Xception(weights="imagenet")
stages = [('spimage', gfac.buildSpImageConverter(SparkMode.RGB_FLOAT32)),
stages = [('spimage', gfac.buildSpImageConverter('BGR','float32')),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space after ,

def test_resize(self):
imgAsRow = imageIO.imageArrayToStruct(array)
smaller = imageIO._resizeFunction([4, 5])
imgAsPIL = PIL.Image.fromarray(obj=imageIO._reverseChannels(array)).resize((5,4))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clearer reordering of lines (for readability):

        self.assertRaises(ValueError, imageIO.createResizeImageUDF, [1, 2, 3])

        make_smaller = imageIO.createResizeImageUDF([4, 5]).func
        imgAsRow = imageIO.imageArrayToStruct(array)
        smallerImg = make_smaller(imgAsRow)
        self.assertEqual(smallerImg.height, 4)
        self.assertEqual(smallerImg.width, 5)

        # Compare to PIL resizing
        imgAsPIL = PIL.Image.fromarray(obj=imageIO._reverseChannels(array)).resize((5,4))
        smallerAry = imageIO._reverseChannels(np.asarray(imgAsPIL))
        np.testing.assert_array_equal(smallerAry, imageIO.imageStructToArray(smallerImg))

        # I'm not sure what the following is testing so I don't know where they should go yet
        for n in ImageSchema.imageSchema['image'].dataType.names:
            smallerImg[n]
        sameImage = imageIO.createResizeImageUDF((imgAsRow.height, imgAsRow.width)).func(imgAsRow)
        self.assertEqual(imgAsRow, sameImage)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure I'll change the order. I believe the last part of the test tests that the image has the correct schema which this might not be the best way to test. I left it in since it was there before.

self.assertEqual(imgAsStruct.width, width)
self.assertEqual(imgAsStruct.data, array.tobytes())
imgReconstructed = imageIO.imageStructToArray(imgAsStruct)
np.testing.assert_array_equal(array,imgReconstructed)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space aftr ,

self.assertEqual(imgAsStruct.width, width)
self.assertEqual(len(imgAsStruct.data), array.size * 4)

# Check channel mismatch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want to keep these assertRaises tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assert raises assertions are not needed anymore. Data type is inferred from the array type instead of being passed as an argument. there is no invalid conversion any more.

_test(np.random.random_sample((10,11,nChannels)).astype('float32'))

def test_image_round_trip(self):
# Test round trip: array -> png -> sparkImg -> array
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may not need this test anymore but let me thikn about it (or we can chat offline)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need it until it gets moved out to Spark's ImageSchema.


__ocvTypesByName = {m.name:m for m in supportedOcvTypes}
__ocvTypesByOrdinal = {m.ord:m for m in supportedOcvTypes}
__ocvTypesByName = {m.name:m for m in _supportedOcvTypes}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double underscore is only for python constructs?

raise TypeError(err_msg.format(type(value), value))

return value
@staticmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newline between methods

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^this

return value
@staticmethod
def toChannelOrder(value):
if not value in ('RGB','BGR'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'L' (or whatever the one-channel one is) is also supported no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^this

# Test that resize with the same size is a no-op
sameImage = imageIO.createResizeImageUDF((imgAsRow.height, imgAsRow.width)).func(imgAsRow)
self.assertEqual(imgAsRow, sameImage)
# Test that we have a valid image schema (all fields are in)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the comments!

_test(np.random.randint(0, 256, (10, 11, nChannels), 'uint8'))
_test(np.random.random_sample((10,11,nChannels)).astype('float32'))

def test_image_round_trip(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove this one since the tests directly above and below test what this is testing more clearly.

shape = cls.appModel.inputShape()

imgFiles, images = getSampleImageList()
imageArray = np.empty((len(images), shape[0], shape[1], 3), 'uint8')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's move getSampleImagelist() as a "private" method in this test class and have it take care of resizing & reverseChannels as well. so a reader doesn't have to worry about getSampleImageList's return image type.

self.assertEqual(processed_images.shape[2], InceptionV3Constants.INPUT_SHAPE[1])

transformer = TFImageTransformer(inputCol="image", outputCol=outputCol, graph=g,
transformer = TFImageTransformer(channelOrder='BGR',inputCol="image", outputCol=outputCol, graph=g,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also test the RGB version separtely

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i see you have it above

self.assertTrue( (processed == keras_processed).all() )


# TODO: I believe this is already tested in named_image_test, should we remove it?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep this test and the one above as simple test cases for TFTransformer, one for BGR and one for RGB. what do you think?

Copy link
Collaborator

@sueann sueann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a couple of comments that require your attention - could you take a look? I have re-marked them (there are two) - the rest new ones are all about the space after commas 😂. Otherwise it looks good to me. Thanks!

raise TypeError(err_msg.format(type(value), value))

return value
@staticmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^this

return value
@staticmethod
def toChannelOrder(value):
if not value in ('RGB','BGR'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^this

graph=keras_graph)
image_df = self.loadImagesInternal(dataset, self.getInputCol())
transformer = TFImageTransformer(inputCol=self._loadedImageCol(),
transformer = TFImageTransformer(channelOrder='RGB',inputCol=self._loadedImageCol(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space

outputTensor=modelGraphSpec["outputTensorName"],
outputMode=modelGraphSpec["outputMode"])
resizeUdf = resizeImage(modelGraphSpec["inputTensorSize"])
tfTransformer = TFImageTransformer(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the previous styling was correct

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's still under 100 characters though. there are definitely longer lines in the file.


output_col = "transformed_image"
transformer = TFImageTransformer(inputCol="image", outputCol=output_col, graph=g,
transformer = TFImageTransformer(channelOrder='RGB',inputCol="image", outputCol=output_col, graph=g,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space

# dtype - data type of the image's array, sorted as a numpy compatible string.
#
# NOTE: likely to be migrated to Spark ImageSchema code in the near future.
_OcvType = namedtuple("OcvType",["name","ord","nChannels","dtype"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spaces....

"""
return sparkModeLookup[imageRow.mode]

return Row(origin=origin,mode=imageType.ord, height=height, width=width, nChannels=nChannels, data=data)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space

if imgType.nChannels != 1:
ary = _reverseChannels(ary)
if imgType.nChannels == 1:
return Image.fromarray(obj=ary,mode='L')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space here and below

imgAsPil = imageStructToPIL(imgAsRow).resize(sz)
# PIL is RGB based while image schema is BGR based => we need to flip the channels
imgAsArray = _reverseChannels(np.asarray(imgAsPil))
return imageArrayToStruct(imgAsArray,origin=imgAsRow.origin)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space

return None
decodeImage = udf(_decode, ImageSchema.imageSchema['image'].dataType)
imageData = filesToDF(sc, path, numPartitions=numPartition)
return imageData.select(decodeImage("filePath","fileData").alias("image"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space

Copy link
Collaborator

@sueann sueann left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one tiny thing to fix. We can merge after. Thanks!

schema = StructType([StructField("filePath", StringType(), False),
StructField("fileData", BinaryType(), False)])
rdd = sc.binaryFiles(path, minPartitions=numPartitions).repartition(numPartitions)
rdd = sc.binaryFiles(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol i think autopep8's spacing for function calls doesn't match pep8... but oh well, good enough for me.

raise ValueError("Unsupported channel order. Expected one of ('RGB','BGR') but got '%s'") % value
if not value in ('L', 'RGB', 'BGR'):
raise ValueError(
"Unsupported channel order. Expected one of ('RGB','BGR') but got '%s'") % value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add 'L' here

@sueann sueann merged commit aeff9c9 into databricks:master Dec 19, 2017
@sueann
Copy link
Collaborator

sueann commented Dec 19, 2017

@tomasatdatabricks do you mind sending an email to the sparkdl mailing list (google groups linked from the readme) about this change:

  • this commit changes the assumptions about the image channel ordering + background info on the ImageSchema being added to Spark 2.3
  • we'll release a version of DL Pipelines up to the point just before this commit just in case this breaks anyone's workflow. that'll contain newer features like DeepImageFeaturizer in Scala, and distributed prediction on general keras & tensorflow models (TFTransformer, KerasTransformer).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants