X-Git-Url: https://code.communitydata.science/cdsc_reddit.git/blobdiff_plain/628a70734b19c25da5df2e0c9ed4f1616c2b5c10..7b14db67de8650e4858d3f102fbeab813a30ee29:/timeseries/cluster_timeseries.py?ds=inline diff --git a/timeseries/cluster_timeseries.py b/timeseries/cluster_timeseries.py index 07507d7..91fa705 100644 --- a/timeseries/cluster_timeseries.py +++ b/timeseries/cluster_timeseries.py @@ -2,11 +2,11 @@ import pandas as pd import numpy as np from pyspark.sql import functions as f from pyspark.sql import SparkSession -from choose_clusters import load_clusters, load_densities +from .choose_clusters import load_clusters, load_densities import fire from pathlib import Path -def main(term_clusters_path="/gscratch/comdata/output/reddit_clustering/comment_terms_10000.feather", +def build_cluster_timeseries(term_clusters_path="/gscratch/comdata/output/reddit_clustering/comment_terms_10000.feather", author_clusters_path="/gscratch/comdata/output/reddit_clustering/comment_authors_10000.feather", term_densities_path="/gscratch/comdata/output/reddit_density/comment_terms_10000.feather", author_densities_path="/gscratch/comdata/output/reddit_density/comment_authors_10000.feather", @@ -34,4 +34,4 @@ def main(term_clusters_path="/gscratch/comdata/output/reddit_clustering/comment_ ts.write.parquet(output, mode='overwrite') if __name__ == "__main__": - fire.Fire(main) + fire.Fire(build_cluster_timeseries)