Problem Statement: -

Pasted image 20241228171033.png

from pyspark.sql.types import *
from pyspark.sql.functions import *


data = [
    ("john", "tomato", 2),
    ("bill", "apple", 2),
    ("john", "banana", 2),
    ("john", "tomato", 3),
    ("bill", "taco", 2),
    ("bill", "apple", 2),
]
schema = "name string,item string,weight int"
df = spark.createDataFrame(data, schema)
df.show(truncate=False)
df_agg = df.groupBy("name","item").agg(sum(col("weight")).alias("Ttl_wht"))
df_agg.show()
+----+------+-------+
|name|  item|Ttl_wht|
+----+------+-------+
|john|tomato|      5|
|bill| apple|      4|
|john|banana|      2|
|bill|  taco|      2|
+----+------+-------+

collect_list()

df_final = df_agg.groupBy("name").agg(collect_list(struct("item","ttl_wht")).alias("Items"))
df_final.show(truncate=False)
+----+--------------------------+
|name|Items                     |
+----+--------------------------+
|john|[{tomato, 5}, {banana, 2}]|
|bill|[{apple, 4}, {taco, 2}]   |
+----+--------------------------+