Skip to content

Commit 218ce4c

Browse files
committed
address comments
1 parent 3c8863c commit 218ce4c

File tree

5 files changed

+23
-19
lines changed

5 files changed

+23
-19
lines changed

mllib/src/main/scala/org/apache/spark/ml/source/image/ImageDataSource.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,33 +18,32 @@
1818
package org.apache.spark.ml.source.image
1919

2020
/**
21-
* `image` package implements Spark SQL data source API for loading IMAGE data as `DataFrame`.
21+
* `image` package implements Spark SQL data source API for loading image data as `DataFrame`.
2222
* The loaded `DataFrame` has one `StructType` column: `image`.
2323
* The schema of the `image` column is:
24-
* - origin: String (represents the origin of the image.
25-
* If loaded from files, then it is the file path)
24+
* - origin: String (represents the file path of the image)
2625
* - height: Int (height of the image)
2726
* - width: Int (width of the image)
2827
* - nChannels: Int (number of the image channels)
2928
* - mode: Int (OpenCV-compatible type)
3029
* - data: BinaryType (Image bytes in OpenCV-compatible order: row-wise BGR in most cases)
3130
*
32-
* To use IMAGE data source, you need to set "image" as the format in `DataFrameReader` and
31+
* To use image data source, you need to set "image" as the format in `DataFrameReader` and
3332
* optionally specify the data source options, for example:
3433
* {{{
3534
* // Scala
3635
* val df = spark.read.format("image")
37-
* .option("dropImageFailures", true)
36+
* .option("dropInvalid", true)
3837
* .load("data/mllib/images/partitioned")
3938
*
4039
* // Java
4140
* Dataset<Row> df = spark.read().format("image")
42-
* .option("dropImageFailures", true)
41+
* .option("dropInvalid", true)
4342
* .load("data/mllib/images/partitioned");
4443
* }}}
4544
*
46-
* IMAGE data source supports the following options:
47-
* - "dropImageFailures": Whether to drop the files that are not valid images from the result.
45+
* Image data source supports the following options:
46+
* - "dropInvalid": Whether to drop the files that are not valid images from the result.
4847
*
4948
* @note This IMAGE data source does not support saving images to files.
5049
*

mllib/src/main/scala/org/apache/spark/ml/source/image/ImageFileFormat.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ private[image] class ImageFileFormat extends FileFormat with DataSourceRegister
6969

7070
(file: PartitionedFile) => {
7171
val emptyUnsafeRow = new UnsafeRow(0)
72-
if (!imageSourceOptions.dropImageFailures && requiredSchema.isEmpty) {
72+
if (!imageSourceOptions.dropInvalid && requiredSchema.isEmpty) {
7373
Iterator(emptyUnsafeRow)
7474
} else {
7575
val origin = file.filePath
@@ -82,7 +82,7 @@ private[image] class ImageFileFormat extends FileFormat with DataSourceRegister
8282
Closeables.close(stream, true)
8383
}
8484
val resultOpt = ImageSchema.decode(origin, bytes)
85-
val filteredResult = if (imageSourceOptions.dropImageFailures) {
85+
val filteredResult = if (imageSourceOptions.dropInvalid) {
8686
resultOpt.toIterator
8787
} else {
8888
Iterator(resultOpt.getOrElse(ImageSchema.invalidImageRow(origin)))

mllib/src/main/scala/org/apache/spark/ml/source/image/ImageOptions.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,9 @@ private[image] class ImageOptions(
2424

2525
def this(parameters: Map[String, String]) = this(CaseInsensitiveMap(parameters))
2626

27-
val dropImageFailures = parameters.getOrElse("dropImageFailures", "false").toBoolean
27+
/**
28+
* Whether to drop invalid images. If true, invalid images will be removed, otherwise
29+
* invalid images will be returned with empty data and all other field filled with `-1`.
30+
*/
31+
val dropInvalid = parameters.getOrElse("dropInvalid", "false").toBoolean
2832
}

mllib/src/test/scala/org/apache/spark/ml/source/image/ImageFileFormatSuite.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class ImageFileFormatSuite extends SparkFunSuite with MLlibTestSparkContext {
3434
val df1 = spark.read.format("image").load(imagePath)
3535
assert(df1.count === 9)
3636

37-
val df2 = spark.read.format("image").option("dropImageFailures", true).load(imagePath)
37+
val df2 = spark.read.format("image").option("dropInvalid", true).load(imagePath)
3838
assert(df2.count === 8)
3939
}
4040

@@ -50,11 +50,11 @@ class ImageFileFormatSuite extends SparkFunSuite with MLlibTestSparkContext {
5050

5151
test("image datasource test: read non image") {
5252
val filePath = imagePath + "/cls=kittens/date=2018-01/not-image.txt"
53-
val df = spark.read.format("image").option("dropImageFailures", true)
53+
val df = spark.read.format("image").option("dropInvalid", true)
5454
.load(filePath)
5555
assert(df.count() === 0)
5656

57-
val df2 = spark.read.format("image").option("dropImageFailures", false)
57+
val df2 = spark.read.format("image").option("dropInvalid", false)
5858
.load(filePath)
5959
assert(df2.count() === 1)
6060
val result = df2.head()
@@ -64,7 +64,7 @@ class ImageFileFormatSuite extends SparkFunSuite with MLlibTestSparkContext {
6464

6565
test("image datasource partition test") {
6666
val result = spark.read.format("image")
67-
.option("dropImageFailures", true).load(imagePath)
67+
.option("dropInvalid", true).load(imagePath)
6868
.select(substring_index(col("image.origin"), "/", -1).as("origin"), col("cls"), col("date"))
6969
.collect()
7070

@@ -82,15 +82,16 @@ class ImageFileFormatSuite extends SparkFunSuite with MLlibTestSparkContext {
8282

8383
// Images with the different number of channels
8484
test("readImages pixel values test") {
85-
val images = spark.read.format("image").option("dropImageFailures", true)
85+
val images = spark.read.format("image").option("dropInvalid", true)
8686
.load(imagePath + "/cls=multichannel/").collect()
8787

8888
val firstBytes20Set = images.map { rrow =>
8989
val row = rrow.getAs[Row]("image")
9090
val filename = Paths.get(getOrigin(row)).getFileName().toString()
9191
val mode = getMode(row)
9292
val bytes20 = getData(row).slice(0, 20).toList
93-
filename -> Tuple2(mode, bytes20)
93+
filename -> Tuple2(mode, bytes20) // Cannot remove `Tuple2`, otherwise `->` operator
94+
// will match 2 arguments
9495
}.toSet
9596

9697
assert(firstBytes20Set === expectedFirstBytes20Set)

python/pyspark/ml/tests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2186,7 +2186,7 @@ def tearDown(self):
21862186
class ImageReaderTest(SparkSessionTestCase):
21872187

21882188
def test_read_images(self):
2189-
data_path = 'data/mllib/images/kittens'
2189+
data_path = 'data/mllib/images/origin/kittens'
21902190
df = ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
21912191
self.assertEqual(df.count(), 4)
21922192
first_row = df.take(1)[0][0]
@@ -2253,7 +2253,7 @@ def tearDownClass(cls):
22532253
def test_read_images_multiple_times(self):
22542254
# This test case is to check if `ImageSchema.readImages` tries to
22552255
# initiate Hive client multiple times. See SPARK-22651.
2256-
data_path = 'data/mllib/images/kittens'
2256+
data_path = 'data/mllib/images/origin/kittens'
22572257
ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
22582258
ImageSchema.readImages(data_path, recursive=True, dropImageFailures=True)
22592259

0 commit comments

Comments
 (0)