This tutorial will explain various approaches with examples on how to add new columns or modify existing columns in a dataframe. Below listed topics will be explained with examples on this page, click on item in the below list and it will take you to the respective section of the page:
df = spark.read.csv("file:///path_to_files/csv_file_with_duplicates.csv", header=True)
df.show()
+-----+---------+-------+
|db_id| db_name|db_type|
+-----+---------+-------+
| 12| Teradata| RDBMS|
| 14|Snowflake|CloudDB|
| 15| Vertica| RDBMS|
| 12| Teradata| RDBMS|
| 22| Mysql| RDBMS|
+-----+---------+-------+
df.columns
Output: ['db_id', 'db_name', 'db_type']
withColumn(columnName, columnLogic/columnExpression)
This function takes 2 parameters, 1st parameter is the name of new or existing column and 2nd parameter is the column logic / column expression for the new or existing column passed in the 1st paramter.
from pyspark.sql.functions import col,lit
df_updated = df.withColumn("db_type",lit("Relation Database"))
df_updated.show()
+-----+---------+-----------------+
|db_id| db_name| db_type|
+-----+---------+-----------------+
| 12| Teradata|Relation Database|
| 14|Snowflake|Relation Database|
| 15| Vertica|Relation Database|
| 12| Teradata|Relation Database|
| 22| Mysql|Relation Database|
+-----+---------+-----------------+
lit() function is used to pass literals i.e. hardcoded values, spark won't take hardcoded values directly.
from pyspark.sql.functions import col,lit
df_updated = df.withColumn("db_type_cd",lit("Relation Database"))
df_updated.show()
+-----+---------+-------+-----------------+
|db_id| db_name|db_type| db_type_cd|
+-----+---------+-------+-----------------+
| 12| Teradata| RDBMS|Relation Database|
| 14|Snowflake|CloudDB|Relation Database|
| 15| Vertica| RDBMS|Relation Database|
| 12| Teradata| RDBMS|Relation Database|
| 22| Mysql| RDBMS|Relation Database|
+-----+---------+-------+-----------------+
lit() function is used to pass literals i.e. hardcoded values, spark won't take hardcoded values directly.
from pyspark.sql.functions import col,lit,substring
df_updated = df.withColumn("db_name_first_4_char",substring("db_name",1,4))
df_updated.show()
+-----+---------+-------+--------------------+
|db_id| db_name|db_type|db_name_first_4_char|
+-----+---------+-------+--------------------+
| 12| Teradata| RDBMS| Tera|
| 14|Snowflake|CloudDB| Snow|
| 15| Vertica| RDBMS| Vert|
| 12| Teradata| RDBMS| Tera|
| 22| Mysql| RDBMS| Mysq|
+-----+---------+-------+--------------------+
from pyspark.sql.functions import col,lit
df_updated = df.select("db_id", "db_name", lit("Relation Database").alias("db_type"))
df_updated.show()
+-----+---------+-----------------+
|db_id| db_name| db_type|
+-----+---------+-----------------+
| 12| Teradata|Relation Database|
| 14|Snowflake|Relation Database|
| 15| Vertica|Relation Database|
| 12| Teradata|Relation Database|
| 22| Mysql|Relation Database|
+-----+---------+-----------------+
from pyspark.sql.functions import col,lit
df_updated = df.select("db_id", "db_name", "db_type", lit("Relation Database").alias("db_type_cd"))
df_updated.show()
+-----+---------+-------+-----------------+
|db_id| db_name|db_type| db_type_cd|
+-----+---------+-------+-----------------+
| 12| Teradata| RDBMS|Relation Database|
| 14|Snowflake|CloudDB|Relation Database|
| 15| Vertica| RDBMS|Relation Database|
| 12| Teradata| RDBMS|Relation Database|
| 22| Mysql| RDBMS|Relation Database|
+-----+---------+-------+-----------------+
#same example as above but by using list
column_li = df.columns
column_li.append(lit("Relation Database").alias("db_type_cd"))
df_updated = df.select(column_li)
df_updated.show()
+-----+---------+-------+-----------------+
|db_id| db_name|db_type| db_type_cd|
+-----+---------+-------+-----------------+
| 12| Teradata| RDBMS|Relation Database|
| 14|Snowflake|CloudDB|Relation Database|
| 15| Vertica| RDBMS|Relation Database|
| 12| Teradata| RDBMS|Relation Database|
| 22| Mysql| RDBMS|Relation Database|
+-----+---------+-------+-----------------+
lit() function is used to pass literals i.e. hardcoded values, spark won't take hardcoded values directly.
from pyspark.sql.functions import col,lit,substring
df_updated = df.select("db_id", "db_name", "db_type", substring("db_name",1,4).alias("db_name_first_4_char"))
df_updated.show()
+-----+---------+-------+--------------------+
|db_id| db_name|db_type|db_name_first_4_char|
+-----+---------+-------+--------------------+
| 12| Teradata| RDBMS| Tera|
| 14|Snowflake|CloudDB| Snow|
| 15| Vertica| RDBMS| Vert|
| 12| Teradata| RDBMS| Tera|
| 22| Mysql| RDBMS| Mysq|
+-----+---------+-------+--------------------+
#same example as above but by using list
column_li = df.columns
column_li.append(substring("db_name",1,4).alias("db_name_first_4_char"))
df_updated = df.select(column_li)
df_updated.show()
+-----+---------+-------+--------------------+
|db_id| db_name|db_type|db_name_first_4_char|
+-----+---------+-------+--------------------+
| 12| Teradata| RDBMS| Tera|
| 14|Snowflake|CloudDB| Snow|
| 15| Vertica| RDBMS| Vert|
| 12| Teradata| RDBMS| Tera|
| 22| Mysql| RDBMS| Mysq|
+-----+---------+-------+--------------------+
substring() function is used to get a part of string from the db_name column.
from pyspark.sql.functions import col,lit,substring
df_updated = df.select("db_id", "db_name", "db_type", substring("db_name",1,4).alias("db_name_first_4_char"), (col("db_id")%10).alias("bucket") )
df_updated.show()
+-----+---------+-------+--------------------+------+
|db_id| db_name|db_type|db_name_first_4_char|bucket|
+-----+---------+-------+--------------------+------+
| 12| Teradata| RDBMS| Tera| 2.0|
| 14|Snowflake|CloudDB| Snow| 4.0|
| 15| Vertica| RDBMS| Vert| 5.0|
| 12| Teradata| RDBMS| Tera| 2.0|
| 22| Mysql| RDBMS| Mysq| 2.0|
+-----+---------+-------+--------------------+------+
#same example as above but by using list
column_li = df.columns
column_li.append(substring("db_name",1,4).alias("db_name_first_4_char")) #adding 1st Column to list
column_li.append((col("db_id")%10).alias("bucket")) #adding 2nd Column to list
df_updated = df.select(column_li)
df_updated.show()
+-----+---------+-------+--------------------+------+
|db_id| db_name|db_type|db_name_first_4_char|bucket|
+-----+---------+-------+--------------------+------+
| 12| Teradata| RDBMS| Tera| 2.0|
| 14|Snowflake|CloudDB| Snow| 4.0|
| 15| Vertica| RDBMS| Vert| 5.0|
| 12| Teradata| RDBMS| Tera| 2.0|
| 22| Mysql| RDBMS| Mysq| 2.0|
+-----+---------+-------+--------------------+------+
from pyspark.sql.functions import col,when
df_updated = df.join(df_other, "db_id", "left").select("db_id", df.db_name, df.db_type, df_other.db_name.alias("other_db_name"), df_other.db_type.alias("other_db_type"))
df_updated.show()
+-----+---------+-------+-------------+-------------+
|db_id| db_name|db_type|other_db_name|other_db_type|
+-----+---------+-------+-------------+-------------+
| 12| Teradata| RDBMS| null| null|
| 14|Snowflake|CloudDB| Snowflake| RDBMS|
| 15| Vertica| RDBMS| null| null|
| 12| Teradata| RDBMS| null| null|
| 22| Mysql| RDBMS| Mysql| RDBMS|
+-----+---------+-------+-------------+-------------+