Skip to content

Commit ac8a5e5

Browse files
committed
supporting different compressions csv
1 parent 843be8e commit ac8a5e5

File tree

13 files changed

+233
-70
lines changed

13 files changed

+233
-70
lines changed

dataframe-csv/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/io/CsvTsvParams.kt

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import io.deephaven.csv.CsvSpecs
44
import org.apache.commons.csv.CSVFormat
55
import org.jetbrains.kotlinx.dataframe.api.ParserOptions
66
import org.jetbrains.kotlinx.dataframe.io.ColType
7+
import org.jetbrains.kotlinx.dataframe.io.CsvCompression
78
import org.jetbrains.kotlinx.dataframe.io.DEFAULT_COL_TYPE
89
import org.jetbrains.kotlinx.dataframe.io.QuoteMode
910

@@ -31,10 +32,11 @@ internal object CsvTsvParams {
3132
val HEADER: List<String> = emptyList()
3233

3334
/**
34-
* @param isCompressed If `true`, the input stream is compressed and will be decompressed before reading.
35-
* The default is `false`.
35+
* @param compression Determines the compression of the CSV file.
36+
* If a ZIP file contains multiple files, an [IllegalArgumentException] is thrown.
37+
* The default is [CsvCompression.None].
3638
*/
37-
const val IS_COMPRESSED: Boolean = false
39+
val COMPRESSION: CsvCompression<*> = CsvCompression.None
3840

3941
/**
4042
* @param colTypes A map of column names to their expected [ColType]s. Can be supplied to force
@@ -70,7 +72,7 @@ internal object CsvTsvParams {
7072
)
7173

7274
/**
73-
* @param ignoreEmptyLines If `true`, empty lines will be skipped.
75+
* @param ignoreEmptyLines If `true`, intermediate empty lines will be skipped.
7476
* The default is `false`.
7577
*/
7678
const val IGNORE_EMPTY_LINES: Boolean = false
@@ -79,9 +81,9 @@ internal object CsvTsvParams {
7981
* @param allowMissingColumns If this set to `true`, then rows that are too short
8082
* (that have fewer columns than the header row) will be interpreted as if the missing columns contained
8183
* the empty string.
82-
* The default is `false`.
84+
* The default is `true`.
8385
*/
84-
const val ALLOW_MISSING_COLUMNS: Boolean = false
86+
const val ALLOW_MISSING_COLUMNS: Boolean = true
8587

8688
/**
8789
* @param ignoreExcessColumns If this set to `true`, then rows that are too long
@@ -158,7 +160,7 @@ internal object CsvTsvParams {
158160
* @param recordSeparator The character that separates records in a CSV/TSV file.
159161
* The default is `'\n'`.
160162
*/
161-
const val RECORD_SEPARATOR: Char = '\n'
163+
const val RECORD_SEPARATOR: String = "\n"
162164

163165
/**
164166
* @param headerComments A list of comments to include at the beginning of the CSV/TSV file.

dataframe-csv/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/io/ioUtils.kt

Lines changed: 66 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,36 @@
11
package org.jetbrains.kotlinx.dataframe.impl.io
22

3+
import org.apache.commons.io.input.BOMInputStream
34
import org.jetbrains.kotlinx.dataframe.AnyFrame
45
import org.jetbrains.kotlinx.dataframe.DataFrame
6+
import org.jetbrains.kotlinx.dataframe.io.CsvCompression
7+
import org.jetbrains.kotlinx.dataframe.io.CsvCompression.Custom
8+
import org.jetbrains.kotlinx.dataframe.io.CsvCompression.Gzip
9+
import org.jetbrains.kotlinx.dataframe.io.CsvCompression.None
10+
import org.jetbrains.kotlinx.dataframe.io.CsvCompression.Zip
511
import org.jetbrains.kotlinx.dataframe.io.isURL
612
import org.jetbrains.kotlinx.dataframe.io.readJson
713
import java.io.File
814
import java.io.InputStream
915
import java.net.HttpURLConnection
1016
import java.net.URL
17+
import java.util.zip.ZipInputStream
1118

12-
internal fun isCompressed(fileOrUrl: String) = listOf("gz", "zip").contains(fileOrUrl.split(".").last())
19+
internal fun compressionStateOf(fileOrUrl: String): CsvCompression<*> =
20+
when (fileOrUrl.split(".").last()) {
21+
"gz" -> CsvCompression.Gzip
22+
"zip" -> CsvCompression.Zip
23+
else -> CsvCompression.None
24+
}
1325

14-
internal fun isCompressed(file: File) = listOf("gz", "zip").contains(file.extension)
26+
internal fun compressionStateOf(file: File): CsvCompression<*> =
27+
when (file.extension) {
28+
"gz" -> CsvCompression.Gzip
29+
"zip" -> CsvCompression.Zip
30+
else -> CsvCompression.None
31+
}
1532

16-
internal fun isCompressed(url: URL) = isCompressed(url.path)
33+
internal fun compressionStateOf(url: URL): CsvCompression<*> = compressionStateOf(url.path)
1734

1835
internal fun catchHttpResponse(url: URL, body: (InputStream) -> AnyFrame): AnyFrame {
1936
val connection = url.openConnection()
@@ -42,5 +59,50 @@ public fun asURL(fileOrUrl: String): URL =
4259
if (isURL(fileOrUrl)) {
4360
URL(fileOrUrl).toURI()
4461
} else {
45-
File(fileOrUrl).toURI()
62+
File(fileOrUrl).also {
63+
require(it.exists()) { "File not found: \"$fileOrUrl\"" }
64+
require(it.isFile) { "Not a file: \"$fileOrUrl\"" }
65+
}.toURI()
4666
}.toURL()
67+
68+
/**
69+
* Adjusts the input stream to be safe to use with the given compression algorithm as well
70+
* as any potential BOM characters.
71+
*
72+
* Also closes the stream after the block is executed.
73+
*/
74+
internal inline fun <T> InputStream.useSafely(compression: CsvCompression<*>, block: (InputStream) -> T): T {
75+
var zipInputStream: ZipInputStream? = null
76+
77+
// first wrap the stream in the compression algorithm
78+
val unpackedStream = when (compression) {
79+
None -> this
80+
81+
Zip -> compression(this).also {
82+
it as ZipInputStream
83+
// make sure to call nextEntry once to prepare the stream
84+
if (it.nextEntry == null) error("No entries in zip file")
85+
86+
zipInputStream = it
87+
}
88+
89+
Gzip -> compression(this)
90+
91+
is Custom<*> -> compression(this)
92+
}
93+
94+
val bomSafeStream = BOMInputStream.builder()
95+
.setInputStream(unpackedStream)
96+
.get()
97+
98+
try {
99+
return block(bomSafeStream)
100+
} finally {
101+
close()
102+
// if we were reading from a ZIP, make sure there was only one entry, as to
103+
// warn the user of potential issues
104+
if (compression == Zip && zipInputStream!!.nextEntry != null) {
105+
throw IllegalArgumentException("Zip file contains more than one entry")
106+
}
107+
}
108+
}

dataframe-csv/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/io/readCsvOrTsv.kt

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import kotlinx.datetime.Instant
2323
import kotlinx.datetime.LocalDate
2424
import kotlinx.datetime.LocalDateTime
2525
import kotlinx.datetime.LocalTime
26-
import org.apache.commons.io.input.BOMInputStream
2726
import org.jetbrains.kotlinx.dataframe.DataColumn
2827
import org.jetbrains.kotlinx.dataframe.DataFrame
2928
import org.jetbrains.kotlinx.dataframe.DataRow
@@ -34,11 +33,11 @@ import org.jetbrains.kotlinx.dataframe.api.tryParse
3433
import org.jetbrains.kotlinx.dataframe.columns.ValueColumn
3534
import org.jetbrains.kotlinx.dataframe.impl.ColumnNameGenerator
3635
import org.jetbrains.kotlinx.dataframe.io.ColType
36+
import org.jetbrains.kotlinx.dataframe.io.CsvCompression
3737
import org.jetbrains.kotlinx.dataframe.io.DEFAULT_COL_TYPE
3838
import java.io.InputStream
3939
import java.math.BigDecimal
4040
import java.net.URL
41-
import java.util.zip.GZIPInputStream
4241
import kotlin.reflect.KType
4342
import kotlin.reflect.full.withNullability
4443
import kotlin.reflect.typeOf
@@ -49,7 +48,7 @@ import kotlin.time.Duration
4948
* @include [CsvTsvParams.INPUT_STREAM]
5049
* @param delimiter The field delimiter character. The default is ',' for CSV, '\t' for TSV.
5150
* @include [CsvTsvParams.HEADER]
52-
* @include [CsvTsvParams.IS_COMPRESSED]
51+
* @include [CsvTsvParams.COMPRESSION]
5352
* @include [CsvTsvParams.COL_TYPES]
5453
* @include [CsvTsvParams.SKIP_LINES]
5554
* @include [CsvTsvParams.READ_LINES]
@@ -67,7 +66,7 @@ internal fun readCsvOrTsvImpl(
6766
inputStream: InputStream,
6867
delimiter: Char,
6968
header: List<String> = CsvTsvParams.HEADER,
70-
isCompressed: Boolean = CsvTsvParams.IS_COMPRESSED,
69+
compression: CsvCompression<*> = CsvTsvParams.COMPRESSION,
7170
colTypes: Map<String, ColType> = CsvTsvParams.COL_TYPES,
7271
skipLines: Long = CsvTsvParams.SKIP_LINES,
7372
readLines: Long? = CsvTsvParams.READ_LINES,
@@ -115,32 +114,30 @@ internal fun readCsvOrTsvImpl(
115114
colTypes(colTypes, useDeepHavenLocalDateTime) // this function must be last, so the return value is used
116115
}.build()
117116

118-
val adjustedInputStream = inputStream
119-
.let { if (isCompressed) GZIPInputStream(it) else it }
120-
.let { BOMInputStream.builder().setInputStream(it).get() }
121-
122-
if (adjustedInputStream.available() <= 0) {
123-
return if (header.isEmpty()) {
124-
DataFrame.empty()
125-
} else {
126-
dataFrameOf(
127-
header.map {
128-
DataColumn.createValueColumn(
129-
name = it,
130-
values = emptyList<String>(),
131-
type = typeOf<String>(),
132-
)
133-
},
134-
)
117+
val csvReaderResult = inputStream.useSafely(compression) { safeInputStream ->
118+
if (safeInputStream.available() <= 0) {
119+
return if (header.isEmpty()) {
120+
DataFrame.empty()
121+
} else {
122+
dataFrameOf(
123+
header.map {
124+
DataColumn.createValueColumn(
125+
name = it,
126+
values = emptyList<String>(),
127+
type = typeOf<String>(),
128+
)
129+
},
130+
)
131+
}
135132
}
136-
}
137133

138-
// read the csv
139-
val csvReaderResult = CsvReader.read(
140-
csvSpecs,
141-
adjustedInputStream,
142-
ListSink.SINK_FACTORY,
143-
)
134+
// read the csv
135+
CsvReader.read(
136+
csvSpecs,
137+
safeInputStream,
138+
ListSink.SINK_FACTORY,
139+
)
140+
}
144141

145142
val defaultColType = colTypes[DEFAULT_COL_TYPE]
146143

dataframe-csv/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/io/writeCsvOrTsv.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ internal fun writeCsvOrTsvImpl(
1818
escapeChar: Char? = CsvTsvParams.ESCAPE_CHAR,
1919
commentChar: Char? = CsvTsvParams.COMMENT_CHAR,
2020
headerComments: List<String> = CsvTsvParams.HEADER_COMMENTS,
21-
recordSeparator: Char = CsvTsvParams.RECORD_SEPARATOR,
21+
recordSeparator: String = CsvTsvParams.RECORD_SEPARATOR,
2222
additionalCsvFormat: CSVFormat = CsvTsvParams.ADDITIONAL_CSV_FORMAT,
2323
) {
2424
val format = with(CSVFormat.Builder.create(additionalCsvFormat)) {
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
package org.jetbrains.kotlinx.dataframe.io
2+
3+
import java.io.InputStream
4+
import java.util.zip.GZIPInputStream
5+
import java.util.zip.ZipInputStream
6+
7+
/**
8+
* Compression algorithm to use when reading csv files.
9+
* We support GZIP and ZIP compression out of the box.
10+
*
11+
* Custom compression algorithms can be added by creating an instance of [Custom].
12+
*/
13+
public sealed class CsvCompression<I : InputStream>(public open val wrapStream: (InputStream) -> I) :
14+
(InputStream) -> I by wrapStream {
15+
16+
public data object Gzip : CsvCompression<GZIPInputStream>(::GZIPInputStream)
17+
18+
public data object Zip : CsvCompression<ZipInputStream>(::ZipInputStream)
19+
20+
public data object None : CsvCompression<InputStream>({ it })
21+
22+
public data class Custom<I : InputStream>(override val wrapStream: (InputStream) -> I) :
23+
CsvCompression<I>(wrapStream)
24+
}

dataframe-csv/src/main/kotlin/org/jetbrains/kotlinx/dataframe/io/readCsv.kt

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import org.jetbrains.kotlinx.dataframe.api.ParserOptions
66
import org.jetbrains.kotlinx.dataframe.impl.io.CsvTsvParams
77
import org.jetbrains.kotlinx.dataframe.impl.io.asURL
88
import org.jetbrains.kotlinx.dataframe.impl.io.catchHttpResponse
9-
import org.jetbrains.kotlinx.dataframe.impl.io.isCompressed
9+
import org.jetbrains.kotlinx.dataframe.impl.io.compressionStateOf
1010
import org.jetbrains.kotlinx.dataframe.impl.io.readCsvOrTsvImpl
1111
import java.io.File
1212
import java.io.FileInputStream
@@ -24,6 +24,7 @@ public fun DataFrame.Companion.readCsv(
2424
file: File,
2525
delimiter: Char = CsvTsvParams.CSV_DELIMITER,
2626
header: List<String> = CsvTsvParams.HEADER,
27+
compression: CsvCompression<*> = compressionStateOf(file),
2728
colTypes: Map<String, ColType> = CsvTsvParams.COL_TYPES,
2829
skipLines: Long = CsvTsvParams.SKIP_LINES,
2930
readLines: Long? = CsvTsvParams.READ_LINES,
@@ -41,7 +42,7 @@ public fun DataFrame.Companion.readCsv(
4142
inputStream = it,
4243
delimiter = delimiter,
4344
header = header,
44-
isCompressed = isCompressed(file),
45+
compression = compression,
4546
colTypes = colTypes,
4647
skipLines = skipLines,
4748
readLines = readLines,
@@ -61,6 +62,7 @@ public fun DataFrame.Companion.readCsv(
6162
url: URL,
6263
delimiter: Char = CsvTsvParams.CSV_DELIMITER,
6364
header: List<String> = CsvTsvParams.HEADER,
65+
compression: CsvCompression<*> = compressionStateOf(url),
6466
colTypes: Map<String, ColType> = CsvTsvParams.COL_TYPES,
6567
skipLines: Long = CsvTsvParams.SKIP_LINES,
6668
readLines: Long? = CsvTsvParams.READ_LINES,
@@ -78,7 +80,7 @@ public fun DataFrame.Companion.readCsv(
7880
inputStream = it,
7981
delimiter = delimiter,
8082
header = header,
81-
isCompressed = isCompressed(url),
83+
compression = compression,
8284
colTypes = colTypes,
8385
skipLines = skipLines,
8486
readLines = readLines,
@@ -98,6 +100,7 @@ public fun DataFrame.Companion.readCsv(
98100
fileOrUrl: String,
99101
delimiter: Char = CsvTsvParams.CSV_DELIMITER,
100102
header: List<String> = CsvTsvParams.HEADER,
103+
compression: CsvCompression<*> = compressionStateOf(fileOrUrl),
101104
colTypes: Map<String, ColType> = CsvTsvParams.COL_TYPES,
102105
skipLines: Long = CsvTsvParams.SKIP_LINES,
103106
readLines: Long? = CsvTsvParams.READ_LINES,
@@ -115,7 +118,7 @@ public fun DataFrame.Companion.readCsv(
115118
inputStream = it,
116119
delimiter = delimiter,
117120
header = header,
118-
isCompressed = isCompressed(fileOrUrl),
121+
compression = compression,
119122
colTypes = colTypes,
120123
skipLines = skipLines,
121124
readLines = readLines,
@@ -136,7 +139,7 @@ public fun DataFrame.Companion.readCsv(
136139
inputStream: InputStream,
137140
delimiter: Char = CsvTsvParams.CSV_DELIMITER,
138141
header: List<String> = CsvTsvParams.HEADER,
139-
isCompressed: Boolean = CsvTsvParams.IS_COMPRESSED,
142+
compression: CsvCompression<*> = CsvTsvParams.COMPRESSION,
140143
colTypes: Map<String, ColType> = CsvTsvParams.COL_TYPES,
141144
skipLines: Long = CsvTsvParams.SKIP_LINES,
142145
readLines: Long? = CsvTsvParams.READ_LINES,
@@ -154,7 +157,7 @@ public fun DataFrame.Companion.readCsv(
154157
inputStream = inputStream,
155158
delimiter = delimiter,
156159
header = header,
157-
isCompressed = isCompressed,
160+
compression = compression,
158161
colTypes = colTypes,
159162
skipLines = skipLines,
160163
readLines = readLines,
@@ -174,7 +177,7 @@ public fun DataFrame.Companion.readCsvStr(
174177
text: String,
175178
delimiter: Char = CsvTsvParams.CSV_DELIMITER,
176179
header: List<String> = CsvTsvParams.HEADER,
177-
isCompressed: Boolean = CsvTsvParams.IS_COMPRESSED,
180+
compression: CsvCompression<*> = CsvTsvParams.COMPRESSION,
178181
colTypes: Map<String, ColType> = CsvTsvParams.COL_TYPES,
179182
skipLines: Long = CsvTsvParams.SKIP_LINES,
180183
readLines: Long? = CsvTsvParams.READ_LINES,
@@ -191,7 +194,7 @@ public fun DataFrame.Companion.readCsvStr(
191194
inputStream = text.byteInputStream(),
192195
delimiter = delimiter,
193196
header = header,
194-
isCompressed = isCompressed,
197+
compression = compression,
195198
colTypes = colTypes,
196199
skipLines = skipLines,
197200
readLines = readLines,

0 commit comments

Comments
 (0)