+#!/usr/bin/env python3
 from pyspark.sql import functions as f
 from pyspark.sql import Window
 from pyspark.sql import SparkSession
 import numpy as np
+import fire
+from pathlib import Path
 
-spark = SparkSession.builder.getOrCreate()
-df = spark.read.text("/gscratch/comdata/users/nathante/reddit_comment_ngrams_10p_sample/")
 
-df = df.withColumnRenamed("value","phrase")
+def main(ngram_dir="/gscratch/comdata/output/reddit_ngrams"):
+    spark = SparkSession.builder.getOrCreate()
+    ngram_dir = Path(ngram_dir)
+    ngram_sample = ngram_dir / "reddit_comment_ngrams_10p_sample"
+    df = spark.read.text(str(ngram_sample))
 
-# count phrase occurrances
-phrases = df.groupby('phrase').count()
-phrases = phrases.withColumnRenamed('count','phraseCount')
-phrases = phrases.filter(phrases.phraseCount > 10)
+    df = df.withColumnRenamed("value","phrase")
 
+    # count phrase occurrances
+    phrases = df.groupby('phrase').count()
+    phrases = phrases.withColumnRenamed('count','phraseCount')
+    phrases = phrases.filter(phrases.phraseCount > 10)
 
-# count overall
-N = phrases.select(f.sum(phrases.phraseCount).alias("phraseCount")).collect()[0].phraseCount
+    # count overall
+    N = phrases.select(f.sum(phrases.phraseCount).alias("phraseCount")).collect()[0].phraseCount
 
-print(f'analyzing PMI on a sample of {N} phrases') 
-logN = np.log(N)
-phrases = phrases.withColumn("phraseLogProb", f.log(f.col("phraseCount")) - logN)
+    print(f'analyzing PMI on a sample of {N} phrases') 
+    logN = np.log(N)
+    phrases = phrases.withColumn("phraseLogProb", f.log(f.col("phraseCount")) - logN)
 
-# count term occurrances
-phrases = phrases.withColumn('terms',f.split(f.col('phrase'),' '))
-terms = phrases.select(['phrase','phraseCount','phraseLogProb',f.explode(phrases.terms).alias('term')])
+    # count term occurrances
+    phrases = phrases.withColumn('terms',f.split(f.col('phrase'),' '))
+    terms = phrases.select(['phrase','phraseCount','phraseLogProb',f.explode(phrases.terms).alias('term')])
 
-win = Window.partitionBy('term')
-terms = terms.withColumn('termCount',f.sum('phraseCount').over(win))
-terms = terms.withColumnRenamed('count','termCount')
-terms = terms.withColumn('termLogProb',f.log(f.col('termCount')) - logN)
+    win = Window.partitionBy('term')
+    terms = terms.withColumn('termCount',f.sum('phraseCount').over(win))
+    terms = terms.withColumnRenamed('count','termCount')
+    terms = terms.withColumn('termLogProb',f.log(f.col('termCount')) - logN)
 
-terms = terms.groupBy(terms.phrase, terms.phraseLogProb, terms.phraseCount).sum('termLogProb')
-terms = terms.withColumnRenamed('sum(termLogProb)','termsLogProb')
-terms = terms.withColumn("phrasePWMI", f.col('phraseLogProb') - f.col('termsLogProb'))
+    terms = terms.groupBy(terms.phrase, terms.phraseLogProb, terms.phraseCount).sum('termLogProb')
+    terms = terms.withColumnRenamed('sum(termLogProb)','termsLogProb')
+    terms = terms.withColumn("phrasePWMI", f.col('phraseLogProb') - f.col('termsLogProb'))
 
-# join phrases to term counts
+    # join phrases to term counts
 
 
-df = terms.select(['phrase','phraseCount','phraseLogProb','phrasePWMI'])
+    df = terms.select(['phrase','phraseCount','phraseLogProb','phrasePWMI'])
 
-df = df.sort(['phrasePWMI'],descending=True)
-df = df.sortWithinPartitions(['phrasePWMI'],descending=True)
-df.write.parquet("/gscratch/comdata/users/nathante/reddit_comment_ngrams_pwmi.parquet/",mode='overwrite',compression='snappy')
+    df = df.sort(['phrasePWMI'],descending=True)
+    df = df.sortWithinPartitions(['phrasePWMI'],descending=True)
 
-df = spark.read.parquet("/gscratch/comdata/users/nathante/reddit_comment_ngrams_pwmi.parquet/")
+    pwmi_dir = ngram_dir / "reddit_comment_ngrams_pwmi.parquet/"
+    df.write.parquet(str(pwmi_dir), mode='overwrite', compression='snappy')
 
-df.write.csv("/gscratch/comdata/users/nathante/reddit_comment_ngrams_pwmi.csv/",mode='overwrite',compression='none')
+    df = spark.read.parquet(str(pwmi_dir))
 
-df = spark.read.parquet("/gscratch/comdata/users/nathante/reddit_comment_ngrams_pwmi.parquet")
-df = df.select('phrase','phraseCount','phraseLogProb','phrasePWMI')
+    df.write.csv(str(ngram_dir / "reddit_comment_ngrams_pwmi.csv/"),mode='overwrite',compression='none')
 
-# choosing phrases occurring at least 3500 times in the 10% sample (35000 times) and then with a PWMI of at least 3 yeids about 65000 expressions.
-#
-df = df.filter(f.col('phraseCount') > 3500).filter(f.col("phrasePWMI")>3)
-df = df.toPandas()
-df.to_feather("/gscratch/comdata/users/nathante/reddit_multiword_expressions.feather")
-df.to_csv("/gscratch/comdata/users/nathante/reddit_multiword_expressions.csv")
+    df = spark.read.parquet(str(pwmi_dir))
+    df = df.select('phrase','phraseCount','phraseLogProb','phrasePWMI')
+
+    # choosing phrases occurring at least 3500 times in the 10% sample (35000 times) and then with a PWMI of at least 3 yeids about 65000 expressions.
+    #
+    df = df.filter(f.col('phraseCount') > 3500).filter(f.col("phrasePWMI")>3)
+    df = df.toPandas()
+    df.to_feather(ngram_dir / "multiword_expressions.feather")
+    df.to_csv(ngram_dir / "multiword_expressions.csv")
+
+if __name__ == '__main__':
+    fire.Fire(main)