X-Git-Url: https://code.communitydata.science/cdsc_reddit.git/blobdiff_plain/36cb0a5546d220bb19c0029eb7d4365059822f84..2d21ff1137dfaf83c5a51fdcd8900503c50a06ab:/timeseries/cluster_timeseries.py diff --git a/timeseries/cluster_timeseries.py b/timeseries/cluster_timeseries.py index 07507d7..91fa705 100644 --- a/timeseries/cluster_timeseries.py +++ b/timeseries/cluster_timeseries.py @@ -2,11 +2,11 @@ import pandas as pd import numpy as np from pyspark.sql import functions as f from pyspark.sql import SparkSession -from choose_clusters import load_clusters, load_densities +from .choose_clusters import load_clusters, load_densities import fire from pathlib import Path -def main(term_clusters_path="/gscratch/comdata/output/reddit_clustering/comment_terms_10000.feather", +def build_cluster_timeseries(term_clusters_path="/gscratch/comdata/output/reddit_clustering/comment_terms_10000.feather", author_clusters_path="/gscratch/comdata/output/reddit_clustering/comment_authors_10000.feather", term_densities_path="/gscratch/comdata/output/reddit_density/comment_terms_10000.feather", author_densities_path="/gscratch/comdata/output/reddit_density/comment_authors_10000.feather", @@ -34,4 +34,4 @@ def main(term_clusters_path="/gscratch/comdata/output/reddit_clustering/comment_ ts.write.parquet(output, mode='overwrite') if __name__ == "__main__": - fire.Fire(main) + fire.Fire(build_cluster_timeseries)