]> code.communitydata.science - cdsc_reddit.git/blob - clustering/clustering_base.py
1d24533b520865d8e3f8bd53bad8a344178d8741
[cdsc_reddit.git] / clustering / clustering_base.py
1 from pathlib import Path
2 import numpy as np
3 import pandas as pd
4 from dataclasses import dataclass
5 from sklearn.metrics import silhouette_score, silhouette_samples
6
7 # this is meant to be an interface, not created directly
8 class clustering_job:
9     def __init__(self, infile, outpath, name, call, *args, **kwargs):
10         self.outpath = Path(outpath)
11         self.call = call
12         self.args = args
13         self.kwargs = kwargs
14         self.infile = Path(infile)
15         self.name = name
16         self.hasrun = False
17
18     def run(self):
19         self.subreddits, self.mat = self.read_distance_mat(self.infile)
20         self.clustering = self.call(self.mat, *self.args, **self.kwargs)
21         self.cluster_data = self.process_clustering(self.clustering, self.subreddits)
22         self.score = self.silhouette()
23         self.outpath.mkdir(parents=True, exist_ok=True)
24         self.cluster_data.to_feather(self.outpath/(self.name + ".feather"))
25         self.hasrun = True
26         
27     def get_info(self):
28         if not self.hasrun:
29             self.run()
30
31         self.result = clustering_result(outpath=str(self.outpath.resolve()),
32                                         silhouette_score=self.score,
33                                         name=self.name,
34                                         n_clusters=self.n_clusters,
35                                         n_isolates=self.n_isolates,
36                                         silhouette_samples = self.silsampout
37                                         )
38         return self.result
39
40     def silhouette(self):
41         isolates = self.clustering.labels_ == -1
42         scoremat = self.mat[~isolates][:,~isolates]
43         if scoremat.shape[0] > 0:
44             score = silhouette_score(scoremat, self.clustering.labels_[~isolates], metric='precomputed')
45             silhouette_samp = silhouette_samples(self.mat, self.clustering.labels_, metric='precomputed')
46             silhouette_samp = pd.DataFrame({'subreddit':self.subreddits,'score':silhouette_samp})
47             self.outpath.mkdir(parents=True, exist_ok=True)
48             silsampout = self.outpath / ("silhouette_samples-" + self.name +  ".feather")
49             self.silsampout = silsampout.resolve()
50             silhouette_samp.to_feather(self.silsampout)
51         else:
52             score = None
53             self.silsampout = None
54         return score
55
56     def read_distance_mat(self, similarities, use_threads=True):
57         df = pd.read_feather(similarities, use_threads=use_threads)
58         mat = np.array(df.drop('_subreddit',1))
59         n = mat.shape[0]
60         mat[range(n),range(n)] = 1
61         return (df._subreddit,1-mat)
62
63     def process_clustering(self, clustering, subreddits):
64
65         if hasattr(clustering,'n_iter_'):
66             print(f"clustering took {clustering.n_iter_} iterations")
67
68         clusters = clustering.labels_
69         self.n_clusters = len(set(clusters))
70
71         print(f"found {self.n_clusters} clusters")
72
73         cluster_data = pd.DataFrame({'subreddit': subreddits,'cluster':clustering.labels_})
74
75         cluster_sizes = cluster_data.groupby("cluster").count().reset_index()
76         print(f"the largest cluster has {cluster_sizes.loc[cluster_sizes.cluster!=-1].subreddit.max()} members")
77
78         print(f"the median cluster has {cluster_sizes.subreddit.median()} members")
79         n_isolates1 = (cluster_sizes.subreddit==1).sum()
80
81         print(f"{n_isolates1} clusters have 1 member")
82
83         n_isolates2 = (cluster_sizes.loc[cluster_sizes.cluster==-1,['subreddit']])
84
85         print(f"{n_isolates2} subreddits are in cluster -1",flush=True)
86
87         if n_isolates1 == 0:
88             self.n_isolates = n_isolates2
89         else:
90             self.n_isolates = n_isolates1
91
92         return cluster_data
93
94 @dataclass
95 class clustering_result:
96     outpath:Path
97     silhouette_score:float
98     name:str
99     n_clusters:int
100     n_isolates:int
101     silhouette_samples:str

Community Data Science Collective || Want to submit a patch?