-
Notifications
You must be signed in to change notification settings - Fork 695
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add local outlier factor implementation.
- Loading branch information
jameswillis
committed
Oct 11, 2024
1 parent
678da00
commit 08c8515
Showing
6 changed files
with
394 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
"""Algorithms for detecting outliers in spatial datasets.""" |
60 changes: 60 additions & 0 deletions
60
python/sedona/stats/outlier_detection/local_outlier_factor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
"""Functions related to calculating the local outlier factor of a dataset.""" | ||
from typing import Optional | ||
|
||
from pyspark.sql import DataFrame, SparkSession | ||
|
||
ID_COLUMN_NAME = "__id" | ||
CONTENTS_COLUMN_NAME = "__contents" | ||
|
||
|
||
def local_outlier_factor( | ||
dataframe: DataFrame, | ||
k: int = 20, | ||
geometry: Optional[str] = None, | ||
handle_ties: bool = False, | ||
use_spheroid=False, | ||
): | ||
"""Annotates a dataframe with a column containing the local outlier factor for each data record. | ||
The dataframe should contain at least one GeometryType column. Rows must be unique. If one geometry column is | ||
present it will be used automatically. If two are present, the one named 'geometry' will be used. If more than one | ||
are present and neither is named 'geometry', the column name must be provided. | ||
Args: | ||
dataframe: apache sedona idDataframe containing the point geometries | ||
k: number of nearest neighbors that will be considered for the LOF calculation | ||
geometry: name of the geometry column | ||
handle_ties: whether to handle ties in the k-distance calculation. Default is false | ||
use_spheroid: whether to use a cartesian or spheroidal distance calculation. Default is false | ||
Returns: | ||
A PySpark DataFrame containing the lof for each row | ||
""" | ||
sedona = SparkSession.getActiveSession() | ||
|
||
result_df = sedona._jvm.org.apache.sedona.stats.outlierDetection.LocalOutlierFactor.localOutlierFactor( | ||
dataframe._jdf, | ||
k, | ||
geometry, | ||
handle_ties, | ||
use_spheroid, | ||
) | ||
|
||
return DataFrame(result_df, sedona) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
import pytest | ||
import numpy as np | ||
import pyspark.sql.functions as f | ||
from pyspark.sql import DataFrame | ||
from pyspark.sql.types import StructType, StructField, IntegerType, DoubleType | ||
from sedona.sql.st_constructors import ST_MakePoint | ||
from sedona.sql.st_functions import ST_X, ST_Y | ||
from sklearn.neighbors import LocalOutlierFactor | ||
|
||
from tests.test_base import TestBase | ||
from sedona.stats.outlier_detection.local_outlier_factor import local_outlier_factor | ||
|
||
|
||
class TestLOF(TestBase): | ||
def get_small_data(self) -> DataFrame: | ||
schema = StructType( | ||
[ | ||
StructField("id", IntegerType(), True), | ||
StructField("x", DoubleType(), True), | ||
StructField("y", DoubleType(), True), | ||
] | ||
) | ||
return self.spark.createDataFrame( | ||
[ | ||
(1, 1.0, 2.0), | ||
(2, 2.0, 2.0), | ||
(3, 3.0, 3.0), | ||
], | ||
schema, | ||
).select("id", ST_MakePoint("x", "y").alias("geometry")) | ||
|
||
def get_medium_data(self): | ||
np.random.seed(42) | ||
|
||
X_inliers = 0.3 * np.random.randn(100, 2) | ||
X_inliers = np.r_[X_inliers + 2, X_inliers - 2] | ||
X_outliers = np.random.uniform(low=-4, high=4, size=(20, 2)) | ||
return np.r_[X_inliers, X_outliers] | ||
|
||
def get_medium_dataframe(self, data): | ||
schema = StructType( | ||
[StructField("x", DoubleType(), True), StructField("y", DoubleType(), True)] | ||
) | ||
|
||
return ( | ||
self.spark.createDataFrame(data, schema) | ||
.select(ST_MakePoint("x", "y").alias("geometry")) | ||
.withColumn("anotherColumn", f.rand()) | ||
) | ||
|
||
def compare_results(self, actual, expected, k): | ||
assert len(actual) == len(expected) | ||
missing = set(expected.keys()) - set(actual.keys()) | ||
assert len(missing) == 0 | ||
big_diff = { | ||
k: (v, expected[k], abs(1 - v / expected[k])) | ||
for k, v in actual.items() | ||
if abs(1 - v / expected[k]) > 0.0000000001 | ||
} | ||
assert len(big_diff) == 0 | ||
|
||
@pytest.mark.parametrize("k", [5, 21, 3]) | ||
def test_lof_matches_sklearn(self, k): | ||
data = self.get_medium_data() | ||
actual = { | ||
tuple(x[0]): x[1] | ||
for x in local_outlier_factor(self.get_medium_dataframe(data.tolist()), k) | ||
.select(f.array(ST_X("geometry"), ST_Y("geometry")), "lof") | ||
.collect() | ||
} | ||
clf = LocalOutlierFactor(n_neighbors=k, contamination="auto") | ||
clf.fit_predict(data) | ||
expected = dict( | ||
zip( | ||
[tuple(x) for x in data], | ||
[float(-x) for x in clf.negative_outlier_factor_], | ||
) | ||
) | ||
self.compare_results(actual, expected, k) | ||
|
||
# TODO uncomment when KNN join supports empty dfs | ||
# def test_handle_empty_dataframe(self): | ||
# empty_df = self.spark.createDataFrame([], self.get_small_data().schema) | ||
# result_df = local_outlier_factor(empty_df, 2) | ||
# | ||
# assert 0 == result_df.count() | ||
|
||
def test_raise_error_for_invalid_k_value(self): | ||
with pytest.raises(Exception): | ||
local_outlier_factor(self.get_small_data(), -1) |
148 changes: 148 additions & 0 deletions
148
...k/common/src/main/scala/org/apache/sedona/stats/outlierDetection/LocalOutlierFactor.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
package org.apache.sedona.stats.outlierDetection | ||
|
||
import org.apache.sedona.stats.Util.getGeometryColumnName | ||
import org.apache.spark.sql.sedona_sql.expressions.st_functions.{ST_Distance, ST_DistanceSpheroid} | ||
import org.apache.spark.sql.{Column, DataFrame, SparkSession, functions => f} | ||
|
||
object LocalOutlierFactor { | ||
|
||
private val ID_COLUMN_NAME = "__id" | ||
private val CONTENTS_COLUMN_NAME = "__contents" | ||
|
||
/** | ||
* Annotates a dataframe with a column containing the local outlier factor for each data record. | ||
* The dataframe should contain at least one GeometryType column. Rows must be unique. If one | ||
* geometry column is present it will be used automatically. If two are present, the one named | ||
* 'geometry' will be used. If more than one are present and neither is named 'geometry', the | ||
* column name must be provided. | ||
* | ||
* @param dataframe | ||
* apache sedona idDataframe containing the point geometries | ||
* @param k | ||
* number of nearest neighbors that will be considered for the LOF calculation | ||
* @param geometry | ||
* name of the geometry column | ||
* @param handleTies | ||
* whether to handle ties in the k-distance calculation. Default is false | ||
* @param useSpheroid | ||
* whether to use a cartesian or spheroidal distance calculation. Default is false | ||
* | ||
* @return | ||
* A PySpark DataFrame containing the lof for each row | ||
*/ | ||
def localOutlierFactor( | ||
dataframe: DataFrame, | ||
k: Int = 20, | ||
geometry: String = null, | ||
handleTies: Boolean = false, | ||
useSpheroid: Boolean = false): DataFrame = { | ||
|
||
if (k < 1) | ||
throw new IllegalArgumentException("k must be a positive integer") | ||
|
||
val prior: String = if (handleTies) { | ||
val prior = | ||
SparkSession.getActiveSession.get.conf | ||
.get("spark.sedona.join.knn.includeTieBreakers", "false") | ||
SparkSession.getActiveSession.get.conf.set("spark.sedona.join.knn.includeTieBreakers", true) | ||
prior | ||
} else "false" // else case to make compiler happy | ||
|
||
val distanceFunction: (Column, Column) => Column = | ||
if (useSpheroid) ST_DistanceSpheroid else ST_Distance | ||
val useSpheroidString = if (useSpheroid) "True" else "False" // for the SQL expression | ||
|
||
val geometryColumn = if (geometry == null) getGeometryColumnName(dataframe) else geometry | ||
|
||
val KNNFunction = "ST_KNN" | ||
|
||
// Store original contents, prep necessary columns | ||
val formattedDataframe = dataframe | ||
.withColumn(CONTENTS_COLUMN_NAME, f.struct("*")) | ||
.withColumn(ID_COLUMN_NAME, f.sha2(f.to_json(f.col(CONTENTS_COLUMN_NAME)), 256)) | ||
.withColumnRenamed(geometryColumn, "geometry") | ||
|
||
val kDistanceDf = formattedDataframe | ||
.alias("l") | ||
.join( | ||
formattedDataframe.alias("r"), | ||
// k + 1 because we are not counting the row matching to itself | ||
f.expr(f"$KNNFunction(l.geometry, r.geometry, $k + 1, $useSpheroidString)") && f.col( | ||
f"l.$ID_COLUMN_NAME") =!= f.col(f"r.$ID_COLUMN_NAME")) | ||
.groupBy(f"l.$ID_COLUMN_NAME") | ||
.agg( | ||
f.first("l.geometry").alias("geometry"), | ||
f.first(f"l.$CONTENTS_COLUMN_NAME").alias(CONTENTS_COLUMN_NAME), | ||
f.max(distanceFunction(f.col("l.geometry"), f.col("r.geometry"))).alias("k_distance"), | ||
f.collect_list(f"r.$ID_COLUMN_NAME").alias("neighbors")) | ||
.checkpoint() | ||
|
||
val lrdDf = kDistanceDf | ||
.alias("A") | ||
.select( | ||
f.col(ID_COLUMN_NAME).alias("a_id"), | ||
f.col(CONTENTS_COLUMN_NAME), | ||
f.col("geometry").alias("a_geometry"), | ||
f.explode(f.col("neighbors")).alias("n_id")) | ||
.join( | ||
kDistanceDf.select( | ||
f.col(ID_COLUMN_NAME).alias("b_id"), | ||
f.col("geometry").alias("b_geometry"), | ||
f.col("k_distance").alias("b_k_distance")), | ||
f.expr("n_id = b_id")) | ||
.select( | ||
f.col("a_id"), | ||
f.col("b_id"), | ||
f.col(CONTENTS_COLUMN_NAME), | ||
f.array_max( | ||
f.array( | ||
f.col("b_k_distance"), | ||
distanceFunction(f.col("a_geometry"), f.col("b_geometry")))) | ||
.alias("rd")) | ||
.groupBy("a_id") | ||
.agg( | ||
// + 1e-10 to avoid division by zero, matches sklearn impl | ||
(f.lit(1.0) / (f.mean("rd") + 1e-10)).alias("lrd"), | ||
f.collect_list(f.col("b_id")).alias("neighbors"), | ||
f.first(CONTENTS_COLUMN_NAME).alias(CONTENTS_COLUMN_NAME)) | ||
|
||
val ret = lrdDf | ||
.select( | ||
f.col("a_id"), | ||
f.col("lrd").alias("a_lrd"), | ||
f.col(CONTENTS_COLUMN_NAME), | ||
f.explode(f.col("neighbors")).alias("n_id")) | ||
.join( | ||
lrdDf.select(f.col("a_id").alias("b_id"), f.col("lrd").alias("b_lrd")), | ||
f.expr("n_id = b_id")) | ||
.groupBy("a_id") | ||
.agg( | ||
f.first(CONTENTS_COLUMN_NAME).alias(CONTENTS_COLUMN_NAME), | ||
(f.sum("b_lrd") / (f.count("b_lrd") * f.first("a_lrd"))).alias("lof")) | ||
.select(f.col(f"$CONTENTS_COLUMN_NAME.*"), f.col("lof")) | ||
|
||
if (handleTies) | ||
SparkSession.getActiveSession.get.conf | ||
.set("spark.sedona.join.knn.includeTieBreakers", prior) | ||
ret | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.