Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(c++,spark): support json payload file format #518

Merged
merged 16 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/include/graphar/fwd.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ using IdType = int64_t;
enum class Type;
class DataType;
/** Type of file format */
enum FileType { CSV = 0, PARQUET = 1, ORC = 2 };
enum FileType { CSV = 0, PARQUET = 1, ORC = 2, JSON = 3 };
enum class AdjListType : uint8_t;

template <typename T>
Expand Down
2 changes: 2 additions & 0 deletions cpp/include/graphar/util/file_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace graphar {
static inline FileType StringToFileType(const std::string& str) {
static const std::map<std::string, FileType> str2file_type{
{"csv", FileType::CSV},
{"json", FileType::JSON},
{"parquet", FileType::PARQUET},
{"orc", FileType::ORC}};
try {
Expand All @@ -43,6 +44,7 @@ static inline FileType StringToFileType(const std::string& str) {
static inline const char* FileTypeToString(FileType file_type) {
static const std::map<FileType, const char*> file_type2string{
{FileType::CSV, "csv"},
{FileType::JSON, "json"},
{FileType::PARQUET, "parquet"},
{FileType::ORC, "orc"}};
return file_type2string.at(file_type);
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/filesystem.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ std::shared_ptr<ds::FileFormat> FileSystem::GetFileFormat(
return std::make_shared<ds::ParquetFileFormat>();
case ORC:
return std::make_shared<ds::OrcFileFormat>();
case JSON:
return std::make_shared<ds::JsonFileFormat>();
default:
return nullptr;
}
Expand Down
64 changes: 63 additions & 1 deletion cpp/test/test_arrow_chunk_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ TEST_CASE_METHOD(GlobalFixture, "ArrowChunkReader") {
auto maybe_reader = AdjListArrowChunkReader::Make(
graph_info, src_label, edge_label, dst_label,
AdjListType::ordered_by_source);
REQUIRE(maybe_reader.status().ok());
auto reader = maybe_reader.value();
// check reader start from vertex chunk 0
auto result = reader->GetChunk();
Expand Down Expand Up @@ -463,4 +462,67 @@ TEST_CASE_METHOD(GlobalFixture, "ArrowChunkReader") {
REQUIRE(reader->seek(1024).IsIndexError());
}
}

TEST_CASE_METHOD(GlobalFixture, "JSON_TEST") {
// read file and construct graph info
std::string path = test_data_dir + "/ldbc_sample/json/LdbcSample.graph.yml";
std::string src_label = "Person", edge_label = "Knows", dst_label = "Person";
std::string vertex_property_name = "id";
std::string edge_property_name = "creationDate";
auto maybe_graph_info = GraphInfo::Load(path);
REQUIRE(maybe_graph_info.status().ok());
auto graph_info = maybe_graph_info.value();
auto vertex_info = graph_info->GetVertexInfo(src_label);
REQUIRE(vertex_info != nullptr);
auto v_pg = vertex_info->GetPropertyGroup(vertex_property_name);
REQUIRE(v_pg != nullptr);
auto edge_info = graph_info->GetEdgeInfo(src_label, edge_label, dst_label);
REQUIRE(edge_info != nullptr);
auto e_pg = edge_info->GetPropertyGroup(edge_property_name);
REQUIRE(e_pg != nullptr);

SECTION("VertexPropertyArrowChunkReader") {
auto maybe_reader = VertexPropertyArrowChunkReader::Make(
graph_info, src_label, vertex_property_name);
REQUIRE(maybe_reader.status().ok());
auto reader = maybe_reader.value();
REQUIRE(reader->GetChunkNum() == 10);

SECTION("Basics") {
auto result = reader->GetChunk();
REQUIRE(!result.has_error());
auto table = result.value();
REQUIRE(table->num_rows() == 100);
REQUIRE(table->GetColumnByName(GeneralParams::kVertexIndexCol) !=
nullptr);

// seek
REQUIRE(reader->seek(100).ok());
result = reader->GetChunk();
REQUIRE(!result.has_error());
table = result.value();
REQUIRE(table->num_rows() == 100);
REQUIRE(table->GetColumnByName(GeneralParams::kVertexIndexCol) !=
nullptr);
REQUIRE(reader->next_chunk().ok());
result = reader->GetChunk();
REQUIRE(!result.has_error());
table = result.value();
REQUIRE(table->num_rows() == 100);
REQUIRE(table->GetColumnByName(GeneralParams::kVertexIndexCol) !=
nullptr);
REQUIRE(reader->seek(900).ok());
result = reader->GetChunk();
REQUIRE(!result.has_error());
table = result.value();
REQUIRE(table->num_rows() == 3);
REQUIRE(table->GetColumnByName(GeneralParams::kVertexIndexCol) !=
nullptr);
REQUIRE(reader->GetChunkNum() == 10);
REQUIRE(reader->next_chunk().IsIndexError());

REQUIRE(reader->seek(1024).IsIndexError());
}
}
}
} // namespace graphar
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.graphar.GarTable
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -174,6 +175,7 @@ class GarDataSource extends TableProvider with DataSourceRegister {
case "csv" => classOf[CSVFileFormat]
case "orc" => classOf[OrcFileFormat]
case "parquet" => classOf[ParquetFileFormat]
case "json" => classOf[JsonFileFormat]
case _ => throw new IllegalArgumentException
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ import org.apache.hadoop.fs.Path
import org.apache.parquet.hadoop.ParquetInputFormat
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.csv.CSVOptions
import org.apache.spark.sql.catalyst.json.JSONOptionsInRead
import org.apache.spark.sql.catalyst.expressions.{ExprUtils, Expression}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.connector.read.PartitionReaderFactory
import org.apache.spark.sql.execution.PartitionedFileUtil
import org.apache.spark.sql.execution.datasources.parquet.{
Expand All @@ -34,6 +36,7 @@ import org.apache.spark.sql.execution.datasources.parquet.{
}
import org.apache.spark.sql.execution.datasources.v2.FileScan
import org.apache.spark.sql.execution.datasources.v2.csv.CSVPartitionReaderFactory
import org.apache.spark.sql.execution.datasources.v2.json.JsonPartitionReaderFactory
import org.apache.spark.sql.execution.datasources.v2.orc.OrcPartitionReaderFactory
import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetPartitionReaderFactory
import org.apache.spark.sql.execution.datasources.{
Expand Down Expand Up @@ -74,6 +77,7 @@ case class GarScan(
case "csv" => createCSVReaderFactory()
case "orc" => createOrcReaderFactory()
case "parquet" => createParquetReaderFactory()
case "json" => createJSONReaderFactory()
case _ =>
throw new IllegalArgumentException("Invalid format name: " + formatName)
}
Expand Down Expand Up @@ -193,6 +197,46 @@ case class GarScan(
)
}

// Create the reader factory for the JSON format.
private def createJSONReaderFactory(): PartitionReaderFactory = {
val parsedOptions = new JSONOptionsInRead(
CaseInsensitiveMap(options.asScala.toMap),
sparkSession.sessionState.conf.sessionLocalTimeZone,
sparkSession.sessionState.conf.columnNameOfCorruptRecord
)

// Check a field requirement for corrupt records here to throw an exception in a driver side
ExprUtils.verifyColumnNameOfCorruptRecord(
dataSchema,
parsedOptions.columnNameOfCorruptRecord
)
// Don't push any filter which refers to the "virtual" column which cannot present in the input.
// Such filters will be applied later on the upper layer.
val actualFilters =
pushedFilters.filterNot(
_.references.contains(parsedOptions.columnNameOfCorruptRecord)
)

val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
// Hadoop Configurations are case sensitive.
val hadoopConf =
sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap)
val broadcastedConf = sparkSession.sparkContext.broadcast(
new SerializableConfiguration(hadoopConf)
)
// The partition values are already truncated in `FileScan.partitions`.
// We should use `readPartitionSchema` as the partition schema here.
JsonPartitionReaderFactory(
sparkSession.sessionState.conf,
broadcastedConf,
dataSchema,
readDataSchema,
readPartitionSchema,
parsedOptions,
actualFilters
)
}

/**
* Override "partitions" of
* org.apache.spark.sql.execution.datasources.v2.FileScan to disable splitting
Expand Down Expand Up @@ -272,6 +316,7 @@ case class GarScan(
case "csv" => super.hashCode()
case "orc" => getClass.hashCode()
case "parquet" => getClass.hashCode()
case "json" => super.hashCode()
case _ =>
throw new IllegalArgumentException("Invalid format name: " + formatName)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ case class GarScanBuilder(

override def pushedFilters(): Array[Filter] = formatName match {
case "csv" => Array.empty[Filter]
case "json" => Array.empty[Filter]
case "orc" => pushedOrcFilters
case "parquet" => pushedParquetFilters
case _ =>
Expand Down Expand Up @@ -87,8 +88,9 @@ case class GarScanBuilder(
// Check if the file format supports nested schema pruning.
override protected val supportsNestedSchemaPruning: Boolean =
formatName match {
case "csv" => false
case "orc" => sparkSession.sessionState.conf.nestedSchemaPruningEnabled
case "csv" => false
case "json" => false
case "orc" => sparkSession.sessionState.conf.nestedSchemaPruningEnabled
case "parquet" =>
sparkSession.sessionState.conf.nestedSchemaPruningEnabled
case _ =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@ import org.apache.spark.sql.execution.datasources.v2.FileTable
import org.apache.spark.sql.graphar.csv.CSVWriteBuilder
import org.apache.spark.sql.graphar.orc.OrcWriteBuilder
import org.apache.spark.sql.graphar.parquet.ParquetWriteBuilder
import org.apache.spark.sql.graphar.json.JSONWriteBuilder
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.execution.datasources.json.JsonDataSource
import org.apache.spark.sql.catalyst.json.JSONOptions

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -82,8 +85,21 @@ case class GarTable(
OrcUtils.inferSchema(sparkSession, files, options.asScala.toMap)
case "parquet" =>
ParquetUtils.inferSchema(sparkSession, options.asScala.toMap, files)
case "json" => {
val parsedOptions = new JSONOptions(
options.asScala.toMap,
sparkSession.sessionState.conf.sessionLocalTimeZone
)

JsonDataSource(parsedOptions).inferSchema(
sparkSession,
files,
parsedOptions
)
}
case _ =>
throw new IllegalArgumentException("Invalid format name: " + formatName)

}

/** Construct a new write builder according to the actual file format. */
Expand All @@ -95,6 +111,8 @@ case class GarTable(
new OrcWriteBuilder(paths, formatName, supportsDataType, info)
case "parquet" =>
new ParquetWriteBuilder(paths, formatName, supportsDataType, info)
case "json" =>
new JSONWriteBuilder(paths, formatName, supportsDataType, info)
case _ =>
throw new IllegalArgumentException("Invalid format name: " + formatName)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Derived from Apache Spark 3.5.1
// https://github.com/apache/spark/blob/1d550c4/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonWriteBuilder.scala

package org.apache.spark.sql.graphar.json
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
import org.apache.spark.sql.catalyst.util.CompressionCodecs
import org.apache.spark.sql.connector.write.LogicalWriteInfo
import org.apache.spark.sql.execution.datasources.json.JsonOutputWriter
import org.apache.spark.sql.execution.datasources.{
CodecStreams,
OutputWriter,
OutputWriterFactory
}

import org.apache.spark.sql.catalyst.json.JSONOptions
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StructType, DataType}

import org.apache.spark.sql.graphar.GarWriteBuilder

class JSONWriteBuilder(
paths: Seq[String],
formatName: String,
supportsDataType: DataType => Boolean,
info: LogicalWriteInfo
) extends GarWriteBuilder(paths, formatName, supportsDataType, info) {
override def prepareWrite(
sqlConf: SQLConf,
job: Job,
options: Map[String, String],
dataSchema: StructType
): OutputWriterFactory = {
val conf = job.getConfiguration
// val parsedOptions = null
val parsedOptions = new JSONOptions(
options,
sqlConf.sessionLocalTimeZone,
sqlConf.columnNameOfCorruptRecord
)
parsedOptions.compressionCodec.foreach { codec =>
CompressionCodecs.setCodecConfiguration(conf, codec)
}

new OutputWriterFactory {
override def newInstance(
path: String,
dataSchema: StructType,
context: TaskAttemptContext
): OutputWriter = {
new JsonOutputWriter(path, parsedOptions, dataSchema, context)
}

override def getFileExtension(context: TaskAttemptContext): String = {
".json" + CodecStreams.getCompressionExtension(context)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.sql.sources.DataSourceRegister
Expand Down Expand Up @@ -173,6 +174,7 @@ class GarDataSource extends TableProvider with DataSourceRegister {
case "csv" => classOf[CSVFileFormat]
case "orc" => classOf[OrcFileFormat]
case "parquet" => classOf[ParquetFileFormat]
case "json" => classOf[JsonFileFormat]
case _ => throw new IllegalArgumentException
}
}
Loading
Loading