]> code.communitydata.science - cdsc_reddit.git/blob - clustering/clustering_base.py
changes for archiving.
[cdsc_reddit.git] / clustering / clustering_base.py
1 import pickle
2 from pathlib import Path
3 import numpy as np
4 import pandas as pd
5 from dataclasses import dataclass
6 from sklearn.metrics import silhouette_score, silhouette_samples
7 from collections import Counter
8
9 # this is meant to be an interface, not created directly
10 class clustering_job:
11     def __init__(self, infile, outpath, name, call, *args, **kwargs):
12         self.outpath = Path(outpath)
13         self.call = call
14         self.args = args
15         self.kwargs = kwargs
16         self.infile = Path(infile)
17         self.name = name
18         self.hasrun = False
19
20     def run(self):
21         self.subreddits, self.mat = self.read_distance_mat(self.infile)
22         self.clustering = self.call(self.mat, *self.args, **self.kwargs)
23         self.cluster_data = self.process_clustering(self.clustering, self.subreddits)
24         self.outpath.mkdir(parents=True, exist_ok=True)
25         self.cluster_data.to_feather(self.outpath/(self.name + ".feather"))
26
27         self.hasrun = True
28         self.cleanup()
29
30     def cleanup(self):
31         self.cluster_data = None
32         self.mat = None
33         self.clustering=None
34         self.subreddits=None
35         
36     def get_info(self):
37         if not self.hasrun:
38             self.run()
39
40         self.result = clustering_result(outpath=str(self.outpath.resolve()),
41                                         silhouette_score=self.score,
42                                         name=self.name,
43                                         n_clusters=self.n_clusters,
44                                         n_isolates=self.n_isolates,
45                                         silhouette_samples = self.silsampout
46                                         )
47         return self.result
48
49     def silhouette(self):
50         counts = Counter(self.clustering.labels_)
51         singletons = [key for key, value in counts.items() if value == 1]
52         isolates = (self.clustering.labels_ == -1) | (np.isin(self.clustering.labels_,np.array(singletons)))
53         scoremat = self.mat[~isolates][:,~isolates]
54         if self.n_clusters > 1:
55             score = silhouette_score(scoremat, self.clustering.labels_[~isolates], metric='precomputed')
56             silhouette_samp = silhouette_samples(self.mat, self.clustering.labels_, metric='precomputed')
57             silhouette_samp = pd.DataFrame({'subreddit':self.subreddits,'score':silhouette_samp})
58             self.outpath.mkdir(parents=True, exist_ok=True)
59             silsampout = self.outpath / ("silhouette_samples-" + self.name +  ".feather")
60             self.silsampout = silsampout.resolve()
61             silhouette_samp.to_feather(self.silsampout)
62         else:
63             score = None
64             self.silsampout = None
65
66         return score
67
68     def read_distance_mat(self, similarities, use_threads=True):
69         print(similarities)
70         df = pd.read_feather(similarities, use_threads=use_threads)
71         mat = np.array(df.drop('_subreddit',1))
72         n = mat.shape[0]
73         mat[range(n),range(n)] = 1
74         return (df._subreddit,1-mat)
75
76     def process_clustering(self, clustering, subreddits):
77
78         if hasattr(clustering,'n_iter_'):
79             print(f"clustering took {clustering.n_iter_} iterations")
80
81         clusters = clustering.labels_
82         self.n_clusters = len(set(clusters))
83
84         print(f"found {self.n_clusters} clusters")
85         cluster_data = pd.DataFrame({'subreddit': subreddits,'cluster':clustering.labels_})
86
87
88         self.score = self.silhouette()
89         print(f"silhouette_score:{self.score}")
90
91
92         cluster_sizes = cluster_data.groupby("cluster").count().reset_index()
93         print(f"the largest cluster has {cluster_sizes.loc[cluster_sizes.cluster!=-1].subreddit.max()} members")
94
95         print(f"the median cluster has {cluster_sizes.subreddit.median()} members")
96         n_isolates1 = (cluster_sizes.subreddit==1).sum()
97
98         print(f"{n_isolates1} clusters have 1 member")
99
100         n_isolates2 = cluster_sizes.loc[cluster_sizes.cluster==-1,:]['subreddit'].to_list()
101         if len(n_isolates2) > 0:
102             n_isloates2 = n_isolates2[0]
103         print(f"{n_isolates2} subreddits are in cluster -1",flush=True)
104
105         if n_isolates1 == 0:
106             self.n_isolates = n_isolates2
107         else:
108             self.n_isolates = n_isolates1
109
110         return cluster_data
111
112 class twoway_clustering_job(clustering_job):
113     def __init__(self, infile, outpath, name, call1, call2, args1, args2):
114         self.outpath = Path(outpath)
115         self.call1 = call1
116         self.args1 = args1
117         self.call2 = call2
118         self.args2 = args2
119         self.infile = Path(infile)
120         self.name = name
121         self.hasrun = False
122         self.args = args1|args2
123
124     def run(self):
125         self.subreddits, self.mat = self.read_distance_mat(self.infile)
126         self.step1 = self.call1(self.mat, **self.args1)
127         self.clustering = self.call2(self.mat, self.step1, **self.args2)
128         self.cluster_data = self.process_clustering(self.clustering, self.subreddits)
129         self.hasrun = True
130         self.after_run()
131         self.cleanup()
132
133     def after_run(self):
134         self.score = self.silhouette()
135         self.outpath.mkdir(parents=True, exist_ok=True)
136         print(self.outpath/(self.name+".feather"))
137         self.cluster_data.to_feather(self.outpath/(self.name + ".feather"))
138
139
140     def cleanup(self):
141         super().cleanup()
142         self.step1 = None
143
144 @dataclass
145 class clustering_result:
146     outpath:Path
147     silhouette_score:float
148     name:str
149     n_clusters:int
150     n_isolates:int
151     silhouette_samples:str

Community Data Science Collective || Want to submit a patch?