diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala index e586d95..eed6fb1 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala @@ -197,6 +197,8 @@ case class NebulaReadConfigEntry(address: String = "", space: String = "", labels: List[String] = List(), weightCols: List[String] = List()) { + assert(weightCols.isEmpty || labels.size == weightCols.size, + "weightCols must be empty or has the same amount values with labels") override def toString: String = { s"NebulaReadConfigEntry: " + s"{address: $address, space: $space, labels: ${labels.mkString(",")}, " + diff --git a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala index c89d5f9..a10ba19 100644 --- a/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala +++ b/nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/reader/DataReader.scala @@ -48,11 +48,11 @@ class NebulaReader(spark: SparkSession, configs: Configs, partitionNum: String) .withReturnCols(returnCols.toList) .withPartitionNum(partition) .build() - if (dataset == null) { - dataset = spark.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF() - } else { - dataset = dataset.union(spark.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF()) + var df = spark.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF() + if (weights.nonEmpty) { + df = df.select("_srcId", "_dstId", weights(i)) } + dataset = if (dataset == null) df else dataset.union(df) } dataset }