X-Git-Url: https://code.communitydata.science/cdsc_reddit.git/blobdiff_plain/7b14db67de8650e4858d3f102fbeab813a30ee29..5a40465a629a1d7d95dbec9730d3950842bcb4f5:/clustering/clustering_base.py?ds=sidebyside diff --git a/clustering/clustering_base.py b/clustering/clustering_base.py index 3778fc3..ced627d 100644 --- a/clustering/clustering_base.py +++ b/clustering/clustering_base.py @@ -1,3 +1,4 @@ +import pickle from pathlib import Path import numpy as np import pandas as pd @@ -24,6 +25,13 @@ class clustering_job: self.outpath.mkdir(parents=True, exist_ok=True) self.cluster_data.to_feather(self.outpath/(self.name + ".feather")) self.hasrun = True + self.cleanup() + + def cleanup(self): + self.cluster_data = None + self.mat = None + self.clustering=None + self.subreddits=None def get_info(self): if not self.hasrun: @@ -57,6 +65,7 @@ class clustering_job: return score def read_distance_mat(self, similarities, use_threads=True): + print(similarities) df = pd.read_feather(similarities, use_threads=use_threads) mat = np.array(df.drop('_subreddit',1)) n = mat.shape[0] @@ -95,6 +104,38 @@ class clustering_job: return cluster_data +class twoway_clustering_job(clustering_job): + def __init__(self, infile, outpath, name, call1, call2, args1, args2): + self.outpath = Path(outpath) + self.call1 = call1 + self.args1 = args1 + self.call2 = call2 + self.args2 = args2 + self.infile = Path(infile) + self.name = name + self.hasrun = False + self.args = args1|args2 + + def run(self): + self.subreddits, self.mat = self.read_distance_mat(self.infile) + self.step1 = self.call1(self.mat, **self.args1) + self.clustering = self.call2(self.mat, self.step1, **self.args2) + self.cluster_data = self.process_clustering(self.clustering, self.subreddits) + self.hasrun = True + self.after_run() + self.cleanup() + + def after_run(): + self.score = self.silhouette() + self.outpath.mkdir(parents=True, exist_ok=True) + print(self.outpath/(self.name+".feather")) + self.cluster_data.to_feather(self.outpath/(self.name + ".feather")) + + + def cleanup(self): + super().cleanup() + self.step1 = None + @dataclass class clustering_result: outpath:Path