From 529b7f051133df6cccd3fbf360f8b196bf5b5996 Mon Sep 17 00:00:00 2001 From: Nate E TeBlunthuis Date: Sun, 9 Aug 2020 02:34:42 -0700 Subject: [PATCH] Bugfix --- top_comment_phrases.py | 45 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 top_comment_phrases.py diff --git a/top_comment_phrases.py b/top_comment_phrases.py new file mode 100644 index 0000000..707ceaa --- /dev/null +++ b/top_comment_phrases.py @@ -0,0 +1,45 @@ +from pyspark.sql import functions as f +from pyspark.sql import Window +from pyspark.sql import SparkSession +import numpy as np + +spark = SparkSession.builder.getOrCreate() +df = spark.read.text("/gscratch/comdata/users/nathante/reddit_comment_ngrams_10p_sample/") + +df = df.withColumnRenamed("value","phrase") + + +# count overall +N = df.count() +print(f'analyzing PMI on a sample of {N} phrases') +logN = np.log(N) + +# count phrase occurrances +phrases = df.groupby('phrase').count() +phrases = phrases.withColumnRenamed('count','phraseCount') +phrases = phrases.withColumn("phraseLogProb", f.log(f.col("phraseCount")) - logN) + + +# count term occurrances +phrases = phrases.withColumn('terms',f.split(f.col('phrase'),' ')) +terms = phrases.select(['phrase','phraseCount','phraseLogProb',f.explode(phrases.terms).alias('term')]) + +win = Window.partitionBy('term') +terms = terms.withColumn('termCount',f.sum('phraseCount').over(win)) +terms = terms.withColumnRenamed('count','termCount') +terms = terms.withColumn('termLogProb',f.log(f.col('termCount')) - logN) + +terms = terms.groupBy(terms.phrase, terms.phraseLogProb, terms.phraseCount).sum('termLogProb') +terms = terms.withColumnRenamed('sum(termLogProb)','termsLogProb') +terms = terms.withColumn("phrasePWMI", f.col('phraseLogProb') - f.col('termsLogProb')) + +# join phrases to term counts + + +df = terms.select(['phrase','phraseCount','phraseLogProb','phrasePWMI']) + +df = df.repartition('phrasePWMI') +df = df.sort(['phrasePWMI'],descending=True) +df = df.sortWithinPartitions(['phrasePWMI'],descending=True) +df.write.parquet("/gscratch/comdata/users/nathante/reddit_comment_ngrams_pwmi.parquet/",mode='overwrite',compression='snappy') +df.write.csv("/gscratch/comdata/users/nathante/reddit_comment_ngrams_pwmi.csv/",mode='overwrite',compression='none') -- 2.39.5