Filtering outliers in Apache Spark based on calculations of previous values


I'm processing geospatial data using Spark 2.0 Dataframes with the following schema:

 |-- date: timestamp (nullable = true)
 |-- lat: double (nullable = true)
 |-- lon: double (nullable = true)
 |-- accuracy: double (nullable = true)
 |-- track_id: long (nullable = true)

I have seen that there are jumps of the location signal to a complete different place. The strange thing is, that the signal then remains for a certain time, say aound 25 seconds or 5 samples at the remote location and then jumps back to where I stand.

I'd like to remove these outliers by calculating the speed between the current and the "last valid record" by calculating the speed between the points. If the speed is above a given threshold, the current record should be dropped and the "last valid record" remains the same. If the speed is below the threshold the current record is added to the result data frame and becomes the new "last valid record".

I'm using Spark 2.0 with Dataframes.

Any suggestions of how to implement this strategy or any better strategy are highly appreciated. Thanks.

PS: I asked the same questions in stackoverflow, with a concrete implementation. But, since I'm not sure if this is the right approach, and do not want bias the answers to a certain Spark method, I ask here for any suggestions.


Posted 2016-12-07T06:20:35.227

Reputation: 133

How many rows are there per track_id? I assume the track_id indicates a unique object to track? – Jan van der Vegt – 2016-12-07T06:54:56.940

No, the track_id separates unique tracks or lets call it trips, meaning a collection of such records, sampled all 5 seconds. So for 24 hours long track there would be 17280 records per track_id. I intended the track_id for partitioning and parallelization of the algorithm. The full dataset contains just date, lat, lon, accuracy recorded over 80 days. Then the track_id was generated by assigning a new track id, after each break of more than 10 minutes. The longest track is not longer than 24 hours. – Martin – 2016-12-07T07:08:49.620



This is actually a general problem with time-series data: you have some logic to implement based on one or more values in the series. You always have two choices:

  1. Feed the time series through some module that calculates as each data point arrives
  2. Use the "spreadsheet method" to calculate a series of columns eventually arriving at the goal

The advantage of the first approach is you can use the same module to process your real-time data. The advantage of the second approach is that it's very fast and usually easier to implement.

Since you're already in a Spark Dataset, here's the strategy:

  1. Calculate a speed column: $p_t - p_{t-1}$ where $p$ is the position
  2. Calculate a "jump" column: 1 if the speed is over a certain threshold, -1 if under, 0 otherwise
  3. Calculate a "jumpsum" column: the cumulative sum of the jump column
  4. Bad data will have a jumpsum of 1; filter them out

Here's how you do it:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.SparkSession

val ss: SparkSession = SparkSession.builder.getOrCreate()

// note the file must be on each executor in the same directory
val ds =
   .option("header", "true")
   .option("inferSchema", "true")

val w = Window.partitionBy().orderBy("datetime")
val threshold = 10
def jump(v: Double): Int = if (v > threshold) 1 else if (v < -threshold) -1 else 0
val sqlJump = udf(jump _)

val cleanDS = ds
    .withColumn("speed", $"position" - lag($"position", 1).over(w.rowsBetween(-1, -1)))
    .withColumn("jump", sqlJump($"speed"))
    .withColumn("jumpsum", sum($"jump").over(w.rowsBetween(Long.MinValue, 0)))

Here's what the output Dataset looks like (I didn't remove the bad rows so you can see the calculation):

|       1|       1| null|null|   null|
|       2|       1|    0|   0|      0|
|       3|       1|    0|   0|      0|
|       4|       1|    0|   0|      0|
|       5|       1|    0|   0|      0|
|       6|       2|    1|   0|      0|
|       7|       1|   -1|   0|      0|
|       8|       1|    0|   0|      0|
|       9|      46|   45|   1|      1|
|      10|      45|   -1|   0|      1|
|      11|      48|    3|   0|      1|
|      12|      45|   -3|   0|      1|
|      13|       1|  -44|  -1|      0|
|      14|       2|    1|   0|      0|
|      15|       1|   -1|   0|      0|

The "data.csv" is just the first two columns of that Dataset:



All that's left to do is filter out jumpsum === 1.


Posted 2016-12-07T06:20:35.227

Reputation: 754

Thank you so much, Pete. This is a brilliant solution. You helped me so much. – Martin – 2016-12-08T19:02:28.207

You're welcome! Looking at your post again, it looks like you'll want to partitionBy($"track_id") in your window construction. I left it blank in the solution, which only works if there is only one time series. – Pete – 2016-12-08T21:11:38.817