]> code.communitydata.science - cdsc_reddit.git/blob - ngrams/top_comment_phrases.py
Merge remote-tracking branch 'refs/remotes/origin/excise_reindex' into excise_reindex
[cdsc_reddit.git] / ngrams / top_comment_phrases.py
1 #!/usr/bin/env python3
2 from pyspark.sql import functions as f
3 from pyspark.sql import Window
4 from pyspark.sql import SparkSession
5 import numpy as np
6 import fire
7 from pathlib import Path
8
9
10 def main(ngram_dir="/gscratch/comdata/output/reddit_ngrams"):
11     spark = SparkSession.builder.getOrCreate()
12     ngram_dir = Path(ngram_dir)
13     ngram_sample = ngram_dir / "reddit_comment_ngrams_10p_sample"
14     df = spark.read.text(str(ngram_sample))
15
16     df = df.withColumnRenamed("value","phrase")
17
18     # count phrase occurrances
19     phrases = df.groupby('phrase').count()
20     phrases = phrases.withColumnRenamed('count','phraseCount')
21     phrases = phrases.filter(phrases.phraseCount > 10)
22
23     # count overall
24     N = phrases.select(f.sum(phrases.phraseCount).alias("phraseCount")).collect()[0].phraseCount
25
26     print(f'analyzing PMI on a sample of {N} phrases') 
27     logN = np.log(N)
28     phrases = phrases.withColumn("phraseLogProb", f.log(f.col("phraseCount")) - logN)
29
30     # count term occurrances
31     phrases = phrases.withColumn('terms',f.split(f.col('phrase'),' '))
32     terms = phrases.select(['phrase','phraseCount','phraseLogProb',f.explode(phrases.terms).alias('term')])
33
34     win = Window.partitionBy('term')
35     terms = terms.withColumn('termCount',f.sum('phraseCount').over(win))
36     terms = terms.withColumnRenamed('count','termCount')
37     terms = terms.withColumn('termLogProb',f.log(f.col('termCount')) - logN)
38
39     terms = terms.groupBy(terms.phrase, terms.phraseLogProb, terms.phraseCount).sum('termLogProb')
40     terms = terms.withColumnRenamed('sum(termLogProb)','termsLogProb')
41     terms = terms.withColumn("phrasePWMI", f.col('phraseLogProb') - f.col('termsLogProb'))
42
43     # join phrases to term counts
44
45
46     df = terms.select(['phrase','phraseCount','phraseLogProb','phrasePWMI'])
47
48     df = df.sort(['phrasePWMI'],descending=True)
49     df = df.sortWithinPartitions(['phrasePWMI'],descending=True)
50
51     pwmi_dir = ngram_dir / "reddit_comment_ngrams_pwmi.parquet/"
52     df.write.parquet(str(pwmi_dir), mode='overwrite', compression='snappy')
53
54     df = spark.read.parquet(str(pwmi_dir))
55
56     df.write.csv(str(ngram_dir / "reddit_comment_ngrams_pwmi.csv/"),mode='overwrite',compression='none')
57
58     df = spark.read.parquet(str(pwmi_dir))
59     df = df.select('phrase','phraseCount','phraseLogProb','phrasePWMI')
60
61     # choosing phrases occurring at least 3500 times in the 10% sample (35000 times) and then with a PWMI of at least 3 yeids about 65000 expressions.
62     #
63     df = df.filter(f.col('phraseCount') > 3500).filter(f.col("phrasePWMI")>3)
64     df = df.toPandas()
65     df.to_feather(ngram_dir / "multiword_expressions.feather")
66     df.to_csv(ngram_dir / "multiword_expressions.csv")
67
68 if __name__ == '__main__':
69     fire.Fire(main)

Community Data Science Collective || Want to submit a patch?