UDF’s in PySpark for Beginners

In this tutorial we will write two basic UDF’s in PySpark.
UDF is acronym for User Defined Function which in our context are python functions which will help us process data like you would in a usual python code.
The objective here is have a crystal clear understanding of how to create UDF without complicating matters much. Two UDF’s we will create are —
- Count unique elements in a array (in our case array of dates) and
- Sum elements of the array (in our case array of amounts spent).
Complete code which we will deconstruct in this post is below:
import pyspark.sql.functions as F
import pyspark.sql.types as Tcustomers = spark.createDataFrame(data=[["Alice", "2016-05-01", 50.00], ["Alice", "2016-05-03", 45.00], ["Alice", "2016-05-04", 55.00], ["Alice", "2016-05-04", 75.00], ["Bob", "2016-05-01", 19.00], ["Bob", "2016-05-01", 25.00], ["Bob", "2016-05-04", 29.00], ["Bob", "2016-05-06", 27.00]], schema=["name", "date", "amountSpent"])def unique_counter(x):
return len(set(x))def sum_total(x):
return sum(x)unique_count = F.udf(unique_counter, T.IntegerType())
total = F.udf(sum_total, T.FloatType())customers.groupBy('name').agg(F.collect_list('date').alias('date'), F.collect_list('amountSpent').alias('amountSpent')).withColumn('spend_days', unique_count(F.col('date'))).withColumn('total_spend', total(F.col('amountSpent'))).show()#+-----+--------------------+---------------+----------+-----------+
#| name| date| amountSpent|spend_days|total_spend|
#+-----+--------------------+---------------+----------+-----------+
#| Bob|[2016-05-01, 2016...|[19.0, 29.0,...| 3| 100.0|
#|Alice|[2016-05-04, 2016...|[75.0, 50.0,...| 3| 225.0|
#+-----+--------------------+---------------+----------+-----------+
Lets understand the code part by part.
We import the functions and types available in pyspark.sql
import pyspark.sql.functions as F
import pyspark.sql.types as T
Next we create a small dataframe to work with. The data has 3 columns of name, date and amountSpent.
customers = spark.createDataFrame(data=[["Alice", "2016-05-01", 50.00], ["Alice", "2016-05-03", 45.00], ["Alice", "2016-05-04", 55.00], ["Alice", "2016-05-04", 75.00], ["Bob", "2016-05-01", 19.00], ["Bob", "2016-05-01", 25.00], ["Bob", "2016-05-04", 29.00], ["Bob", "2016-05-06", 27.00]], schema=["name", "date", "amountSpent"])
You can view the dataframe as:
customers.show()
The output should be:
+-----+----------+-----------+
| name| date|amountSpent|
+-----+----------+-----------+
|Alice|2016-05-01| 50.0|
|Alice|2016-05-03| 45.0|
|Alice|2016-05-04| 55.0|
|Alice|2016-05-04| 75.0|
| Bob|2016-05-01| 19.0|
| Bob|2016-05-01| 25.0|
| Bob|2016-05-04| 29.0|
| Bob|2016-05-06| 27.0|
+-----+----------+-----------+
Note that while creating dataframe if schema is just a list of columns then the type is inferred automatically by spark.
You can also look at the resulting schema of dataframe as:
customers.printSchema()
The output should be:
root
|-- name: string (nullable = true)
|-- date: string (nullable = true)
|-- amountSpent: double (nullable = true)
Now, naturally one might be interested in the total days money was spent and what the total spend was for each person. The way to do this is to groupby name and then do necessary operations on date and amountSpent columns. We will write UDF’s to do exactly these operations.
def unique_counter(x):
return len(set(x))def sum_total(x):
return sum(x)unique_count = F.udf(unique_counter, T.IntegerType())
total = F.udf(sum_total, T.FloatType())
Not surprisingly they look exactly like normal python functions with additional lines at the end of code block above which converts these python functions to spark UDF’s. Let’s take one of these lines:
unique_count = F.udf(unique_counter, T.IntegerType())
- unique_counter is the python function we created, which takes a list and gives count of unique elements in it.
- T.IntegerType() is the return type of the UDF unique_counter.
- unique_count is the name of python UDF translated to spark.
- Finally, F.udf(…) does the job of translating python UDF to spark.
Now all we need to do is groupby dataframe by column name and pass the collected lists of date and amountSpend to their UDF’s. But lets look at just collected lists first, so that we really know we are sending to UDF what we want.
customers.groupBy('name').agg(F.collect_list('date').alias('date'), F.collect_list('amountSpent').alias('amountSpent')).show()+-----+--------------------+--------------------+
| name| date| amountSpent|
+-----+--------------------+--------------------+
| Bob|[2016-05-01, 2016...|[19.0, 29.0, 25.0...|
|Alice|[2016-05-03, 2016...|[45.0, 50.0, 75.0...|
+-----+--------------------+--------------------+
Looks alright! Lets apply the UDF…
customers.groupBy('name').agg(F.collect_list('date').alias('date'), F.collect_list('amountSpent').alias('amountSpent')).withColumn('spend_days', unique_count(F.col('date'))).withColumn('total_spend', total(F.col('amountSpent'))).show()+-----+--------------------+---------------+----------+-----------+
| name| date| amountSpent|spend_days|total_spend|
+-----+--------------------+---------------+----------+-----------+
| Bob|[2016-05-01, 2016...|[19.0, 29.0,...| 3| 100.0|
|Alice|[2016-05-04, 2016...|[75.0, 50.0,...| 3| 225.0|
+-----+--------------------+---------------+----------+-----------+
Bob and Alice spent 100 and 225 respectively in 3 unique days each.
And you learned how to write a UDF in PySpark :)