Spark aggregateByKey function allows to perform multiple aggregation simultaneously

Spark Aggregate By Key

Function aggregateByKey is one of the aggregate function (Others are reduceByKey & groupByKey) available in Spark. This is the only aggregation function which allows multiple type of aggregation(Maximun, minimun, average, sum & count) at the same time. People find it hard to understand this function initially but will try to explain the function in simpler way.


RDD Partitions: Understanding the concept of RDD partitions is important before understanding the aggregateByKey function.
  1. The content of file read (RDD) in spark is partitioned into multiple partitions.
  2. Transformations that are applied on RDD, are applied to each of the partition.
  3. Spark release as many task for transformation as many number of partitions are present for RDD.
scala> val orderItemMap = orderItemRDD.map(orderItem => (1,orderItem.split(",")(4).toFloat))
orderItemMap: org.apache.spark.rdd.RDD[(Int, Float)] = MapPartitionsRDD[27] at map at :25
 
scala> orderItemMap.getNumPartitions
res23: Int = 2

Overriding the number of partitons in RDD: Using 'repartition(numOfPartitions)'
scala> val orderItemMap = orderItemRDD.map(orderItem => (1,orderItem.split(",")(4).toFloat)).repartition(1)
orderItemMap: org.apache.spark.rdd.RDD[(Int, Float)] = MapPartitionsRDD[13] at repartition at :25
 
scala> orderItemMap.getNumPartitions
res7: Int = 1

Why it is important to understand Partitions in RDD?
Because aggregateByKey has 3 main inputs/functions and 1 of them works on output each of the RDD partition transforamtion. And if they is only single partition in RDD then that function will not be called.


aggregateByKey requires 3 main inputs:
  1. zeroValue: Initial value (mostly Zero (0)) which will not affect the aggregate values to be collected. For example, 0 would be initial value to perform sum or count or to perform operation on String then the initial value will be empty string.
  2. Combiner function: This function accepts two parameters. The second parameter is merged into the first parameter. This function combines/merges values within a single partition.
  3. Reduce/Merge function: This function also accepts two parameters. Here parameters are merged into one across RDD partitions.

Syntax:
dataframeRDD.aggregateByKey(init_value)(combinerFunc,reduceFunc)


Example: Finding the total revenue, total items & Average revenue per item. Download sample data file here
val orderItemRDD = sc.textFile("file:///Users/dbmstutorials/spark/orderItem.txt")
 
--Setting partition to 1
val orderItemMap = orderItemRDD.map(orderItem => (1,orderItem.split(",")(4).toFloat)).repartition(1)  //5th field in sample file is revenue

orderItemMap.getNumPartitions // Check Number of RDD partiton is 1
res1: Int = 1
   
   
AggregateByKey using external functions
val init_value = (0.0f, 0.0f,0.0f) //Intial Value for Revenue(sum), total items(count) & Average revenue (sum/count)
val combinerFunc = (inter:(Float, Float, Float), value:Float) => { (inter._1 + value, inter._2+1,(inter._1+value)/(inter._2+1))}
val reduceFunc = (p1:(Float, Float, Float), p2:(Float, Float, Float)) => { (p1._1 + p2._1, p1._2+p2._2, (p1._1 + p2._1)/(p1._2+p2._2)) }

val revenueAndCountAndAvg = orderItemMap.
  aggregateByKey(init_value)(combinerFunc,reduceFunc)


AggregateByKey by defining functions internally

val revenueAndCountAndAvg = orderItemMap.
  aggregateByKey((0.0f, 0,0.0f))(
   (inter, value) => { 
            (inter._1 + value, inter._2+1,(inter._1+value)/(inter._2+1))}, //1st processing adds all the values, 2nd processing counts all records, 3rd processing is sum/count
   (p1, p2) => { 
            (p1._1 + p2._1, p1._2+p2._2, (p1._1 + p2._1)/(p1._2+p2._2))
   }
 )
revenueAndCountAndAvg.collect.foreach(println)

Deep Dive into how function is working with single partition for better understanding


Lets see how the same function will perform on RDD with 2 partitions
val orderItemRDD = sc.textFile("file:///Users/dbmstutorials/spark/orderItem.txt")

val orderItemMap = orderItemRDD.map(orderItem => (1,orderItem.split(",")(4).toFloat))  //5th field in sample file is revenue

scala> orderItemMap.getNumPartitions //  Number of RDD partitons are 2
res32: Int = 2


val revenueAndCountAndAvg = orderItemMap.
  aggregateByKey((0.0f, 0,0.0f))(
    (inter, value) => { 
             (inter._1 + value, inter._2+1,inter._1/(inter._2+1))}, //1st processing adding all the values, 2nd processing counting all records, 3rd procssing average sum/count
    (p1, p2) => { 
             (p1._1 + p2._1, p1._2+p2._2, (p1._1 + p2._1)/(p1._2+p2._2))
    }
  )

revenueAndCountAndAvg.take(1).foreach(println)
(1,(3818.46,20,190.923))


Notes: Same function applied but on RDD with 2 partitions.