Problem Statement: -

Given a dataset, find time spent by an employee inside the office

from datetime import datetime
from pyspark.sql.types import StructType, StructField, TimestampType, LongType, StringType

_data = [
    (11114, datetime.strptime('08:30:00.000000', "%H:%M:%S.%f"), "I"),
    (11114, datetime.strptime('10:30:00.000000', "%H:%M:%S.%f"), 'O'),
    (11114, datetime.strptime('11:30:00.000000', "%H:%M:%S.%f"), 'I'),
    (11114, datetime.strptime('15:30:00.000000', "%H:%M:%S.%f"), 'O'),
    (11115, datetime.strptime('09:30:00.000000', "%H:%M:%S.%f"), 'I'),
    (11115, datetime.strptime('17:30:00.000000', "%H:%M:%S.%f"), 'O')
]

_schema = StructType([
    StructField('emp_id', LongType(), True),
    StructField('punch_time', TimestampType(), True),
    StructField('flag', StringType(), True)
])

df = spark.createDataFrame(data=_data, schema=_schema)
df.show()
+------+-------------------+----+
|emp_id|         punch_time|flag|
+------+-------------------+----+
| 11114|1900-01-01 08:30:00|   I|
| 11114|1900-01-01 10:30:00|   O|
| 11114|1900-01-01 11:30:00|   I|
| 11114|1900-01-01 15:30:00|   O|
| 11115|1900-01-01 09:30:00|   I|
| 11115|1900-01-01 17:30:00|   O|
+------+-------------------+----+

Window

window_spec = Window.partitionBy("emp_id").orderBy("punch_time")

df_lead = df \
    .withColumn("Next_punch",lead("punch_time").over(window_spec)) \
    .filter(col("flag") == "I")

df_lead.show()
+------+-------------------+----+-------------------+
|emp_id|         punch_time|flag|         Next_punch|
+------+-------------------+----+-------------------+
| 11114|1900-01-01 08:30:00|   I|1900-01-01 10:30:00|
| 11114|1900-01-01 11:30:00|   I|1900-01-01 15:30:00|
| 11115|1900-01-01 09:30:00|   I|1900-01-01 17:30:00|
+------+-------------------+----+-------------------+

unix_timestamp()

df_diff = df_lead \
    .withColumn("ttdiff", (unix_timestamp("next_punch") - unix_timestamp("punch_time"))/3600 )
df_diff.show()
+------+-------------------+----+-------------------+------+
|emp_id|         punch_time|flag|         Next_punch|ttdiff|
+------+-------------------+----+-------------------+------+
| 11114|1900-01-01 08:30:00|   I|1900-01-01 10:30:00|   2.0|
| 11114|1900-01-01 11:30:00|   I|1900-01-01 15:30:00|   4.0|
| 11115|1900-01-01 09:30:00|   I|1900-01-01 17:30:00|   8.0|
+------+-------------------+----+-------------------+------+

Or Direct substraction: -

    df_diff = df_lead \
        .withColumn("ttdiff", ((col("next_punch")-col("punch_time"))/3600).cast("long") )
    df_diff.show()
df_final = df_diff \
    .select("emp_id","ttdiff") \
    .groupBy("emp_id") \
    .agg(sum("ttdiff").alias("total_time"))
df_final.show()
+------+----------+
|emp_id|total_time|
+------+----------+
| 11114|       6.0|
| 11115|       8.0|
+------+----------+