This tutorial will explain with examples on how to partition a dataframe randomly or based on specified column(s) of a dataframe.
df = spark.read.csv("file:///path_to_files/csv_with_duplicates_and_nulls.csv",header=True)
df.show()
+-----+---------+-------+
|db_id| db_name|db_type|
+-----+---------+-------+
| 12| Teradata| RDBMS|
| 14|Snowflake|CloudDB|
| 15| Vertica| RDBMS|
| 12| Teradata| RDBMS|
| 22| Mysql| null|
| 50|Snowflake| RDBMS|
| 51| null|CloudDB|
+-----+---------+-------+
df.rdd.getNumPartitions()
Output: 3
from pyspark.sql.functions import spark_partition_id
df_update = df.repartition(4)
df_update.rdd.getNumPartitions()
Output: 4
df_update.select("db_id", "db_name", spark_partition_id().alias("partition#") ).show()
+-----+---------+----------+
|db_id| db_name|partition#|
+-----+---------+----------+
| 51| null| 0|
| 14|Snowflake| 0|
| 12| Teradata| 1|
| 22| Mysql| 1|
| 12| Teradata| 2|
| 50|Snowflake| 3|
| 15| Vertica| 3|
+-----+---------+----------+
from pyspark.sql.functions import spark_partition_id
df_update = df.repartition(4)
df_update.rdd.getNumPartitions()
Output: 4
df_update = df_update.withColumn("partition#", spark_partition_id())
df_update.show()
+-----+---------+-------+----------+
|db_id| db_name|db_type|partition#|
+-----+---------+-------+----------+
| 22| Mysql| null| 0|
| 12| Teradata| RDBMS| 0|
| 51| null|CloudDB| 1|
| 15| Vertica| RDBMS| 1|
| 14|Snowflake|CloudDB| 2|
| 50|Snowflake| RDBMS| 3|
| 12| Teradata| RDBMS| 3|
+-----+---------+-------+----------+
from pyspark.sql.functions import spark_partition_id
df_update = df.repartition(4)
df_update.rdd.getNumPartitions()
Output: 4
df_update = df_update.select("db_name",spark_partition_id()).filter(spark_partition_id().isin(4,1))
df_update.show()
+-------+--------------------+
|db_name|SPARK_PARTITION_ID()|
+-------+--------------------+
| Mysql| 1|
|Vertica| 4|
+-------+--------------------+
repartition((numPartitions, *cols)
df.rdd.getNumpartitins()
Output: 1
df_update = df.repartition(3)
df_update.rdd.getNumPartitions()
Output: 3
df_update = df.repartition("db_name")
df_update.select("db_name",spark_partition_id()).show()
+---------+--------------------+
| db_name|SPARK_PARTITION_ID()|
+---------+--------------------+
| null| 42|
| Mysql| 69|
| Vertica| 107|
|Snowflake| 176|
|Snowflake| 176|
| Teradata| 191|
| Teradata| 191|
+---------+--------------------+
df_update.rdd.getNumPartitions()
Output: 200
from pyspark.sql.functions import spark_partition_id
df_update = df.repartition("db_name", "db_id")
df_update.select("db_name","db_id",spark_partition_id()).show()
+---------+-----+--------------------+
| db_name|db_id|SPARK_PARTITION_ID()|
+---------+-----+--------------------+
| null| 51| 3|
| Teradata| 12| 51|
| Teradata| 12| 51|
|Snowflake| 50| 55|
| Vertica| 15| 77|
| Mysql| 22| 118|
|Snowflake| 14| 124|
+---------+-----+--------------------+
df_update.rdd.getNumPartitions()
Output: 200
from pyspark.sql.functions import spark_partition_id
df_update = df.repartition(2, "db_name")
df_update.select("db_name",spark_partition_id()).show()
+---------+--------------------+
| db_name|SPARK_PARTITION_ID()|
+---------+--------------------+
|Snowflake| 0|
|Snowflake| 0|
| null| 0|
| Teradata| 1|
| Vertica| 1|
| Teradata| 1|
| Mysql| 1|
+---------+--------------------+
df_update.rdd.getNumPartitions()
Output: 2
from pyspark.sql.functions import col, spark_partition_id
df_update = df.repartition(col("db_name").substr(1,1))
df_update.select("db_name",spark_partition_id()).show()
+---------+--------------------+
| db_name|SPARK_PARTITION_ID()|
+---------+--------------------+
| null| 42|
| Teradata| 44|
| Teradata| 44|
| Mysql| 68|
| Vertica| 69|
|Snowflake| 124|
|Snowflake| 124|
+---------+--------------------+
df_update.rdd.getNumPartitions()
Output: 200
from pyspark.sql.functions import col, spark_partition_id
df_update = df.repartition(3, col("db_id")%5)
df_update.select("db_id","db_name",spark_partition_id()).show()
+-----+---------+--------------------+
|db_id| db_name|SPARK_PARTITION_ID()|
+-----+---------+--------------------+
| 12| Teradata| 0|
| 12| Teradata| 0|
| 22| Mysql| 0|
| 15| Vertica| 1|
| 50|Snowflake| 1|
| 14|Snowflake| 2|
| 51| null| 2|
+-----+---------+--------------------+
df_update.rdd.getNumPartitions()
Output: 3