1 from pyspark.sql import functions as f
 
   2 from pyspark.sql import SparkSession
 
   3 from pyspark.sql import Window
 
   5 spark = SparkSession.builder.getOrCreate()
 
   6 conf = spark.sparkContext.getConf()
 
   8 submissions = spark.read.parquet("/gscratch/comdata/output/reddit_submissions_by_subreddit.parquet")
 
  10 prop_nsfw = submissions.select(['subreddit','over_18']).groupby('subreddit').agg(f.mean(f.col('over_18').astype('double')).alias('prop_nsfw'))
 
  12 df = spark.read.parquet("/gscratch/comdata/output/reddit_comments_by_subreddit.parquet")
 
  15 df = df.filter(~df.subreddit.like("u_%"))
 
  17 df = df.groupBy('subreddit').agg(f.count('id').alias("n_comments"))
 
  19 df = df.join(prop_nsfw,on='subreddit')
 
  20 df = df.filter(df.prop_nsfw < 0.5)
 
  22 win = Window.orderBy(f.col('n_comments').desc())
 
  23 df = df.withColumn('comments_rank', f.rank().over(win))
 
  27 df = df.sort_values("n_comments")
 
  29 df.to_csv('/gscratch/comdata/output/reddit_similarity/subreddits_by_num_comments.csv', index=False)