]> code.communitydata.science - cdsc_reddit.git/blobdiff - timeseries/cluster_timeseries.py
Merge remote-tracking branch 'refs/remotes/origin/excise_reindex' into excise_reindex
[cdsc_reddit.git] / timeseries / cluster_timeseries.py
index 07507d74c037ca870bc57e83357d8847d9506a46..2286ab0cad083307fbe977344f96a35f8b6a1c41 100644 (file)
@@ -2,20 +2,16 @@ import pandas as pd
 import numpy as np
 from pyspark.sql import functions as f
 from pyspark.sql import SparkSession
 import numpy as np
 from pyspark.sql import functions as f
 from pyspark.sql import SparkSession
-from choose_clusters import load_clusters, load_densities
+from .choose_clusters import load_clusters, load_densities
 import fire
 from pathlib import Path
 
 import fire
 from pathlib import Path
 
-def main(term_clusters_path="/gscratch/comdata/output/reddit_clustering/comment_terms_10000.feather",
+def build_cluster_timeseries(term_clusters_path="/gscratch/comdata/output/reddit_clustering/comment_terms_10000.feather",
          author_clusters_path="/gscratch/comdata/output/reddit_clustering/comment_authors_10000.feather",
          term_densities_path="/gscratch/comdata/output/reddit_density/comment_terms_10000.feather",
          author_densities_path="/gscratch/comdata/output/reddit_density/comment_authors_10000.feather",
          output="data/subreddit_timeseries.parquet"):
 
          author_clusters_path="/gscratch/comdata/output/reddit_clustering/comment_authors_10000.feather",
          term_densities_path="/gscratch/comdata/output/reddit_density/comment_terms_10000.feather",
          author_densities_path="/gscratch/comdata/output/reddit_density/comment_authors_10000.feather",
          output="data/subreddit_timeseries.parquet"):
 
-
-    clusters = load_clusters(term_clusters_path, author_clusters_path)
-    densities = load_densities(term_densities_path, author_densities_path)
-    
     spark = SparkSession.builder.getOrCreate()
     
     df = spark.read.parquet("/gscratch/comdata/output/reddit_comments_by_subreddit.parquet")
     spark = SparkSession.builder.getOrCreate()
     
     df = spark.read.parquet("/gscratch/comdata/output/reddit_comments_by_subreddit.parquet")
@@ -26,12 +22,16 @@ def main(term_clusters_path="/gscratch/comdata/output/reddit_clustering/comment_
     ts = df.select(['subreddit','week','author']).distinct().groupby(['subreddit','week']).count()
     
     ts = ts.repartition('subreddit')
     ts = df.select(['subreddit','week','author']).distinct().groupby(['subreddit','week']).count()
     
     ts = ts.repartition('subreddit')
-    spk_clusters = spark.createDataFrame(clusters)
+
+    if term_densities_path is not None and author_densities_path is not None:
+        densities = load_densities(term_densities_path, author_densities_path)
+        spk_densities = spark.createDataFrame(densities)
+        ts = ts.join(spk_densities, on='subreddit', how='inner')
     
     
+    clusters = load_clusters(term_clusters_path, author_clusters_path)
+    spk_clusters = spark.createDataFrame(clusters)
     ts = ts.join(spk_clusters, on='subreddit', how='inner')
     ts = ts.join(spk_clusters, on='subreddit', how='inner')
-    spk_densities = spark.createDataFrame(densities)
-    ts = ts.join(spk_densities, on='subreddit', how='inner')
     ts.write.parquet(output, mode='overwrite')
 
 if __name__ == "__main__":
     ts.write.parquet(output, mode='overwrite')
 
 if __name__ == "__main__":
-    fire.Fire(main)
+    fire.Fire(build_cluster_timeseries)

Community Data Science Collective || Want to submit a patch?