Pyspark DataFrame Operations - Basics | Pyspark DataFrames

November 20, 2018

In this post, we will be discussing on how to work with dataframes in pyspark and perform different spark dataframe operations such as a aggregations, ordering, joins and other similar data manipulations on a spark dataframe.

Introduction

Spark Dataframe API enables the user to perform parallel and distributed structured data processing on the input data. A Spark dataframe is a dataset with a named set of columns.

By the end of this post, you should be familiar in performing the most frequently used data manipulations on a spark dataframe.

Table of Contents

  • What is a Spark Dataframe?
  • Spark Dataframe Features
  • Spark DataFrame Operations
    • Create Spark DataFrame
    • Spark DataFrame Schema
    • Count of a Spark DataFrame
    • Display DataFrame Data
    • Remove Duplicate rows from a DataFrame
    • Distinct Column Values
    • Spark Filter Data
    • Sorting/Ordering Data in Spark
    • Grouping & Performing Aggregations
    • Join DataFrames in Spark
    • Limit data from a dataframe
    • Union Dataframes Spark
    • How to rename spark dataframe columns
    • DataType Casting
    • Spark Cache DataFrame
    • Unpersist Dataframe
    • Replace Nulls in Spark
    • Partition Data in Spark
    • Spark DataFrame Write
    • Create Temporary View in Spark
    • Spark SQL

What is a Spark Dataframe?

A pyspark dataframe or spark dataframe is a distributed collection of data along with named set of columns. It is similar to a table in a relational database and has a similar look and feel. The dataframe can be derived from a dataset which can be delimited text files, Parquet & ORC Files, CSVs, RDBMS Table, Hive Table, RDDs etc. In addition to this, a dataframe can also be constructed from semi-structured formats such as JSON and XML. The dataframe API is a very powerful one bundled with extensive features and rich optimizations.

Dataframe Features

Below are some of the features of a pyspark dataframe,

  • Unified Data Access
  • Ability to handle structured and semi-structured data
  • Supports a wide variety of Data Sources
  • Profuse Features for Data Manipulations and Aggregations
  • Supports multiple languages such as Python, Java, R & Scala

Spark DataFrame Operations

Some of the basic and frequently used spark dataframe operations would be discussed below.

Before we start, let’s create our SparkSession and sparkContext. (Note: These parameters are automatically created if you’re accessing spark via spark shell)

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('Spark Training').getOrCreate()
sc = spark.sparkContext

Create Spark DataFrame

Below are some of the methods to create a pyspark dataframe.

  • Creating Spark Dataframe from CSV File using spark.read.csv method.

For this example, a countrywise population by year dataset is chosen. The dataset can be downloaded here, population_dataset

df = spark.read.format('csv').options(delimiter=',', header=True).load('/Path-to-file/population.csv')
  • Convert RDD to Dataframe.

Below method shows how to create DataFrame from RDD. The toDF() method can be used to convert the RDD to a dataframe

rdd = sc.parallelize([(1,2),(3,4),(5,6)])
rdf = rdd.toDF()
  • Using spark.createDataFrame method.

Creating a DataFrame from a list of values. Schema is inferred dynamically, if not specified.

tdf = spark.createDataFrame([('Alice',24),('David',43)],['name','age'])
tdf.printSchema()
root
 |-- name: string (nullable = true)
 |-- age: long (nullable = true)



Spark DataFrame Schema

The df.printSchema() method can be used to display the schema of spark dataframe

df.printSchema()
root
 |-- Country Name: string (nullable = true)
 |-- Country Code: string (nullable = true)
 |-- Year: string (nullable = true)
 |-- Value: string (nullable = true)

To obtain the raw schema of a dataframe, the df.schema method can be used.

df.schema
StructType(List(StructField(Country Name,StringType,true),StructField(Country Code,StringType,true),StructField(Year,LongType,true),StructField(Value,LongType,true)))

To display the columns of dataframe, df.columns method can be used. A list consisting of the columns is generated.

df.columns
['Country Name', 'Country Code', 'Year', 'Value']



Count of a Spark DataFrame

df.count()
14885



Display DataFrame Data

Display 5 rows. Truncate=False can be enabled for displaying entire column data on your terminal

df.show(5, truncate=False)
+------------+------------+----+---------+
|Country Name|Country Code|Year|Value    |
+------------+------------+----+---------+
|Arab World  |ARB         |1960|92490932 |
|Arab World  |ARB         |1961|95044497 |
|Arab World  |ARB         |1962|97682294 |
|Arab World  |ARB         |1963|100411076|
|Arab World  |ARB         |1964|103239902|
+------------+------------+----+---------+
only showing top 5 rows



Remove Duplicate rows from a DataFrame

df.dropDuplicates() can be used to remove duplicates from a spark dataframe.

df.dropDuplicates()
DataFrame[Country Name: string, Country Code: string, Year: bigint, Value: bigint]



Distinct Column Values

To display distinct rows of a dataframe, df.distinct() can be used. For our example let’s select distinct country code from the dataset

df.select('Country Code').distinct().show()
+------------+
|Country Code|
+------------+
|         HTI|
|         PSE|
|         LTE|
|         BRB|
|         LVA|
|         POL|
|         ECS|
|         TEA|
|         JAM|
|         ZMB|
|         MIC|
|         BRA|
|         ARM|
|         IDA|
|         MOZ|
|         CUB|
|         JOR|
|         OSS|
|         ABW|
|         FRA|
+------------+
only showing top 20 rows



Spark Filter Data

df.filter() method can be used to filter in pyspark. In our example, let’s display population of India for the year 2015,2016 & - and

| - or

Note: Remember to wrap the conditions with braces when ‘&’ or ‘|’ is used.

from pyspark.sql.functions import col
df.filter((col('Country Name') == 'India') & (col('Year').isin('2015','2016'))).show()
+------------+------------+----+----------+
|Country Name|Country Code|Year|     Value|
+------------+------------+----+----------+
|       India|         IND|2015|1309053980|
|       India|         IND|2016|1324171354|
+------------+------------+----+----------+



Sorting/Ordering Data in Spark

Below example illustrates how the country names can be displayed in descending order

df.select('Country Name').orderBy('Country Name', ascending=False).distinct().show(truncate=False)
+------------------------+
|Country Name            |
+------------------------+
|Zimbabwe                |
|Zambia                  |
|Yemen, Rep.             |
|World                   |
|West Bank and Gaza      |
|Virgin Islands (U.S.)   |
|Vietnam                 |
|Venezuela, RB           |
|Vanuatu                 |
|Uzbekistan              |
|Uruguay                 |
|Upper middle income     |
|United States           |
|United Kingdom          |
|United Arab Emirates    |
|Ukraine                 |
|Uganda                  |
|Tuvalu                  |
|Turks and Caicos Islands|
|Turkmenistan            |
+------------------------+
only showing top 20 rows



Grouping & Performing Aggregations in a Spark Dataframe

  • Obtain the total count of distinct years present in the entire dataset
  • # Count of Years
    df.select('Year').distinct().groupBy().count().show()
    
    +-----+
    |count|
    +-----+
    |   57|
    +-----+
    



  • Calculate sum of dataframe – Compute the total world population for the year 1990
  • df.filter(col('Year') == '1990').agg({'Value':'sum'}).show(truncate=False)
    
    +---------------+
    |sum(Value)     |
    +---------------+
    |5.4935613753E10|
    +---------------+
    


  • Calculate average of dataframe – Compute the average population in India for the year 2005
  • df.filter((col('Year') == '2005') & (col('Country Name') == 'India')).agg({'Value':'avg'}).show(truncate=False)
    
    +-------------+
    |avg(Value)   |
    +-------------+
    |1.144118674E9|
    +-------------+
    


  • Computing Minimum of a column in dataframe – Display the least population for the year 2010
  • df.filter(df.Year == '2007').agg({'Value':'min'}).show()
    
    +----------+
    |min(Value)|
    +----------+
    |     10075|
    +----------+
    


  • Computing Maximum of a column in dataframe – Display country with the largest population for the year 2016
  • df.filter(df.Value == df.filter(df.Year == '2016').agg({'Value':'max'}).collect()[0][0]).show()
    
    +------------+------------+----+----------+
    |Country Name|Country Code|Year|     Value|
    +------------+------------+----+----------+
    |       World|         WLD|2016|7442135578|
    +------------+------------+----+----------+
    



    Spark Join DataFrames

    A pyspark dataframe can be joined with another using the df.join method. df.join takes 3 arguments, join(other, on=None, how=None)

  • other - dataframe to be joined with
  • on - on condition of the join
  • how - type of join. inner join is set by default if not specified
  • Other types of joins which can be specified are, inner, cross, outer, full, full_outer, left, left_outer, right, right_outer, left_semi, and left_anti

    Below is an example illustrating an inner join in pyspark
    Let’s construct 2 dataframes,
    One with only distinct values of country name and country code and the other with country code, value and year
    Country code would be the join condition here

    df1 = df.select('Country Name', 'Country Code').distinct()
    df1.count()
    
    263
    
    df2 = df.select(col('Country Code').alias('ctry_cd'), 'Value', 'Year').distinct()
    df2.count()
    
    14885
    


    Now let’s join both the dataframes on country_code and display the data

    from pyspark.sql.functions import col
    df1.join(df2, col('Country Code') == col('ctry_cd')).show(5)
    
    +--------------------+------------+-------+----------+----+
    |        Country Name|Country Code|ctry_cd|     Value|Year|
    +--------------------+------------+-------+----------+----+
    |East Asia & Pacif...|         EAP|    EAP|1878255588|2004|
    |Europe & Central ...|         ECA|    ECA| 396886165|2001|
    |           IDA blend|         IDB|    IDB| 135810058|1964|
    |           IDA blend|         IDB|    IDB| 403526930|2005|
    |            IDA only|         IDX|    IDX| 984961696|2013|
    +--------------------+------------+-------+----------+----+
    only showing top 5 rows
    


    Country Code seems to be redundant here, so while displaying this can be removed using the drop method

    df1.join(df2, col('Country Code') == col('ctry_cd')).drop(col('ctry_cd')).show(5,False)
    
    +---------------------------------------------+------------+----------+----+
    |Country Name                                 |Country Code|Value     |Year|
    +---------------------------------------------+------------+----------+----+
    |East Asia & Pacific (excluding high income)  |EAP         |1878255588|2004|
    |Europe & Central Asia (excluding high income)|ECA         |396886165 |2001|
    |IDA blend                                    |IDB         |135810058 |1964|
    |IDA blend                                    |IDB         |403526930 |2005|
    |IDA only                                     |IDX         |984961696 |2013|
    +---------------------------------------------+------------+----------+----+
    only showing top 5 rows
    



    Limit data from a spark dataframe

    df.limit method can be used to limit data in a pyspark dataframe.

    df2 = df.limit(10)
    df2.count()
    
    10
    



    Union Spark Dataframes

    Below examples combine 2 dataframes holding the first and last ten rows respectively

    # Combine 2 Dataframes
    df1 = df.orderBy('Country Code').limit(10)
    df2 = df.orderBy('Country Code', ascending=False).limit(10)
    df1.union(df2).show()
    
    +------------+------------+----+-------+
    |Country Name|Country Code|Year|  Value|
    +------------+------------+----+-------+
    |       Aruba|         ABW|1960|  54211|
    |       Aruba|         ABW|1969|  58726|
    |       Aruba|         ABW|1961|  55438|
    |       Aruba|         ABW|1962|  56225|
    |       Aruba|         ABW|1963|  56695|
    |       Aruba|         ABW|1964|  57032|
    |       Aruba|         ABW|1965|  57360|
    |       Aruba|         ABW|1966|  57715|
    |       Aruba|         ABW|1967|  58055|
    |       Aruba|         ABW|1968|  58386|
    |    Zimbabwe|         ZWE|1960|3747369|
    |    Zimbabwe|         ZWE|1969|5009514|
    |    Zimbabwe|         ZWE|1961|3870756|
    |    Zimbabwe|         ZWE|1962|3999419|
    |    Zimbabwe|         ZWE|1963|4132756|
    |    Zimbabwe|         ZWE|1964|4269863|
    |    Zimbabwe|         ZWE|1965|4410212|
    |    Zimbabwe|         ZWE|1966|4553433|
    |    Zimbabwe|         ZWE|1967|4700041|
    |    Zimbabwe|         ZWE|1968|4851431|
    +------------+------------+----+-------+
    



    How to rename spark dataframe columns

  • Below method shows how a simple alias function can be used to rename a column in spark dataframe
  • # Rename Column
    # Method 1
    df.select(col('Country Name').alias('country'), col('Country Code').alias('code')).show(2)
    
    +----------+----+
    |   country|code|
    +----------+----+
    |Arab World| ARB|
    |Arab World| ARB|
    +----------+----+
    only showing top 2 rows
    


  • Another Method which can be used is df.withColumnRenamed method to display the column name according to your requirement.
  • # Rename Column
    # Method 2
    df.select('Year').distinct().groupBy().count().withColumnRenamed('count', 'year_count').show()
    
    +----------+
    |year_count|
    +----------+
    |        57|
    +----------+
    


  • Handy Method to rename multiple columns in a dataframe using toDF
  • # Rename Multiple Columns
    # Method 3
    new_col_names = ['country','code','yr','val']
    df.toDF(*new_col_names).columns
    
    ['country', 'code', 'yr', 'val']
    



    DataType Casting

    A simple cast method can be used to explicitly cast a column from one datatype to another in a dataframe. Below example shows how to convert the value column from string to bigint.

    df.select(df['Value'].cast('bigint')).printSchema()
    
    root
     |-- Value: long (nullable = true)
    
    Wait! But what if you’d like to cast mutiple columns at a shot?

    There are several ways to achieve this. I would like to discuss to easy ways which isn’t very tedious. One way is to use a list of column datatypes and the column names and iterate over the same to cast the columns in one loop.

    Another simpler way is to use Spark SQL to frame a SQL query to cast the columns.

    Below example depicts a concise way to cast multiple columns using a single for loop without having to repetitvely use the cast function in the code.

    from pyspark.sql.functions import col
    cols = df.columns #> ['Country Name', 'Country Code', 'Year', 'Value']
    datatypes = ['string', 'string', 'bigint', 'bigint']
    for i in range(len(cols)):
        df = df.withColumn(cols[i], col(cols[i]).cast(datatypes[i]))
    df = df.select(*cols)
    df.printSchema()
    
    root
     |-- Country Name: string (nullable = true)
     |-- Country Code: string (nullable = true)
     |-- Year: long (nullable = true)
     |-- Value: long (nullable = true)
    



    Spark Cache DataFrame

    This would be useful when dataframe is being called multiple times. The dataframe would be cached in memory, hence the data retrieval latency would be lower

    df.persist()
    
    DataFrame[Country Name: string, Country Code: string, Year: string, Value: string]
    


    A Dataframe can be verified if it’s present in the cache or not using the storageLevel() method.

    True condition indicates dataframe is present is already cached.

    df.storageLevel
    
    StorageLevel(True, True, False, False, 1)
    



    Unpersist Dataframe

    df.unpersist() method can be used to unpersist pyspark dataframe

    df.unpersist()
    
    DataFrame[Country Name: string, Country Code: string, Year: string, Value: string]
    


    StorageLevel After uncaching Dataframe

    df.storageLevel
    
    StorageLevel(False, False, False, False, 1)
    



    Replace Nulls in Spark

    Replace Nulls in a dataframe with some user defined value

    df3 = df.fillna('-99')
    



    Partition Data in Spark

    repartition method can be used to partition the data according to the columns or a defined number. The repartition algorithm performs a full data shuffle creating equally distributed chunks of data among the partitions. The resulting dataframe is hash partitioned.

    Repartition can be done in 2 ways,

  • Passing an int value to repartition method can help create partitions based on the integer argument
  • Passing a column name, would create the partitions based on the distinct column values
  • Caution: Repartition performs a full shuffle on the data. Providing an incorrect input might result in a large file getting created or may sometimes result in out of memory error

    In our example, we will partition the data according to country name & compute the total number of partitions Here there are 263 country names in the dataset, but only 200 files would be created if this dataframe is saved. This is because repartition by default takes in the value present in spark.sql.shuffle.partitions if integer value is not explicitly provided

    df3 = df.repartition('Country Name')
    df3.rdd.getNumPartitions()
    
    200
    


    As shown below the value of the property by default is 200. This value can be changed using the conf.set method

    spark.conf.get("spark.sql.shuffle.partitions")
    
    '200'
    
    spark.conf.set("spark.sql.shuffle.partitions", "300")
    df3 = df.repartition('Country Name')
    df3.rdd.getNumPartitions()
    
    300
    


    Partition Data according to an integer value

    df4 = df.repartition(50)
    df4.rdd.getNumPartitions()
    
    50
    



    Spark DataFrame Write

    A Dataframe can be saved in multiple formats such as parquet, ORC and even plain delimited text files. Below example illustrates how to write pyspark dataframe to CSV file.

    df.write.format('csv').option('delimiter','|').save('Path-to_file')
    

    A Dataframe can be saved in multiple modes, such as,

  • append - appends to existing data in the path
  • overwrite - Overwrites existing data with the dataframe being saved
  • ignore - Does nothing if data exists
  • error (or) errorifexists - raises an exception if data is already present (Default)
  • Below method illustrates how the above save can be performed with overwrite mode

    df.write.format('csv').option('delimiter','|').mode('overwrite').save('Path-to_file')
    



    Create Temporary View in Spark

    The Dataframe can be saved as temporary view which is present as long as that spark session is active

    # Save Dataframe as Temp View
    df.createOrReplaceTempView('population')
    # Above view can be used to perform Spark SQL queries 
    



    Spark SQL

  • Display Data using Spark SQL
  • spark.sql("select * from population limit 5").show()
    
    +------------+------------+----+---------+
    |Country Name|Country Code|Year|    Value|
    +------------+------------+----+---------+
    |  Arab World|         ARB|1960| 92490932|
    |  Arab World|         ARB|1961| 95044497|
    |  Arab World|         ARB|1962| 97682294|
    |  Arab World|         ARB|1963|100411076|
    |  Arab World|         ARB|1964|103239902|
    +------------+------------+----+---------+
    


  • Get Max Year from the dataset
  • spark.sql("select max(year) as max_year from population").show()
    
    +--------+
    |max_year|
    +--------+
    |    2016|
    +--------+
    


  • Display population for Japan & India for the years between 1990 to 1995
  • spark.sql("""select `Country Name`, year,  value as population from population where `Country Name` in ('India','Japan')
              and cast(year as bigint) between 1990 and 1995""").show()
    
    +------------+----+----------+
    |Country Name|year|population|
    +------------+----+----------+
    |       India|1990| 870133480|
    |       India|1991| 888054875|
    |       India|1992| 906021106|
    |       India|1993| 924057817|
    |       India|1994| 942204249|
    |       India|1995| 960482795|
    |       Japan|1990| 123537000|
    |       Japan|1991| 123921000|
    |       Japan|1992| 124229000|
    |       Japan|1993| 124536000|
    |       Japan|1994| 124961000|
    |       Japan|1995| 125439000|
    +------------+----+----------+
    

    By now, you should be familiar on performing basic operations on a Spark Dataframe. I strongly recommend you to use a random dataset and practice the above operations to get a hold of it.

    Feedback and comments are welcome and the same can be posted on the comment section below. Hope this post was helpful. Cheers!

    comments powered by Disqus