]> code.communitydata.science - cdsc_reddit.git/blob - clustering/clustering_base.py
refactor clustring in object oriented style
[cdsc_reddit.git] / clustering / clustering_base.py
1 from pathlib import Path
2 import numpy as np
3 import pandas as pd
4 from dataclasses import dataclass
5 from sklearn.metrics import silhouette_score, silhouette_samples
6 from itertools import product, chain
7 from multiprocessing import Pool, cpu_count
8
9 def sim_to_dist(mat):
10     dist = 1-mat
11     dist[dist < 0] = 0
12     np.fill_diagonal(dist,0)
13     return dist
14
15 class grid_sweep:
16     def __init__(self, jobtype, inpath, outpath, namer, *args):
17         self.jobtype = jobtype
18         self.namer = namer
19         grid = list(product(*args))
20         inpath = Path(inpath)
21         outpath = Path(outpath)
22         self.hasrun = False
23         self.grid = [(inpath,outpath,namer(*g)) + g for g in grid]
24         self.jobs = [jobtype(*g) for g in self.grid]
25
26     def run(self, cores=20):
27         if cores is not None and cores > 1:
28             with Pool(cores) as pool:
29                 infos = pool.map(self.jobtype.get_info, self.jobs)
30         else:
31             infos = map(self.jobtype.get_info, self.jobs)
32
33         self.infos = pd.DataFrame(infos)
34         self.hasrun = True
35
36     def save(self, outcsv):
37         if not self.hasrun:
38             self.run()
39         outcsv = Path(outcsv)
40         outcsv.parent.mkdir(parents=True, exist_ok=True)
41         self.infos.to_csv(outcsv)
42
43
44 class lsi_grid_sweep(grid_sweep):
45     def __init__(self, jobtype, subsweep, inpath, lsi_dimensions, outpath, *args, **kwargs):
46         self.jobtype = jobtype
47         self.subsweep = subsweep
48         inpath = Path(inpath)
49         if lsi_dimensions == 'all':
50             lsi_paths = list(inpath.glob("*"))
51         else:
52             lsi_paths = [inpath / (dim + '.feather') for dim in lsi_dimensions]
53
54         lsi_nums = [p.stem for p in lsi_paths]
55         self.hasrun = False
56         self.subgrids = [self.subsweep(lsi_path, outpath,  lsi_dim, *args, **kwargs) for lsi_dim, lsi_path in zip(lsi_nums, lsi_paths)]
57         self.jobs = list(chain(*map(lambda gs: gs.jobs, self.subgrids)))
58
59
60 # this is meant to be an interface, not created directly
61 class clustering_job:
62     def __init__(self, infile, outpath, name, call, *args, **kwargs):
63         self.outpath = Path(outpath)
64         self.call = call
65         self.args = args
66         self.kwargs = kwargs
67         self.infile = Path(infile)
68         self.name = name
69         self.hasrun = False
70
71     def run(self):
72         self.subreddits, self.mat = self.read_distance_mat(self.infile)
73         self.clustering = self.call(self.mat, *self.args, **self.kwargs)
74         self.cluster_data = self.process_clustering(self.clustering, self.subreddits)
75         self.score = self.silhouette()
76         self.outpath.mkdir(parents=True, exist_ok=True)
77         self.cluster_data.to_feather(self.outpath/(self.name + ".feather"))
78         self.hasrun = True
79         
80     def get_info(self):
81         if not self.hasrun:
82             self.run()
83
84         self.result = clustering_result(outpath=str(self.outpath.resolve()),
85                                         silhouette_score=self.score,
86                                         name=self.name,
87                                         n_clusters=self.n_clusters,
88                                         n_isolates=self.n_isolates,
89                                         silhouette_samples = str(self.silsampout.resolve())
90                                         )
91         return self.result
92
93     def silhouette(self):
94         isolates = self.clustering.labels_ == -1
95         scoremat = self.mat[~isolates][:,~isolates]
96         score = silhouette_score(scoremat, self.clustering.labels_[~isolates], metric='precomputed')
97         silhouette_samp = silhouette_samples(self.mat, self.clustering.labels_, metric='precomputed')
98         silhouette_samp = pd.DataFrame({'subreddit':self.subreddits,'score':silhouette_samp})
99         self.outpath.mkdir(parents=True, exist_ok=True)
100         self.silsampout = self.outpath / ("silhouette_samples-" + self.name +  ".feather")
101         silhouette_samp.to_feather(self.silsampout)
102         return score
103
104     def read_distance_mat(self, similarities, use_threads=True):
105         df = pd.read_feather(similarities, use_threads=use_threads)
106         mat = np.array(df.drop('_subreddit',1))
107         n = mat.shape[0]
108         mat[range(n),range(n)] = 1
109         return (df._subreddit,1-mat)
110
111     def process_clustering(self, clustering, subreddits):
112
113         if hasattr(clustering,'n_iter_'):
114             print(f"clustering took {clustering.n_iter_} iterations")
115
116         clusters = clustering.labels_
117         self.n_clusters = len(set(clusters))
118
119         print(f"found {self.n_clusters} clusters")
120
121         cluster_data = pd.DataFrame({'subreddit': subreddits,'cluster':clustering.labels_})
122
123         cluster_sizes = cluster_data.groupby("cluster").count().reset_index()
124         print(f"the largest cluster has {cluster_sizes.loc[cluster_sizes.cluster!=-1].subreddit.max()} members")
125
126         print(f"the median cluster has {cluster_sizes.subreddit.median()} members")
127         n_isolates1 = (cluster_sizes.subreddit==1).sum()
128
129         print(f"{n_isolates1} clusters have 1 member")
130
131         n_isolates2 = (cluster_sizes.loc[cluster_sizes.cluster==-1,['subreddit']])
132
133         print(f"{n_isolates2} subreddits are in cluster -1",flush=True)
134
135         if n_isolates1 == 0:
136             self.n_isolates = n_isolates2
137         else:
138             self.n_isolates = n_isolates1
139
140         return cluster_data
141
142
143 class lsi_mixin():
144     def set_lsi_dims(self, lsi_dims):
145         self.lsi_dims = lsi_dims
146
147 @dataclass
148 class clustering_result:
149     outpath:Path
150     silhouette_score:float
151     name:str
152     n_clusters:int
153     n_isolates:int
154     silhouette_samples:str
155
156 @dataclass
157 class lsi_result_mixin:
158     lsi_dimensions:int

Community Data Science Collective || Want to submit a patch?