Spark SQL: Functions in Scala




This article is about using and implementing Spark SQL functions. The examples in this article were made using the spark-shell, so you should be able to follow along yourself using the spark-shell too. Let’s make some simple data to illustrate the concepts in this article.

scala> import spark.implicits._
import spark.implicits._

scala> val clusteredPoints = Seq((1, 3, 4), (1, 5, 12), (2, 8, 15)).toDF("cluster_id", "x", "y")
clusteredPoints: org.apache.spark.sql.DataFrame = [cluster_id: int, x: int ... 1 more field]

scala> clusteredPoints.show

|cluster_id|  x|  y|
|         1|  3|  4|
|         1|  5| 12|
|         2|  8| 15|

clusteredPoints is a DataFrame, created by calling the toDF method. In order to call the toDF method on the Seq object we must first import spark.implicits.   Spark SQL, like normal RDB relational SQL, has a lot of existing functions for manipulating data: https://spark.apache.org/docs/2.4.3/api/scala/index.html#org.apache.spark.sql.functions

There are basically two types of Spark SQL functions, those that manipulate data by row: row operation functions and those that manipulate data by column: aggregate functions.

Using an existing row operation function

Let’s look at how we might use an existing Spark SQL function to find the distance of our (x,y) points from the origin:

scala> clusteredPoints.withColumn("hypotenuse", hypot("x", "y")).show

|cluster_id|  x|  y|hypotenuse|
|         1|  3|  4|       5.0|
|         1|  5| 12|      13.0|
|         2|  8| 15|      17.0|

  The above code calls the withColumn method on clusteredPoints. The withColumn creates a new column, derived from the existing columns. It takes two parameters, the column name and a Column object which wraps the value for the new column. Here we use the existing SQL function hypot to derive this new Column object. As the name suggests we simply work out the distance of our (x,y) point from the origin by using the hypotenuse theorem: distance = square root(square(x) + square(y))

Using a custom row operation function

Let’s see how we can define our own hypotenuse function:

def hypotenuse = (a: Int, b: Int) => math.sqrt(a*a + b*b)
val myHypot = udf(hypotenuse)

scala> val myHypot = udf(hypotenuse)
myHypot: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function2>,DoubleType,Some(List(IntegerType, IntegerType)))

First we define our hypotenuse function, the we wrap it using the udf method. As the spark output suggests this creates a UserDefinedFunction object which wraps our function. the myHypot UDF takes two parameters of type Column.  The Columns are expected to be of IntegerType, which means they hold Int values. Just as with the standard hypot function, myHypot returns a Column object of type DoubleType.  

scala> clusteredPoints.withColumn("hypotenuse", myHypot($"x", $"y")).show

|cluster_id|  x|  y|hypotenuse|
|         1|  3|  4|       5.0|
|         1|  5| 12|      13.0|
|         2|  8| 15|      17.0|

We call our myHypot UDF very similarly to how we called the hypot SQL function, except that we must specify “the column x” and the “column y”. This is what the $ does. It is an alias for col(). We could just as easily have called myHypot(col(“x”), col(“y”)).  

Using existing aggregate functions in Spark SQL

Just as we can group and apply aggregate functions in SQL, so we can in Spark SQL. Let’s use an existing aggregate function to get the average x and y for our data, grouped by cluster_id.

scala> clusteredPoints.groupBy("cluster_id").agg(avg("x"), avg("y")).show

|         1|   4.0|   8.0|
|         2|   8.0|  15.0|

  First we use groupBy to group all rows with the same cluster_id together, then apply the avg function to each group. Our toy data is very simple with only 2 dimensional points, but it is easy to imagine having 3-d or multi-dimensional data with many dimensions. In this case we would like to not have to specify each of the columns individually. The code below is a way of dynamically specifying the columns for higher dimensional data:

val groupCol = "cluster_id"
val aggCols = (clusteredPoints.columns.toSet - groupCol).map(
colName => avg(colName).as(colName + "_avg")

clusteredPoints.groupBy(groupCol).agg(aggCols.head, aggCols.tail: _*)

Of course we can also define our own aggregate functions There are two types of User Defined aggregate Functions, those that extend Aggregator and those that extend UserDefinedAggregateFunction

Using a User Defined Aggregator Function by extending Aggregator

Let’s see how we can implement our own avg function by extending the abstract org.apache.spark.sql.expressions.Aggregator class.

import org.apache.spark.sql.{Encoder, Encoders, Row}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, StructType}

case class AvgBuffer(sum: Int, count: Int) {}

case class AvgAggregator(colName: String) extends Aggregator[Row, AvgBuffer, Double]{
  override def zero: AvgBuffer = AvgBuffer(0, 0)

  override def reduce(b: AvgBuffer, a: Row): AvgBuffer = AvgBuffer(b.sum + a.getAs[Int](colName), b.count + 1)

  override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = AvgBuffer(b1.sum + b2.sum, b1.count + b2.count)

  override def finish(reduction: AvgBuffer): Double = reduction.sum/reduction.count

  override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product[AvgBuffer]

  override def outputEncoder: Encoder[Double] = spark.implicits.newDoubleEncoder

First we define a buffer class to hold the intermediate data for our avg calculation, AvgBuffer. This buffer class simply holds two running totals; the sum of the values and the count of the values.

Next, we define the AvgAggregator.

The zero method initializes the buffer.

The reduce method updates the buffer with values from the next row of input.

The merge method combines two buffers into one. This is used by spark when combining the results from intermediate calculations.

The finish method does the final average calculation: avg = sum/count

The bufferEncoder object defines an encoder for serialization/deserialization of the buffer data. This is needed because spark moves the intermediate data around the network.

The outputEncoder object defines an encoder for serialization/deserialization of the output. Also needed when spark moves the output data around.

Having defined our own avg logic, it is straightforward to use as shown below.

val avgX = AvgAggregator("x")
val avgY = AvgAggregator("y")
clusteredPoints.groupBy("cluster_id").agg(avgX.toColumn.as("x_avg"), avgY.toColumn.as("y_avg")).show

scala> val avgX = AvgAggregator("x")
avgX: AvgAggregator = AvgAggregator(x)

scala> val avgY = AvgAggregator("y")
avgY: AvgAggregator = AvgAggregator(y)

scala> clusteredPoints.groupBy("cluster_id").agg(avgX.toColumn.as("x_avg"), avgY.toColumn.as("y_avg")).show

|         1|  4.0|  8.0|
|         2|  8.0| 15.0|

We define an AvgAggregator for each column we will calculate the average values for x and y. We then group our data by cluster_id and apply our avgX and avgY aggregators. Notice that we also call the Aggregator.toColumn method to get a Column object wrapping our new average values, then call Column.as to assign the new columns names. At this point, you may be wondering how we can extend this to handle points of arbitrary dimensions.

Similarly to the example above for the standard avg function, we create a list of aggregators for each of the columns other than the grouping column: “cluster_id”.

val aggregatorCols = (clusteredPoints.columns.toSet - groupCol).map( colName => AvgAggregator(colName).toColumn.as(colName + "_avg") ).toList 
clusteredPoints.groupBy(groupCol).agg(aggregatorCols.head, aggregatorCols.tail: _*)


Using a User Defined Aggregator Function by extending UserDefinedAggregateFunction

Now let’s look at how we can extend org.apache.spark.sql.expressions.UserDefinedAggregateFunction to create an average aggregator.

import org.apache.spark.sql.{Encoder, Encoders, Row}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, StructType}

class AvgUdaf extends UserDefinedAggregateFunction {
  override def inputSchema: StructType = new StructType().add("val", IntegerType)

  override def bufferSchema: StructType = new StructType().add("sum", IntegerType).add("count", IntegerType)

  override def dataType: DataType = DoubleType

  override def deterministic: Boolean = true

  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    buffer.update(0, 0)
    buffer.update(1, 0)

  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    buffer.update(0, buffer.getInt(0) + input.getInt(0))
    buffer.update(1, buffer.getInt(1) + 1)

  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    buffer1.update(0, buffer1.getInt(0) + buffer2.getInt(0))
    buffer1.update(1, buffer1.getInt(1) + buffer2.getInt(1))

  override def evaluate(buffer: Row): Any = buffer.getInt(0).toDouble / buffer.getInt(1).toDouble

The inputSchema object defines the structure and types of the input. We didn’t need to do this for the AvgAggregator class above because the input was simply a whole Row of data from our clusteredPoints DataFrame.

The bufferSchema object defines the structure and types of the buffer data. This is similar to the AvgBuffer class above.

The dataType object defines the type of the output.

The deterministic boolean is true if the aggregation function always returns the same output given the same input. This will almost always be true.

The initialize method initializes the buffer.

The update method updates the buffer with new input data.

The merge method merges two buffers into the first buffer.

The evaluate method calculates the final result. In this case average = sum/count

When comparing the UDAF implementation vs the Aggregator implementation, we can see some key differences. Buffer and input field names are used in  AvgAggregator. Also AvgAggregator is strictly typed for it’s buffer and input, meaning type mismatches will be caught at compile-time. This is a point in AvgAggregator’s favor. However, the AvgAggregator input is the whole row from the original data, meaning it might be better to use a UDAF when you have many columns but you aggregation function only operates on one or a few. If you need all columns for your aggregation function think about using an Aggregator implementation.


We use the AvgUdaf class as shown below:


val avgUdaf = new AvgUdaf() 
scala> clusteredPoints.groupBy("cluster_id").agg(avgUdaf($"x").as("x_avg"), avgUdaf($"y").as("y_avg")).show()
|         1|  4.0|  8.0|
|         2|  8.0| 15.0|


We looked at examples for functions in Spark SQL. We looked at functions which operate at the row level. We looked at the standard hypot mathematical function and our own implementation myHypot. Also, we looked at the standard aggregate function avg and our own custom implementations AvgAggregator and AvgUdaf that extend Aggregator and UserDefinedAggregateFunction.

We looked at the differences between Aggregator and UserDefinedAggregateFunction implementations, so that you should be able to understand when to use each. We also looked at a technique for dynamically extending calling our functions across arbitrary many columns for multi-dimensional data.

Simple standard functions were reimplemented so that the custom implementation details be most easily understood. Of course in real life, you should always use the standard functions if they are available. Indeed Spark can optimize these standard functions more than user defined functions, so be sure to familiarize yourself with the available functions before deciding to write your own.