This tutorial will explain various approaches with examples on how to modify / update existing column values in a dataframe. Below listed topics will be explained with examples on this page, click on item in the below list and it will take you to the respective section of the page:
df = spark.read.csv("file:///path_to_files/csv_file_with_duplicates.csv", header=True)
df.show()
+-----+---------+-------+
|db_id| db_name|db_type|
+-----+---------+-------+
| 12| Teradata| RDBMS|
| 14|Snowflake|CloudDB|
| 15| Vertica| RDBMS|
| 12| Teradata| RDBMS|
| 22| Mysql| RDBMS|
+-----+---------+-------+
df_other = spark.read.csv("file:///path_to_files/join_example_file_2.csv", header=True)
df_other.show()
+-----+-----------+-------+
|db_id| db_name|db_type|
+-----+-----------+-------+
| 17| Oracle| RDBMS|
| 19| MongoDB| NOSQL|
| 21|SingleStore| RDBMS|
| 22| Mysql| RDBMS|
| 14| Snowflake| RDBMS|
+-----+-----------+-------+
withColumn(columnName, columnLogic/columnExpression)
This function takes 2 parameters, 1st parameter is the name of new or existing column and 2nd parameter is the column logic / column expression for the new or existing column passed in the 1st paramter.
from pyspark.sql.functions import col,lit
df_updated = df.withColumn("db_type",lit("Relation Database"))
df_updated.show()
+-----+---------+-----------------+
|db_id| db_name| db_type|
+-----+---------+-----------------+
| 12| Teradata|Relation Database|
| 14|Snowflake|Relation Database|
| 15| Vertica|Relation Database|
| 12| Teradata|Relation Database|
| 22| Mysql|Relation Database|
+-----+---------+-----------------+
lit() function is used to pass literals i.e. hardcoded (default) values, spark won't take hardcoded values directly.
from pyspark.sql.functions import col,lit
df_updated = df.withColumn("db_id", col("db_id")%10)
df_updated.show()
+-----+---------+-------+
|db_id| db_name|db_type|
+-----+---------+-------+
| 2.0| Teradata| RDBMS|
| 4.0|Snowflake|CloudDB|
| 5.0| Vertica| RDBMS|
| 2.0| Teradata| RDBMS|
| 2.0| Mysql| RDBMS|
+-----+---------+-------+
lit() function is used to pass literals i.e. hardcoded values, spark won't take hardcoded values directly.
from pyspark.sql.functions import col,lit,substring
df_updated = df.withColumn("db_name",substring("db_name",1,4))
df_updated.show()
+-----+-------+-------+
|db_id|db_name|db_type|
+-----+-------+-------+
| 12| Tera| RDBMS|
| 14| Snow|CloudDB|
| 15| Vert| RDBMS|
| 12| Tera| RDBMS|
| 22| Mysq| RDBMS|
+-----+-------+-------+
from pyspark.sql.functions import col,lit
df_updated = df.select("db_id", "db_name", lit("Relation Database").alias("db_type"))
df_updated.show()
+-----+---------+-----------------+
|db_id| db_name| db_type|
+-----+---------+-----------------+
| 12| Teradata|Relation Database|
| 14|Snowflake|Relation Database|
| 15| Vertica|Relation Database|
| 12| Teradata|Relation Database|
| 22| Mysql|Relation Database|
+-----+---------+-----------------+
from pyspark.sql.functions import col,lit
df_updated = df.select((col("db_id")%10).alias("db_id"), "db_name", "db_type")
df_updated.show()
+-----+---------+-------+
|db_id| db_name|db_type|
+-----+---------+-------+
| 2.0| Teradata| RDBMS|
| 4.0|Snowflake|CloudDB|
| 5.0| Vertica| RDBMS|
| 2.0| Teradata| RDBMS|
| 2.0| Mysql| RDBMS|
+-----+---------+-------+
lit() function is used to pass literals i.e. hardcoded values, spark won't take hardcoded values directly.
from pyspark.sql.functions import col,lit,substring
df_updated = df.select("db_id", substring("db_name",1,4).alias("db_name"), "db_type")
df_updated.show()
+-----+-------+-------+
|db_id|db_name|db_type|
+-----+-------+-------+
| 12| Tera| RDBMS|
| 14| Snow|CloudDB|
| 15| Vert| RDBMS|
| 12| Tera| RDBMS|
| 22| Mysq| RDBMS|
+-----+-------+-------+
substring() function is used to get a part of string from the db_name column.
from pyspark.sql.functions import col,lit,substring
df_updated = df.select( (col("db_id")%10).alias("db_id"), substring("db_name",1,4).alias("db_name"), "db_type" )
df_updated.show()
+-----+-------+-------+
|db_id|db_name|db_type|
+-----+-------+-------+
| 2.0| Tera| RDBMS|
| 4.0| Snow|CloudDB|
| 5.0| Vert| RDBMS|
| 2.0| Tera| RDBMS|
| 2.0| Mysq| RDBMS|
+-----+-------+-------+
when(condition, value to return if condition is true)
otherwise(value if non of condition met)
→ when() function takes 2 parameters, 1st parameter is the condition which will evaluate to True/False and 2nd parameter is the value to be returned if condition is evaluated to true.
from pyspark.sql.functions import col,lit,when
df_updated = df.select("db_id", "db_name", when( col("db_type")=="RDBMS", "On Premise").when( col("db_type")=="CloudDB","Cloud" ).otherwise( "Not Known" ).alias("db_type"))
df_updated.show()
+-----+---------+----------+
|db_id| db_name| db_type|
+-----+---------+----------+
| 12| Teradata|On Premise|
| 14|Snowflake| Cloud|
| 15| Vertica|On Premise|
| 12| Teradata|On Premise|
| 22| Mysql|On Premise|
+-----+---------+----------+
from pyspark.sql.functions import col,when
df_updated = df.withColumn("db_type", when( col("db_type")=="RDBMS", "On Premise").when( col("db_type")=="CloudDB","Cloud" ).otherwise( "Not Known" ).alias("db_type"))
df_updated.show()
+-----+---------+----------+
|db_id| db_name| db_type|
+-----+---------+----------+
| 12| Teradata|On Premise|
| 14|Snowflake| Cloud|
| 15| Vertica|On Premise|
| 12| Teradata|On Premise|
| 22| Mysql|On Premise|
+-----+---------+----------+
from pyspark.sql.functions import col,when
df_updated = df.join(df_other, "db_id", "left").select("db_id", df.db_name, when(df_other.db_type.isNull(), df.db_type).otherwise(df_other.db_type).alias("db_type"))
df_updated.show()
+-----+---------+-------+
|db_id| db_name|db_type|
+-----+---------+-------+
| 12| Teradata| RDBMS|
| 14|Snowflake| RDBMS|
| 15| Vertica| RDBMS|
| 12| Teradata| RDBMS|
| 22| Mysql| RDBMS|
+-----+---------+-------+
df.printSchema()
root
|-- db_id: string (nullable = true)
|-- db_name: string (nullable = true)
|-- db_type: string (nullable = true)
df_updated= df.select(col("db_id").astype("integer"), "db_name", "db_type")
df_updated.printSchema()
root
|-- db_id: integer (nullable = true)
|-- db_name: string (nullable = true)
|-- db_type: string (nullable = true)
df_updated= df.select(col("db_id").astype("integer"), "db_name", "db_type")
df_updated.show()
+-----+-------+-------+
|db_id|db_name|db_type|
+-----+-------+-------+
| 12| null| RDBMS|
| 14| null|CloudDB|
| 15| null| RDBMS|
| 12| null| RDBMS|
| 22| null| RDBMS|
+-----+-------+-------+