X-Git-Url: https://code.communitydata.science/cdsc_reddit.git/blobdiff_plain/811a0d87c4d394c2c7849a613f6aec2d81e49138..07b0dff9bc0dae2ab6f7fb7334007a5269a512ad:/ngrams/top_comment_phrases.py diff --git a/ngrams/top_comment_phrases.py b/ngrams/top_comment_phrases.py deleted file mode 100755 index ff1c4f0..0000000 --- a/ngrams/top_comment_phrases.py +++ /dev/null @@ -1,69 +0,0 @@ -#!/usr/bin/env python3 -from pyspark.sql import functions as f -from pyspark.sql import Window -from pyspark.sql import SparkSession -import numpy as np -import fire -from pathlib import Path - - -def main(ngram_dir="/gscratch/comdata/output/reddit_ngrams"): - spark = SparkSession.builder.getOrCreate() - ngram_dir = Path(ngram_dir) - ngram_sample = ngram_dir / "reddit_comment_ngrams_10p_sample" - df = spark.read.text(str(ngram_sample)) - - df = df.withColumnRenamed("value","phrase") - - # count phrase occurrances - phrases = df.groupby('phrase').count() - phrases = phrases.withColumnRenamed('count','phraseCount') - phrases = phrases.filter(phrases.phraseCount > 10) - - # count overall - N = phrases.select(f.sum(phrases.phraseCount).alias("phraseCount")).collect()[0].phraseCount - - print(f'analyzing PMI on a sample of {N} phrases') - logN = np.log(N) - phrases = phrases.withColumn("phraseLogProb", f.log(f.col("phraseCount")) - logN) - - # count term occurrances - phrases = phrases.withColumn('terms',f.split(f.col('phrase'),' ')) - terms = phrases.select(['phrase','phraseCount','phraseLogProb',f.explode(phrases.terms).alias('term')]) - - win = Window.partitionBy('term') - terms = terms.withColumn('termCount',f.sum('phraseCount').over(win)) - terms = terms.withColumnRenamed('count','termCount') - terms = terms.withColumn('termLogProb',f.log(f.col('termCount')) - logN) - - terms = terms.groupBy(terms.phrase, terms.phraseLogProb, terms.phraseCount).sum('termLogProb') - terms = terms.withColumnRenamed('sum(termLogProb)','termsLogProb') - terms = terms.withColumn("phrasePWMI", f.col('phraseLogProb') - f.col('termsLogProb')) - - # join phrases to term counts - - - df = terms.select(['phrase','phraseCount','phraseLogProb','phrasePWMI']) - - df = df.sort(['phrasePWMI'],descending=True) - df = df.sortWithinPartitions(['phrasePWMI'],descending=True) - - pwmi_dir = ngram_dir / "reddit_comment_ngrams_pwmi.parquet/" - df.write.parquet(str(pwmi_dir), mode='overwrite', compression='snappy') - - df = spark.read.parquet(str(pwmi_dir)) - - df.write.csv(str(ngram_dir / "reddit_comment_ngrams_pwmi.csv/"),mode='overwrite',compression='none') - - df = spark.read.parquet(str(pwmi_dir)) - df = df.select('phrase','phraseCount','phraseLogProb','phrasePWMI') - - # choosing phrases occurring at least 3500 times in the 10% sample (35000 times) and then with a PWMI of at least 3 yeids about 65000 expressions. - # - df = df.filter(f.col('phraseCount') > 3500).filter(f.col("phrasePWMI")>3) - df = df.toPandas() - df.to_feather(ngram_dir / "multiword_expressions.feather") - df.to_csv(ngram_dir / "multiword_expressions.csv") - -if __name__ == '__main__': - fire.Fire(main)