X-Git-Url: https://code.communitydata.science/cdsc_reddit.git/blobdiff_plain/e6294b5b90135a5163441c8dc62252dd6a188412..55b75ea6fcf421e95f4fe6b180dcec6e64676619:/ngrams/top_comment_phrases.py?ds=inline diff --git a/ngrams/top_comment_phrases.py b/ngrams/top_comment_phrases.py old mode 100644 new mode 100755 index 031cba5..ff1c4f0 --- a/ngrams/top_comment_phrases.py +++ b/ngrams/top_comment_phrases.py @@ -1,58 +1,69 @@ +#!/usr/bin/env python3 from pyspark.sql import functions as f from pyspark.sql import Window from pyspark.sql import SparkSession import numpy as np +import fire +from pathlib import Path -spark = SparkSession.builder.getOrCreate() -df = spark.read.text("/gscratch/comdata/users/nathante/reddit_comment_ngrams_10p_sample/") -df = df.withColumnRenamed("value","phrase") +def main(ngram_dir="/gscratch/comdata/output/reddit_ngrams"): + spark = SparkSession.builder.getOrCreate() + ngram_dir = Path(ngram_dir) + ngram_sample = ngram_dir / "reddit_comment_ngrams_10p_sample" + df = spark.read.text(str(ngram_sample)) -# count phrase occurrances -phrases = df.groupby('phrase').count() -phrases = phrases.withColumnRenamed('count','phraseCount') -phrases = phrases.filter(phrases.phraseCount > 10) + df = df.withColumnRenamed("value","phrase") + # count phrase occurrances + phrases = df.groupby('phrase').count() + phrases = phrases.withColumnRenamed('count','phraseCount') + phrases = phrases.filter(phrases.phraseCount > 10) -# count overall -N = phrases.select(f.sum(phrases.phraseCount).alias("phraseCount")).collect()[0].phraseCount + # count overall + N = phrases.select(f.sum(phrases.phraseCount).alias("phraseCount")).collect()[0].phraseCount -print(f'analyzing PMI on a sample of {N} phrases') -logN = np.log(N) -phrases = phrases.withColumn("phraseLogProb", f.log(f.col("phraseCount")) - logN) + print(f'analyzing PMI on a sample of {N} phrases') + logN = np.log(N) + phrases = phrases.withColumn("phraseLogProb", f.log(f.col("phraseCount")) - logN) -# count term occurrances -phrases = phrases.withColumn('terms',f.split(f.col('phrase'),' ')) -terms = phrases.select(['phrase','phraseCount','phraseLogProb',f.explode(phrases.terms).alias('term')]) + # count term occurrances + phrases = phrases.withColumn('terms',f.split(f.col('phrase'),' ')) + terms = phrases.select(['phrase','phraseCount','phraseLogProb',f.explode(phrases.terms).alias('term')]) -win = Window.partitionBy('term') -terms = terms.withColumn('termCount',f.sum('phraseCount').over(win)) -terms = terms.withColumnRenamed('count','termCount') -terms = terms.withColumn('termLogProb',f.log(f.col('termCount')) - logN) + win = Window.partitionBy('term') + terms = terms.withColumn('termCount',f.sum('phraseCount').over(win)) + terms = terms.withColumnRenamed('count','termCount') + terms = terms.withColumn('termLogProb',f.log(f.col('termCount')) - logN) -terms = terms.groupBy(terms.phrase, terms.phraseLogProb, terms.phraseCount).sum('termLogProb') -terms = terms.withColumnRenamed('sum(termLogProb)','termsLogProb') -terms = terms.withColumn("phrasePWMI", f.col('phraseLogProb') - f.col('termsLogProb')) + terms = terms.groupBy(terms.phrase, terms.phraseLogProb, terms.phraseCount).sum('termLogProb') + terms = terms.withColumnRenamed('sum(termLogProb)','termsLogProb') + terms = terms.withColumn("phrasePWMI", f.col('phraseLogProb') - f.col('termsLogProb')) -# join phrases to term counts + # join phrases to term counts -df = terms.select(['phrase','phraseCount','phraseLogProb','phrasePWMI']) + df = terms.select(['phrase','phraseCount','phraseLogProb','phrasePWMI']) -df = df.sort(['phrasePWMI'],descending=True) -df = df.sortWithinPartitions(['phrasePWMI'],descending=True) -df.write.parquet("/gscratch/comdata/users/nathante/reddit_comment_ngrams_pwmi.parquet/",mode='overwrite',compression='snappy') + df = df.sort(['phrasePWMI'],descending=True) + df = df.sortWithinPartitions(['phrasePWMI'],descending=True) -df = spark.read.parquet("/gscratch/comdata/users/nathante/reddit_comment_ngrams_pwmi.parquet/") + pwmi_dir = ngram_dir / "reddit_comment_ngrams_pwmi.parquet/" + df.write.parquet(str(pwmi_dir), mode='overwrite', compression='snappy') -df.write.csv("/gscratch/comdata/users/nathante/reddit_comment_ngrams_pwmi.csv/",mode='overwrite',compression='none') + df = spark.read.parquet(str(pwmi_dir)) -df = spark.read.parquet("/gscratch/comdata/users/nathante/reddit_comment_ngrams_pwmi.parquet") -df = df.select('phrase','phraseCount','phraseLogProb','phrasePWMI') + df.write.csv(str(ngram_dir / "reddit_comment_ngrams_pwmi.csv/"),mode='overwrite',compression='none') -# choosing phrases occurring at least 3500 times in the 10% sample (35000 times) and then with a PWMI of at least 3 yeids about 65000 expressions. -# -df = df.filter(f.col('phraseCount') > 3500).filter(f.col("phrasePWMI")>3) -df = df.toPandas() -df.to_feather("/gscratch/comdata/users/nathante/reddit_multiword_expressions.feather") -df.to_csv("/gscratch/comdata/users/nathante/reddit_multiword_expressions.csv") + df = spark.read.parquet(str(pwmi_dir)) + df = df.select('phrase','phraseCount','phraseLogProb','phrasePWMI') + + # choosing phrases occurring at least 3500 times in the 10% sample (35000 times) and then with a PWMI of at least 3 yeids about 65000 expressions. + # + df = df.filter(f.col('phraseCount') > 3500).filter(f.col("phrasePWMI")>3) + df = df.toPandas() + df.to_feather(ngram_dir / "multiword_expressions.feather") + df.to_csv(ngram_dir / "multiword_expressions.csv") + +if __name__ == '__main__': + fire.Fire(main)