PySpark UDF | Spark UDF

June 28, 2020

Pyspark UDF enables the user to write custom user defined functions on the go. But we have to take into consideration the performance and type of UDF to be used. This post will cover the details of Pyspark UDF along with the usage of Scala UDF and Pandas UDF in Pyspark.


Pyspark UDF , Pandas UDF and Scala UDF in Pyspark will be covered as part of this post. Spark or PySpark provides the user the ability to write custom functions which are not provided as part of the package.

The internals of a PySpark UDF with code examples is explained in detail. In addition to this, an introduction to Pandas UDF in Pyspark and how a Scala UDF can be used in Pyspark is also covered as part of this post with a performance benchmark between them.

Table of Contents

  • What is a UDF in Spark
  • Internals of Pyspark UDF
  • Register UDF in Pyspark
    • Register UDF in Spark SQL
  • PySpark UDF Examples
    • UDFs with Primitive DataTypes
    • UDFs with Complex DataTypes
      • Pyspark UDF StructType
      • Pyspark UDF ArrayType
  • Scala UDF in PySpark
  • Pandas UDF in PySpark
  • Performance Benchmark
    • Pyspark UDF Performance
    • Scala UDF Performance
    • Pandas UDF Performance
  • Conclusion

What is a UDF in Spark ?

PySpark UDF or Spark UDF or User Defined Functions in Spark help us define custom functions or transformations based on our requirements. This helps us create functions which are not present as part of the built-in functions provided by Spark.

Spark UDF is special and extremely powerful because the user has the liberty to write these functions in the programming language of choice, such as in Scala, Java, Python or R. UDF in Pyspark or UDF in Spark is executed row by row.

Note: Spark UDF or PySpark UDF do not take advantages of the inbuilt optimizations provided by Spark such as the catalyst optimizer, so it is recommended to use them only when it is really required.

Internals of PySpark UDF

When Spark UDF is created in Python, 4 steps are performed, 1. Function is serialized and sent to the workers 2. Spark starts an individual python process in the worker node and data is sent to Python. 3. The execution is performed row by row. 4. Once the computations are performed in Python, the result is sent back to Spark.


Note: The cost of starting a python process in the worker node and serialization of data is quite expensive. Also, Python can consume additional memory causing the worker/executor to fail due to a resource crunch.

Register UDF in Pyspark

This section will cover the methods to create and use a python function as a pyspark udf.

Before we explore the different methods to create a UDF in Pyspark, let’s first create a sample dataframe consisting of integers.

Note: The versions used in this post are, Spark 2.3.0 and Python 3.6

df = spark.range(1, 101)  # Create a dataframe with 100 integers from 1 to 100

 |-- id: long (nullable = false)

Let’s create a simple UDF which computes the cube of an integer. The udf method must be imported from pyspark.sql.functions.

It takes in 2 arguments, function and it’s return type as shown below,

pyspark.sql.functions.udf(f=None, returnType=StringType)

So, as shown above, StringType is the default return type of a spark udf.

Method 1 :

Creating a Pyspark UDF by calling the udf method and passing the function and it’s respective returnType as parameters.

from pyspark.sql.functions import udf
from pyspark.sql.types import LongType

def cube_python(x):
    return x**3

# cube_python = lambda x: x**3   # Alternate to above function creation

spark_cube = udf(cube_python, LongType())'id')).show(5)
|              1|
|              8|
|             27|
|             64|
|            125|
only showing top 5 rows

Method 2 :

Using the decorator pattern in Python. This is the most simple and recommended method to create a udf in pyspark.

from pyspark.sql.functions import udf
from pyspark.sql.types import LongType

def cube_python(x):
    return x**3'id')).show(5)
|              1|
|              8|
|             27|
|             64|
|            125|
only showing top 5 rows

Below is the physical plan of the above statement,'id')).explain()
== Physical Plan ==
*(2) Project [pythonUDF0#628L AS cube_python(id)#626L]
+- BatchEvalPython [cube_python(id#616L)], [id#616L, pythonUDF0#628L]
   +- *(1) Range (1, 101, step=1, splits=8)

Register UDF in Spark SQL

To use a custom udf in Spark SQL, the user has to further register the UDF as a Spark SQL function.

To register a udf in pyspark, use the spark.udf.register method.

spark.udf.register("UDF Name", function, returnType=None)

There are 2 ways in which a Spark UDF can be registered,

Method 1:

Here, the function cube_python is wrapped under the udf method, hence the returnType is not specified.

spark.udf.register("spark_cube", cube_python)

spark.sql("select id, spark_cube(id) from sample").show(5)
| id|spark_cube(id)|
|  1|             1|
|  2|             8|
|  3|            27|
|  4|            64|
|  5|           125|
only showing top 5 rows
Method 2:

When a Python function is used directly, then the return type must be specified.

from pyspark.sql.types import IntegerType

spark.udf.register("spark_square", lambda x: x**2, returnType=IntegerType())

spark.sql("select id, spark_square(id) from sample").show(5)
| id|spark_square(id)|
|  1|               1|
|  2|               4|
|  3|               9|
|  4|              16|
|  5|              25|
only showing top 5 rows

PySpark UDF Examples

Now that we’ve seen how to create and register a UDF in PySpark, let’s dive in with some additional examples.

Before we get started with some examples, we must make note of the fact that only Python’s native datatypes are supported when creating a PySpark UDF and datatypes from other libraries such as Pandas & NumPy are not supported.

You can read more about the list of all supported datatypes from the official docs.

UDFs with Primitive DataTypes

Example 1: Let’s create a Pyspark UDF where we take a timestamp and it’s timezone as arguments, then convert the same to UTC and return it back as a string.

from pyspark.sql.functions import udf
from pyspark.sql.types import TimestampType
import datetime
import pytz

def convert_to_utc(time_col: datetime.datetime, time_zone: str):
    if time_zone and time_zone in pytz.all_timezones:
        time_col = time_col.replace(tzinfo=pytz.timezone(time_zone))
        return time_col.astimezone(pytz.timezone('UTC')).strftime('%Y-%m-%d %H:%M:%S')

# Create a Sample DataFrame
df = spark.createDataFrame([[1, datetime.datetime(2020, 6, 29, 1, 30, 5), 'Asia/Kolkata'],
                            [2, datetime.datetime(2020, 3, 20, 17, 40, 0), 'America/Los_Angeles'],
                            [3, datetime.datetime(2020, 4, 16, 3, 42, 0), 'Asia/Kolkata'],
                            [4, datetime.datetime(2020, 5, 26, 7, 13, 8), 'Europe/Berlin']],
                            ['id', 'time', 'time_zone'])

# Add Column using the UDF
df.withColumn('utc_time', convert_to_utc('time', 'time_zone')).show()
| id|               time|          time_zone|           utc_time|
|  1|2020-06-29 01:30:05|       Asia/Kolkata|2020-06-28 19:37:05|
|  2|2020-03-20 17:40:00|America/Los_Angeles|2020-03-21 01:33:00|
|  3|2020-04-16 03:42:00|       Asia/Kolkata|2020-04-15 21:49:00|
|  4|2020-05-26 07:13:08|      Europe/Berlin|2020-05-26 06:20:08|

Example 2: For this example, we will take in 2 numbers (a & b) and return a/b.

Let’s first create a dataframe and then invoke our UDF.

df = spark.createDataFrame([[2, 3], [7, 4], [23, 18], [40, 26], [27, 98], [20, None]], ['a', 'b'])
from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType

def python_div(a, b):
    if a and b:
        return a/b'a', 'b', python_div('a', 'b')).show(7)
|  a|   b|python_div(a, b)|
|  2|   3|       0.6666667|
|  7|   4|            1.75|
| 23|  18|       1.2777778|
| 40|  26|       1.5384616|
| 27|  98|       0.2755102|
| 20|null|            null|

UDFs with Complex DataTypes

Complex DataTypes in Spark include, ArrayType, StructType and MapType.

Pyspark UDF with StructType

Let’s create a Pyspark UDF which takes in an ArrayType and returns a StructType.

Let’s go ahead and create a sample dataframe, with array columns.

df = spark.createDataFrame([[1, [3, 4, 5, 6]], [2, [4, 2]], [3, [9, 10, 11, 15, 17]], [4, None], [5, [98]]], ['id', 'arr'])
from pyspark.sql.functions import udf
from pyspark.sql.types import StructType, StructField, IntegerType, ArrayType, MapType

struct_schema = StructType([
    StructField('size', IntegerType(), nullable=True),
    StructField('elements', ArrayType(IntegerType()), nullable=True)

def calc_len(arr):
    if arr:
        return {'size': len(arr), 'elements': arr}'id', 'arr', calc_len('arr')).show(5, truncate=False)
|id |arr                |calc_len(arr)           |
|1  |[3, 4, 5, 6]       |[4, [3, 4, 5, 6]]       |
|2  |[4, 2]             |[2, [4, 2]]             |
|3  |[9, 10, 11, 15, 17]|[5, [9, 10, 11, 15, 17]]|
|4  |null               |null                    |
|5  |[98]               |[1, [98]]               |
+---+-------------------+------------------------+'id', 'arr', calc_len('arr')).printSchema()
 |-- id: long (nullable = true)
 |-- arr: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- calc_len(arr): struct (nullable = true)
 |    |-- size: integer (nullable = true)
 |    |-- elements: array (nullable = true)
 |    |    |-- element: integer (containsNull = true)
Pyspark UDF with ArrayType

Now, for our next example let’s create a UDF in Pyspark with returnType as ArrayType which takes in an array of integers as argument and returns an array of floating point numbers.

from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType, ArrayType

def float_arr(arr):
    if arr:
        return [i/2 for i in arr]

df.withColumn('float_arr', float_arr('arr')).show(5, truncate=False)
|id |arr                |float_arr                |
|1  |[3, 4, 5, 6]       |[1.5, 2.0, 2.5, 3.0]     |
|2  |[4, 2]             |[2.0, 1.0]               |
|3  |[9, 10, 11, 15, 17]|[4.5, 5.0, 5.5, 7.5, 8.5]|
|4  |null               |null                     |
|5  |[98]               |[49.0]                   |

Scala UDF in PySpark

Python UDFs in Spark are slow in execution and lead to performance degradation due to the cost of serializing the data to Python and launching of the python process in each individual executor node.

So, an alternate solution is to write a Scala UDF and call the same from Pyspark. A benchmarking between a Pyspark UDF and a Scala UDF is also demonstrated in the coming sections.

Let’s create a UDF in Scala which takes in an integer as argument and determines whether it is odd or even.

Here, the UDF1 class is imported because we are passing a single argument to the UDF. So, if the UDF is supposed to take 2 arguments, then UDF2 must be imported, similarly for 3 arguments, UDF3 and so on.

Package the code into a JAR using a build tool like sbt

package com.custom.scala.udf

import{UDF1, UDF2}

// UDF to check if the number is odd or even
class OddOrEven extends UDF1[Long, String] {

  override def call(x: Long): String = {
    if (x % 2 == 0) "even" else "odd"

// UDF to compute the product of 2 numbers
class IntProduct extends UDF2[Long, Long, Long] {

  override def call(a: Long, b: Long): Long = {
    return a * b

Registering Scala UDF in Pyspark

To register a Scala UDF in Pyspark, follow the below mentioned steps,

Step 1: Add JAR to the Spark Session

Note: If the Spark job is running in cluster mode, then the JAR must be placed either in the local file system of all the nodes or in a distributed file system where all nodes have access to.

spark.sql("add jar /<path-to-jar>/scala_udf.jar")
DataFrame[result: int]

Step 2: Register the function as a Java Function in PySpark using the spark.udf.registerJavaFunction method.

from pyspark.sql.types import LongType, StringType

spark.udf.registerJavaFunction("int_product", "com.custom.scala.udf.IntProduct", returnType=LongType())
spark.udf.registerJavaFunction("odd_or_even", "com.custom.scala.udf.OddOrEven", returnType=StringType())
from pyspark.sql.functions import expr

df = spark.createDataFrame([[2, 3], [7, 4], [23, 18], [40, 26], [27, 98], [20, None]], ['a', 'b'])'a', 'b', expr("int_product(a, b)").alias('product'), expr("odd_or_even(a)")).show()
|  a|   b|product|UDF:odd_or_even(a)|
|  2|   3|      6|              even|
|  7|   4|     28|               odd|
| 23|  18|    414|               odd|
| 40|  26|   1040|              even|
| 27|  98|   2646|               odd|
| 20|null|      0|              even|

Since there is no conversion to/from Python involved here, scala udf yields a significant speed up over a python udf in pyspark.

Below is the physical plan of the above execution,'a', 'b', expr("int_product(a, b)").alias('product'), expr("odd_or_even(a)")).explain()
== Physical Plan ==
*(1) Project [a#586L, b#587L, UDF:int_product(a#586L, b#587L) AS product#610L, UDF:odd_or_even(a#586L) AS UDF:odd_or_even(a)#611]
+- Scan ExistingRDD[a#586L,b#587L]

Pandas UDF in PySpark

Pandas UDF also known as vectorized UDF is a user defined function in Spark which uses Apache Arrow to transfer data to and from Pandas and is executed in a vectorized way.

Apache Arrow is an in-memory columnar storage used by Pandas to access the data sent by the Spark JVM process.

Traditional Python UDFs in Spark execute row by row, whereas Pandas UDF in Pyspark take in a batch of rows and execute them together and return the result back as a batch. Hence, a Pandas UDF is invoked for every batch of rows instead of a row by row execution.

Here is an example of Pandas Scalar UDF which takes in a Pandas Series object as input and returns the output as a Pandas Series object.

Note: Pandas UDFs are available in Spark from version 2.3

Pre-requisites for Pandas UDF to work in Spark,

  • Spark >= 2.3.0 (version)
  • pyarrow - pip install pyarrow
  • The versions used here are, Spark 2.3.0, Python 3.6 and Pyarrow 0.14.1

    For our example, we will take in a pandas series of integers as input and return a pandas series of strings which determine if the input number is odd or even.

    from pyspark.sql.functions import pandas_udf, PandasUDFType
    from pyspark.sql.types import StringType
    @pandas_udf("string", functionType=PandasUDFType.SCALAR)
    def pandas_odd_or_even(x):
        return (x%2).map({0: "even", 1: "odd"})
    df = spark.createDataFrame([[2, 3], [7, 4], [23, 18], [40, 26], [27, 98], [20, None]], ['a', 'b'])'a', pandas_odd_or_even('a')).show()
    |  a|pandas_odd_or_even(a)|
    |  2|                 even|
    |  7|                  odd|
    | 23|                  odd|
    | 40|                 even|
    | 27|                  odd|
    | 20|                 even|

    The physical plan of the above statement is shown below which clearly depicts the usage of Arrow and Python udf'a', pandas_odd_or_even('a')).explain()
    == Physical Plan ==
    *(2) Project [a#113L, pythonUDF0#192 AS pandas_odd_or_even(a)#189]
    +- ArrowEvalPython [pandas_odd_or_even(a#113L)], [a#113L, pythonUDF0#192]
       +- *(1) Project [a#113L]
          +- Scan ExistingRDD[a#113L,b#114L]

    Performance Benchmark

    Let’s perform a benchmark test between a Python UDF, Scala UDF and Pandas UDF in Pyspark. We will create the same function in each variant and compare the execution times.

    You can also do a self benchmark in your own setup or cluster to get a better idea of the execution times.

    Let’s first create our dataframe which will consist of 50 million integers, as shown below,

    df = spark.range(1, 50000001)

    Next, we will create a UDF which will determine if the given integer is odd or even. We will take a note of each UDF’s execution time and finally compare the results.

    Python UDF or PySpark UDF Performance

    First, let’s create a python function to check if the number is odd or even. Then register it in spark to be used as a pyspark udf.

    from pyspark.sql.functions import udf
    from pyspark.sql.types import StringType
    def odd_or_even(x):
        return "odd" if x%2!=0 else "even" 
    spark.udf.register("odd_or_even_py", odd_or_even)
    import time
    start_time = time.time()
    df3 = spark.sql("select odd_or_even_py(id) as odd_or_even_py from data")
    python_time = time.time() - start_time
    print(f"Python UDF Time Taken: {python_time}")
    Python UDF Time Taken: 18.055819034576416

    Scala UDF Performance

    Here, the same function will be created in Scala and executed in Pyspark.

    package com.custom.scala.udf
    // UDF to check if the number is odd or even
    class OddOrEven extends UDF1[Long, String] {
      override def call(x: Long): String = {
        if (x % 2 == 0) "even" else "odd"

    Once we package the code into a JAR, the same will be added to the spark session and registered using the spark.udf.registerJavaFunction

    from pyspark.sql.types import StringType
    spark.sql("add jar /<path-to-jar>/scala_udf.jar")
    spark.udf.registerJavaFunction("odd_or_even_scala", "com.custom.scala.udf.OddOrEven", returnType=StringType())
    import time
    start_time = time.time()
    df2 = spark.sql("select odd_or_even_scala(id) as odd_or_even_sc from data")
    scala_time = time.time() - start_time
    print(f"Scala UDF Time Taken: {scala_time}")
    Scala UDF Time Taken: 3.469304084777832

    Pandas UDF Performance

    Finally, we will create a Pandas UDF in Pyspark and compare it’s execution time.

    from pyspark.sql.functions import pandas_udf, PandasUDFType
    from pyspark.sql.types import StringType
    @pandas_udf("string", functionType=PandasUDFType.SCALAR)
    def pandas_odd_or_even(x):
        return (x%2).map({0: "even", 1: "odd"})
    spark.udf.register("odd_or_even_pandas", pandas_odd_or_even)
    import time
    start_time = time.time()
    df3 = spark.sql("select odd_or_even_pandas(id) as odd_or_even_py from data")
    pandas_time = time.time() - start_time
    print(f"Pandas UDF Time Taken: {pandas_time}")
    Pandas UDF Time Taken: 9.623993158340454

    We can clearly see from the above results that Scala UDF is the quickest in terms of execution and offers the best performance followed by the Pandas UDF. The results can also be visualized as shown below.

    import matplotlib.pyplot as plt
    import numpy as np
    def visualize(plot_objects, time_taken):
        y_pos = np.arange(len(plot_objects))
        plt.barh(y_pos, time_taken, align='center')
        plt.yticks(y_pos, plot_objects)
        plt.xlabel('Time Taken in seconds')
        plt.title('Spark UDF Benchmarking')
    visualize(['Scala UDF', 'Pandas UDF', 'Python UDF'], [scala_time, pandas_time, python_time])



    By now, you must be familiar with the internals of a UDF in Pyspark and how it can be created and used. Pyspark UDF with both primitive and complex datatypes was also discussed.

    In addition to this, you must be clear why a Scala UDF or Pandas UDF in Pyspark is more preferred than a traditional Pyspark UDF. Hope you can leverage these concepts to create neat and powerful UDFs in Spark. Comments and feedback are welcome. Please do leave a like if you’ve found this post useful.

    comments powered by Disqus