Pyspark DataFrame Operations - Basics

November 20, 2018

In this post, we will be discussing on how to perform different dataframe operations such as a aggregations, ordering, joins and other similar data manipulations on a spark dataframe.

Introduction

Spark provides the Dataframe API, which is a very powerful API which enables the user to perform parallel and distrivuted structured data processing on the input data. A Spark dataframe is a dataet with a named set of columns.

By the end of this post, you should be familiar on performing the most frequently data manipulations on a spark dataframe.

Table of Contents

  • What is a Spark Dataframe?
  • Dataframe Features
  • DataFrame Operations
    • Create a DataFrame
    • DataFrame Schema
    • Count of a DataFrame
    • Display DataFrame Data
    • Remove Duplicate rows from a DataFrame
    • Distinct Column Values
    • Filtering Data
    • Sorting/Ordering Data
    • Grouping & Performing Aggregations
    • Join DataFrames
    • Limit data from a dataframe
    • Union 2 Dataframes
    • How to change Dataframe columns
    • DataType Casting
    • Cache a DataFrame
    • Unpersist Dataframe
    • Replace Nulls
    • Partition Data
    • DataFrame Write
    • Create Temp View
    • Spark SQL

What is a Spark Dataframe?

A Dataframe is a distributed collection of data along with named set of columns. It is similar to a table in a relational database and has a similar look and feel. The dataframe can be derived from a dataset which can be delimited text files, Parquet & ORC Files, CSVs, RDBMS Table, Hive Table, RDDs etc. In addition to this, a dataframe can also be constructed from semi-structured formats such as JSON and XML. The dataframe API is a very powerful one bundled with extensive features and rich optimizations.

Dataframe Features

  • Unified Data Access
  • Ability to handle structured and semi-structured data
  • Supports a wide variety of Data Sources
  • Profuse Features for Data Manipulations and Aggregations
  • Supports multiple languages such as Python, Java, R & Scala

DataFrame Operations

Some of the basic and frequently used dataframe operations would be discussed below.

Before we start, let’s create our SparkSession and sparkContext. (Note: These parameters are automatically created if you’re accessing spark via spark shell)

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('Spark Training').getOrCreate()
sc = spark.sparkContext

Create DataFrame

For this example, a countrywise population by year dataset is chosen. The dataset can be downloaded here, population_dataset

df = spark.read.format('csv').options(delimiter=',', header=True).load('/Path-to-file/population.csv')


Create DataFrame from RDD. The toDF() method can be used to convert the RDD to a dataframe

rdd = sc.parallelize([(1,2),(3,4),(5,6)])
rdf = rdd.toDF()


Creating a DataFrame from a list of values. Schema is inferred dynamically, if not specified

tdf = spark.createDataFrame([('Alice',24),('David',43)],['name','age'])
tdf.printSchema()
root
 |-- name: string (nullable = true)
 |-- age: long (nullable = true)



DataFrame Schema

The printSchema() method can be used to display a description of the dataframe

df.printSchema()
root
 |-- Country Name: string (nullable = true)
 |-- Country Code: string (nullable = true)
 |-- Year: string (nullable = true)
 |-- Value: string (nullable = true)


Obtain the raw schema of a dataframe

df.schema
StructType(List(StructField(Country Name,StringType,true),StructField(Country Code,StringType,true),StructField(Year,LongType,true),StructField(Value,LongType,true)))


Generate the columns of the dataset. A list consisting of the columns is generated

df.columns
['Country Name', 'Country Code', 'Year', 'Value']



Count of a DataFrame

df.count()
14885



Display DataFrame Data

Display 5 rows. Truncate=False can be enabled for displaying entire column data on your terminal

df.show(5, truncate=False)
+------------+------------+----+---------+
|Country Name|Country Code|Year|Value    |
+------------+------------+----+---------+
|Arab World  |ARB         |1960|92490932 |
|Arab World  |ARB         |1961|95044497 |
|Arab World  |ARB         |1962|97682294 |
|Arab World  |ARB         |1963|100411076|
|Arab World  |ARB         |1964|103239902|
+------------+------------+----+---------+
only showing top 5 rows



Remove Duplicate rows from a DataFrame

df.dropDuplicates()
DataFrame[Country Name: string, Country Code: string, Year: bigint, Value: bigint]



Distinct Column Values

Select distinct country code from the dataset

df.select('Country Code').distinct().show()
+------------+
|Country Code|
+------------+
|         HTI|
|         PSE|
|         LTE|
|         BRB|
|         LVA|
|         POL|
|         ECS|
|         TEA|
|         JAM|
|         ZMB|
|         MIC|
|         BRA|
|         ARM|
|         IDA|
|         MOZ|
|         CUB|
|         JOR|
|         OSS|
|         ABW|
|         FRA|
+------------+
only showing top 20 rows



Filtering Data

Display Population of India for the year 2015,2016 & - and

| - or

Note: Remember to wrap the conditions with braces when ‘&’ or ‘|’ is used.

from pyspark.sql.functions import col
df.filter((col('Country Name') == 'India') & (col('Year').isin('2015','2016'))).show()
+------------+------------+----+----------+
|Country Name|Country Code|Year|     Value|
+------------+------------+----+----------+
|       India|         IND|2015|1309053980|
|       India|         IND|2016|1324171354|
+------------+------------+----+----------+



Sorting/Ordering Data

Below example illustrates how the country names can be displayed in descending order

df.select('Country Name').orderBy('Country Name', ascending=False).distinct().show(truncate=False)
+------------------------+
|Country Name            |
+------------------------+
|Zimbabwe                |
|Zambia                  |
|Yemen, Rep.             |
|World                   |
|West Bank and Gaza      |
|Virgin Islands (U.S.)   |
|Vietnam                 |
|Venezuela, RB           |
|Vanuatu                 |
|Uzbekistan              |
|Uruguay                 |
|Upper middle income     |
|United States           |
|United Kingdom          |
|United Arab Emirates    |
|Ukraine                 |
|Uganda                  |
|Tuvalu                  |
|Turks and Caicos Islands|
|Turkmenistan            |
+------------------------+
only showing top 20 rows



Grouping & Performing Aggregations

  • Obtain the total count of distinct years present in the entire dataset
  • # Count of Years
    df.select('Year').distinct().groupBy().count().show()
    
    +-----+
    |count|
    +-----+
    |   57|
    +-----+
    



  • SUM – Compute the total world population for the year 1990
  • df.filter(col('Year') == '1990').agg({'Value':'sum'}).show(truncate=False)
    
    +---------------+
    |sum(Value)     |
    +---------------+
    |5.4935613753E10|
    +---------------+
    


  • AVG – Compute the average population in India for the year 2005
  • df.filter((col('Year') == '2005') & (col('Country Name') == 'India')).agg({'Value':'avg'}).show(truncate=False)
    
    +-------------+
    |avg(Value)   |
    +-------------+
    |1.144118674E9|
    +-------------+
    


  • MIN – Display the least population for the year 2010
  • df.filter(df.Year == '2007').agg({'Value':'min'}).show()
    
    +----------+
    |min(Value)|
    +----------+
    |     10075|
    +----------+
    


  • MAX – Display country with the largest population for the year 2016
  • df.filter(df.Value == df.filter(df.Year == '2016').agg({'Value':'max'}).collect()[0][0]).show()
    
    +------------+------------+----+----------+
    |Country Name|Country Code|Year|     Value|
    +------------+------------+----+----------+
    |       World|         WLD|2016|7442135578|
    +------------+------------+----+----------+
    



    Join DataFrames

    A dataframe can be joined with another using the .join method. join takes 3 arguments, join(other, on=None, how=None)

  • other - dataframe to be joined with
  • on - on condition of the join
  • how - type of join. inner join is set by default if not specified
  • Other types of joins which can be specified are, inner, cross, outer, full, full_outer, left, left_outer, right, right_outer, left_semi, and left_anti
  • Below is an example illustrating an inner join
    Let’s construct 2 dataframes,
    One with only distinct values of country name and country code and the other with country code, value and year
    Country code would be the join condition here

    df1 = df.select('Country Name', 'Country Code').distinct()
    df1.count()
    
    263
    
    df2 = df.select(col('Country Code').alias('ctry_cd'), 'Value', 'Year').distinct()
    df2.count()
    
    14885
    


    Now let’s join both the dataframes on country_code and display the data

    from pyspark.sql.functions import col
    df1.join(df2, col('Country Code') == col('ctry_cd')).show(5)
    
    +--------------------+------------+-------+----------+----+
    |        Country Name|Country Code|ctry_cd|     Value|Year|
    +--------------------+------------+-------+----------+----+
    |East Asia & Pacif...|         EAP|    EAP|1878255588|2004|
    |Europe & Central ...|         ECA|    ECA| 396886165|2001|
    |           IDA blend|         IDB|    IDB| 135810058|1964|
    |           IDA blend|         IDB|    IDB| 403526930|2005|
    |            IDA only|         IDX|    IDX| 984961696|2013|
    +--------------------+------------+-------+----------+----+
    only showing top 5 rows
    


    Country Code seems to be redundant here, so while displaying this can be removed using the drop method

    df1.join(df2, col('Country Code') == col('ctry_cd')).drop(col('ctry_cd')).show(5,False)
    
    +---------------------------------------------+------------+----------+----+
    |Country Name                                 |Country Code|Value     |Year|
    +---------------------------------------------+------------+----------+----+
    |East Asia & Pacific (excluding high income)  |EAP         |1878255588|2004|
    |Europe & Central Asia (excluding high income)|ECA         |396886165 |2001|
    |IDA blend                                    |IDB         |135810058 |1964|
    |IDA blend                                    |IDB         |403526930 |2005|
    |IDA only                                     |IDX         |984961696 |2013|
    +---------------------------------------------+------------+----------+----+
    only showing top 5 rows
    



    Limit data from a dataframe

    df2 = df.limit(10)
    df2.count()
    
    10
    



    Union 2 Dataframes

    Below examples combines 2 dataframes holding the first and last ten rows respectively

    # Combine 2 Dataframes
    df1 = df.orderBy('Country Code').limit(10)
    df2 = df.orderBy('Country Code', ascending=False).limit(10)
    df1.union(df2).show()
    
    +------------+------------+----+-------+
    |Country Name|Country Code|Year|  Value|
    +------------+------------+----+-------+
    |       Aruba|         ABW|1960|  54211|
    |       Aruba|         ABW|1969|  58726|
    |       Aruba|         ABW|1961|  55438|
    |       Aruba|         ABW|1962|  56225|
    |       Aruba|         ABW|1963|  56695|
    |       Aruba|         ABW|1964|  57032|
    |       Aruba|         ABW|1965|  57360|
    |       Aruba|         ABW|1966|  57715|
    |       Aruba|         ABW|1967|  58055|
    |       Aruba|         ABW|1968|  58386|
    |    Zimbabwe|         ZWE|1960|3747369|
    |    Zimbabwe|         ZWE|1969|5009514|
    |    Zimbabwe|         ZWE|1961|3870756|
    |    Zimbabwe|         ZWE|1962|3999419|
    |    Zimbabwe|         ZWE|1963|4132756|
    |    Zimbabwe|         ZWE|1964|4269863|
    |    Zimbabwe|         ZWE|1965|4410212|
    |    Zimbabwe|         ZWE|1966|4553433|
    |    Zimbabwe|         ZWE|1967|4700041|
    |    Zimbabwe|         ZWE|1968|4851431|
    +------------+------------+----+-------+
    



    How to change Dataframe columns

  • Below method shows how a simple alias function can be used to rename a column
  • # Rename Column
    # Method 1
    df.select(col('Country Name').alias('country'), col('Country Code').alias('code')).show(2)
    
    +----------+----+
    |   country|code|
    +----------+----+
    |Arab World| ARB|
    |Arab World| ARB|
    +----------+----+
    only showing top 2 rows
    


  • Another Method which can be used is withColumnRenamed method to display the column name according to your requirement.
  • # Rename Column
    # Method 2
    df.select('Year').distinct().groupBy().count().withColumnRenamed('count', 'year_count').show()
    
    +----------+
    |year_count|
    +----------+
    |        57|
    +----------+
    


  • ##### Handy Method to rename multiple columns using toDF
  • # Rename Multiple Columns
    # Method 3
    new_col_names = ['country','code','yr','val']
    df.toDF(*new_col_names).columns
    
    ['country', 'code', 'yr', 'val']
    



    DataType Casting

    A simple cast method can be used to explicitly cast a column from one datatype to another. Below example shows how to convert the value column from string to bigint.

    df.select(df['Value'].cast('bigint')).printSchema()
    
    root
     |-- Value: long (nullable = true)
    
    Wait! But what if you’d like to cast mutiple columns at a shot?

    There are several ways to achieve this. I would like to discuss to easy ways which isn’t very tedious. One way is to use a list of column datatypes and the column names and iterate over the same to cast the columns in one loop. Another simpler way is to use Spark SQL to frame a SQL query to cast the columns.

    Below example depicts a concise way to cast multiple columns using a single for loop without having to repetitvely use the cast function in the code.

    from pyspark.sql.functions import col
    cols = df.columns #> ['Country Name', 'Country Code', 'Year', 'Value']
    datatypes = ['string', 'string', 'bigint', 'bigint']
    for i in range(len(cols)):
        df = df.withColumn(cols[i], col(cols[i]).cast(datatypes[i]))
    df = df.select(*cols)
    df.printSchema()
    
    root
     |-- Country Name: string (nullable = true)
     |-- Country Code: string (nullable = true)
     |-- Year: long (nullable = true)
     |-- Value: long (nullable = true)
    



    Cache a DataFrame

    This would be useful when dataframe is being called multiple times. The dataframe would be cached in memory, hence the data retrieval latency would be lower

    df.persist()
    
    DataFrame[Country Name: string, Country Code: string, Year: string, Value: string]
    


    A Dataframe can be verified if it’s present in the cache or not using the storageLevel() method.

    True condition indicates dataframe is present is already cached.

    df.storageLevel
    
    StorageLevel(True, True, False, False, 1)
    



    Unpersist Dataframe

    df.unpersist()
    
    DataFrame[Country Name: string, Country Code: string, Year: string, Value: string]
    


    StorageLevel After uncaching Dataframe

    df.storageLevel
    
    StorageLevel(False, False, False, False, 1)
    



    Replace Nulls

    Replace Null Values with some user defined value

    df3 = df.fillna('-99')
    



    Partition Data

    repartition method can be used to partition the data according to the columns or a defined number. The repartition algorithm performs a full data shuffle creating equally distributed chunks of data among the partitions. The resulting dataframe is hash partitioned.

    repartition can be done in 2 ways,

  • Passing an int value to repartition method can help create partitions based on the integer argument
  • Passing a column name, would create the partitions based on the distinct column values
  • Caution: Repartition performs a full shuffle on the data. Providing an incorrect input might result in a large file getting created or may sometimes result in out of memory error

    In our example, we will partition the data according to country name & compute the total number of partitions Here there are 263 country names in the dataset, but only 200 files would be created if this dataframe is saved. This is because repartition by default takes in the value present in spark.sql.shuffle.partitions if integer value is not explicitly provided

    df3 = df.repartition('Country Name')
    df3.rdd.getNumPartitions()
    
    200
    


    As shown below the value of the property by default is 200. This value can be changed using the conf.set method

    spark.conf.get("spark.sql.shuffle.partitions")
    
    '200'
    
    spark.conf.set("spark.sql.shuffle.partitions", "300")
    df3 = df.repartition('Country Name')
    df3.rdd.getNumPartitions()
    
    300
    


    Partition Data according to an integer value

    df4 = df.repartition(50)
    df4.rdd.getNumPartitions()
    
    50
    



    DataFrame Write

    A Dataframe can be saved in multiple formats such as parquet, ORC and even plain delimited text files. Below example illustrates how the dataframe can be saved as a pipe delimited csv file

    df.write.format('csv').option('delimiter','|').save('Path-to_file')
    

    A Dataframe can be saved in multiple modes, such as,

  • append - appends to existing data in the path
  • overwrite - Overwrites existing data with the dataframe being saved
  • ignore - Does nothing if data exists
  • error (or) errorifexists - raises an exception if data is already present (Default)
  • Below method illustrates how the above save can be performed with overwrite mode

    df.write.format('csv').option('delimiter','|').mode('overwrite').save('Path-to_file')
    



    Create Temp View

    The Dataframe can be saved as temporary view which is present as long as that spark session is active

    # Save Dataframe as Temp View
    df.createOrReplaceTempView('population')
    # Above view can be used to perform Spark SQL queries 
    



    Spark SQL

  • Display Data using Spark SQL
  • spark.sql("select * from population limit 5").show()
    
    +------------+------------+----+---------+
    |Country Name|Country Code|Year|    Value|
    +------------+------------+----+---------+
    |  Arab World|         ARB|1960| 92490932|
    |  Arab World|         ARB|1961| 95044497|
    |  Arab World|         ARB|1962| 97682294|
    |  Arab World|         ARB|1963|100411076|
    |  Arab World|         ARB|1964|103239902|
    +------------+------------+----+---------+
    


  • Get Max Year from the dataset
  • spark.sql("select max(year) as max_year from population").show()
    
    +--------+
    |max_year|
    +--------+
    |    2016|
    +--------+
    


  • Display population for Japan & India for the years between 1990 to 1995
  • spark.sql("""select `Country Name`, year,  value as population from population where `Country Name` in ('India','Japan')
              and cast(year as bigint) between 1990 and 1995""").show()
    
    +------------+----+----------+
    |Country Name|year|population|
    +------------+----+----------+
    |       India|1990| 870133480|
    |       India|1991| 888054875|
    |       India|1992| 906021106|
    |       India|1993| 924057817|
    |       India|1994| 942204249|
    |       India|1995| 960482795|
    |       Japan|1990| 123537000|
    |       Japan|1991| 123921000|
    |       Japan|1992| 124229000|
    |       Japan|1993| 124536000|
    |       Japan|1994| 124961000|
    |       Japan|1995| 125439000|
    +------------+----+----------+
    

    By now, you should be familiar on performing basic operations on a Spark Dataframe. I strongly recommend you to use a random dataset and practice the above operations to get a hold of it.

    Feedback and comments are welcome and the same can be posted on the comment section below. Hope this post was helpful. Cheers!

    comments powered by Disqus

    About us

    We provide resources and tutorials on the Big-Data, Analytics and Programming platform.

    Recent posts