1 from pyspark.sql import functions as f
 
   2 from pyspark.sql import Window
 
   3 from pyspark.sql import SparkSession
 
   6 spark = SparkSession.builder.getOrCreate()
 
   7 df = spark.read.text("/gscratch/comdata/users/nathante/reddit_comment_ngrams_10p_sample/")
 
   9 df = df.withColumnRenamed("value","phrase")
 
  11 # count phrase occurrances
 
  12 phrases = df.groupby('phrase').count()
 
  13 phrases = phrases.withColumnRenamed('count','phraseCount')
 
  14 phrases = phrases.filter(phrases.phraseCount > 10)
 
  18 N = phrases.select(f.sum(phrases.phraseCount).alias("phraseCount")).collect()[0].phraseCount
 
  20 print(f'analyzing PMI on a sample of {N} phrases') 
 
  22 phrases = phrases.withColumn("phraseLogProb", f.log(f.col("phraseCount")) - logN)
 
  24 # count term occurrances
 
  25 phrases = phrases.withColumn('terms',f.split(f.col('phrase'),' '))
 
  26 terms = phrases.select(['phrase','phraseCount','phraseLogProb',f.explode(phrases.terms).alias('term')])
 
  28 win = Window.partitionBy('term')
 
  29 terms = terms.withColumn('termCount',f.sum('phraseCount').over(win))
 
  30 terms = terms.withColumnRenamed('count','termCount')
 
  31 terms = terms.withColumn('termLogProb',f.log(f.col('termCount')) - logN)
 
  33 terms = terms.groupBy(terms.phrase, terms.phraseLogProb, terms.phraseCount).sum('termLogProb')
 
  34 terms = terms.withColumnRenamed('sum(termLogProb)','termsLogProb')
 
  35 terms = terms.withColumn("phrasePWMI", f.col('phraseLogProb') - f.col('termsLogProb'))
 
  37 # join phrases to term counts
 
  40 df = terms.select(['phrase','phraseCount','phraseLogProb','phrasePWMI'])
 
  42 df = df.sort(['phrasePWMI'],descending=True)
 
  43 df = df.sortWithinPartitions(['phrasePWMI'],descending=True)
 
  44 df.write.parquet("/gscratch/comdata/users/nathante/reddit_comment_ngrams_pwmi.parquet/",mode='overwrite',compression='snappy')
 
  46 df = spark.read.parquet("/gscratch/comdata/users/nathante/reddit_comment_ngrams_pwmi.parquet/")
 
  48 df.write.csv("/gscratch/comdata/users/nathante/reddit_comment_ngrams_pwmi.csv/",mode='overwrite',compression='none')
 
  50 df = spark.read.parquet("/gscratch/comdata/users/nathante/reddit_comment_ngrams_pwmi.parquet")
 
  51 df = df.select('phrase','phraseCount','phraseLogProb','phrasePWMI')
 
  53 # choosing phrases occurring at least 3500 times in the 10% sample (35000 times) and then with a PWMI of at least 3 yeids about 65000 expressions.
 
  55 df = df.filter(f.col('phraseCount') > 3500).filter(f.col("phrasePWMI")>3)
 
  57 df.to_feather("/gscratch/comdata/users/nathante/reddit_multiword_expressions.feather")
 
  58 df.to_csv("/gscratch/comdata/users/nathante/reddit_multiword_expressions.csv")