pivot() in PySpark:
The pivot() function reshapes the DataFrame by turning unique values from one column into separate columns. It is used along with groupBy() to perform aggregations on grouped data.
Syntax: -
DataFrame.pivot(pivot_col, values=None)
Parameters:
pivot_col: The column whose unique values will become new columns.values(optional): A list of values to pivot. If not specified, all unique values from thepivot_colwill be used.
Example: -
from pyspark.sql.functions import sum
# Sample DataFrame
df = spark.createDataFrame([
('A', 'cat', 1),
('A', 'dog', 2),
('B', 'cat', 3),
('B', 'dog', 4)
], ['ID', 'Animal', 'Count'])
# Pivot the data
pivoted_df = df.groupBy("ID").pivot("Animal").agg(sum("Count"))
pivoted_df.show()
+---+---+---+
| ID|cat|dog|
+---+---+---+
| A| 1| 2|
| B| 3| 4|
+---+---+---+
Explanation:
groupBy("ID"): Groups data by theIDcolumn.pivot("Animal"): Pivots on theAnimalcolumn (i.e., turns values like "cat" and "dog" into separate columns).agg(sum("Count")): Aggregates theCountcolumn using thesum()function.