X-Git-Url: https://code.communitydata.science/cdsc_reddit.git/blobdiff_plain/e6294b5b90135a5163441c8dc62252dd6a188412..197518a222a321a8027c3dc5a4121350c47d0779:/ngrams/sort_tf_comments.py diff --git a/ngrams/sort_tf_comments.py b/ngrams/sort_tf_comments.py index abb097e..d9c3e2c 100644 --- a/ngrams/sort_tf_comments.py +++ b/ngrams/sort_tf_comments.py @@ -2,12 +2,17 @@ from pyspark.sql import functions as f from pyspark.sql import SparkSession +import fire -spark = SparkSession.builder.getOrCreate() -df = spark.read.parquet("/gscratch/comdata/users/nathante/reddit_tfidf_test.parquet_temp/") +def main(inparquet, outparquet, colname): + spark = SparkSession.builder.getOrCreate() + df = spark.read.parquet(inparquet) -df = df.repartition(2000,'term') -df = df.sort(['term','week','subreddit']) -df = df.sortWithinPartitions(['term','week','subreddit']) + df = df.repartition(2000,colname) + df = df.sort([colname,'week','subreddit']) + df = df.sortWithinPartitions([colname,'week','subreddit']) -df.write.parquet("/gscratch/comdata/users/nathante/reddit_tfidf_test_sorted_tf.parquet_temp",mode='overwrite',compression='snappy') + df.write.parquet(outparquet,mode='overwrite',compression='snappy') + +if __name__ == '__main__': + fire.Fire(main)