]> code.communitydata.science - cdsc_reddit.git/blobdiff - clustering/clustering_base.py
make pass keyword arg to dataframe.drop
[cdsc_reddit.git] / clustering / clustering_base.py
index 3778fc3fa91259f49a1b1470e7d0901e1f3ee6ba..2f37b686660ab077e879b8f092697d1652b85a9c 100644 (file)
@@ -1,3 +1,4 @@
+import pickle
 from pathlib import Path
 import numpy as np
 import pandas as pd
@@ -20,10 +21,17 @@ class clustering_job:
         self.subreddits, self.mat = self.read_distance_mat(self.infile)
         self.clustering = self.call(self.mat, *self.args, **self.kwargs)
         self.cluster_data = self.process_clustering(self.clustering, self.subreddits)
-        self.score = self.silhouette()
         self.outpath.mkdir(parents=True, exist_ok=True)
         self.cluster_data.to_feather(self.outpath/(self.name + ".feather"))
+
         self.hasrun = True
+        self.cleanup()
+
+    def cleanup(self):
+        self.cluster_data = None
+        self.mat = None
+        self.clustering=None
+        self.subreddits=None
         
     def get_info(self):
         if not self.hasrun:
@@ -54,11 +62,13 @@ class clustering_job:
         else:
             score = None
             self.silsampout = None
+
         return score
 
     def read_distance_mat(self, similarities, use_threads=True):
+        print(similarities)
         df = pd.read_feather(similarities, use_threads=use_threads)
-        mat = np.array(df.drop('_subreddit',1))
+        mat = np.array(df.drop('_subreddit',axis=1))
         n = mat.shape[0]
         mat[range(n),range(n)] = 1
         return (df._subreddit,1-mat)
@@ -72,9 +82,13 @@ class clustering_job:
         self.n_clusters = len(set(clusters))
 
         print(f"found {self.n_clusters} clusters")
-
         cluster_data = pd.DataFrame({'subreddit': subreddits,'cluster':clustering.labels_})
 
+
+        self.score = self.silhouette()
+        print(f"silhouette_score:{self.score}")
+
+
         cluster_sizes = cluster_data.groupby("cluster").count().reset_index()
         print(f"the largest cluster has {cluster_sizes.loc[cluster_sizes.cluster!=-1].subreddit.max()} members")
 
@@ -95,6 +109,38 @@ class clustering_job:
 
         return cluster_data
 
+class twoway_clustering_job(clustering_job):
+    def __init__(self, infile, outpath, name, call1, call2, args1, args2):
+        self.outpath = Path(outpath)
+        self.call1 = call1
+        self.args1 = args1
+        self.call2 = call2
+        self.args2 = args2
+        self.infile = Path(infile)
+        self.name = name
+        self.hasrun = False
+        self.args = args1|args2
+
+    def run(self):
+        self.subreddits, self.mat = self.read_distance_mat(self.infile)
+        self.step1 = self.call1(self.mat, **self.args1)
+        self.clustering = self.call2(self.mat, self.step1, **self.args2)
+        self.cluster_data = self.process_clustering(self.clustering, self.subreddits)
+        self.hasrun = True
+        self.after_run()
+        self.cleanup()
+
+    def after_run(self):
+        self.score = self.silhouette()
+        self.outpath.mkdir(parents=True, exist_ok=True)
+        print(self.outpath/(self.name+".feather"))
+        self.cluster_data.to_feather(self.outpath/(self.name + ".feather"))
+
+
+    def cleanup(self):
+        super().cleanup()
+        self.step1 = None
+
 @dataclass
 class clustering_result:
     outpath:Path

Community Data Science Collective || Want to submit a patch?