1 from pathlib import Path
4 from dataclasses import dataclass
5 from sklearn.metrics import silhouette_score, silhouette_samples
6 from itertools import product, chain
7 from multiprocessing import Pool, cpu_count
12 np.fill_diagonal(dist,0)
16 def __init__(self, jobtype, inpath, outpath, namer, *args):
17 self.jobtype = jobtype
19 grid = list(product(*args))
21 outpath = Path(outpath)
23 self.grid = [(inpath,outpath,namer(*g)) + g for g in grid]
24 self.jobs = [jobtype(*g) for g in self.grid]
26 def run(self, cores=20):
27 if cores is not None and cores > 1:
28 with Pool(cores) as pool:
29 infos = pool.map(self.jobtype.get_info, self.jobs)
31 infos = map(self.jobtype.get_info, self.jobs)
33 self.infos = pd.DataFrame(infos)
36 def save(self, outcsv):
40 outcsv.parent.mkdir(parents=True, exist_ok=True)
41 self.infos.to_csv(outcsv)
44 class lsi_grid_sweep(grid_sweep):
45 def __init__(self, jobtype, subsweep, inpath, lsi_dimensions, outpath, *args, **kwargs):
46 self.jobtype = jobtype
47 self.subsweep = subsweep
49 if lsi_dimensions == 'all':
50 lsi_paths = list(inpath.glob("*"))
52 lsi_paths = [inpath / (dim + '.feather') for dim in lsi_dimensions]
54 lsi_nums = [p.stem for p in lsi_paths]
56 self.subgrids = [self.subsweep(lsi_path, outpath, lsi_dim, *args, **kwargs) for lsi_dim, lsi_path in zip(lsi_nums, lsi_paths)]
57 self.jobs = list(chain(*map(lambda gs: gs.jobs, self.subgrids)))
60 # this is meant to be an interface, not created directly
62 def __init__(self, infile, outpath, name, call, *args, **kwargs):
63 self.outpath = Path(outpath)
67 self.infile = Path(infile)
72 self.subreddits, self.mat = self.read_distance_mat(self.infile)
73 self.clustering = self.call(self.mat, *self.args, **self.kwargs)
74 self.cluster_data = self.process_clustering(self.clustering, self.subreddits)
75 self.score = self.silhouette()
76 self.outpath.mkdir(parents=True, exist_ok=True)
77 self.cluster_data.to_feather(self.outpath/(self.name + ".feather"))
84 self.result = clustering_result(outpath=str(self.outpath.resolve()),
85 silhouette_score=self.score,
87 n_clusters=self.n_clusters,
88 n_isolates=self.n_isolates,
89 silhouette_samples = str(self.silsampout.resolve())
94 isolates = self.clustering.labels_ == -1
95 scoremat = self.mat[~isolates][:,~isolates]
96 score = silhouette_score(scoremat, self.clustering.labels_[~isolates], metric='precomputed')
97 silhouette_samp = silhouette_samples(self.mat, self.clustering.labels_, metric='precomputed')
98 silhouette_samp = pd.DataFrame({'subreddit':self.subreddits,'score':silhouette_samp})
99 self.outpath.mkdir(parents=True, exist_ok=True)
100 self.silsampout = self.outpath / ("silhouette_samples-" + self.name + ".feather")
101 silhouette_samp.to_feather(self.silsampout)
104 def read_distance_mat(self, similarities, use_threads=True):
105 df = pd.read_feather(similarities, use_threads=use_threads)
106 mat = np.array(df.drop('_subreddit',1))
108 mat[range(n),range(n)] = 1
109 return (df._subreddit,1-mat)
111 def process_clustering(self, clustering, subreddits):
113 if hasattr(clustering,'n_iter_'):
114 print(f"clustering took {clustering.n_iter_} iterations")
116 clusters = clustering.labels_
117 self.n_clusters = len(set(clusters))
119 print(f"found {self.n_clusters} clusters")
121 cluster_data = pd.DataFrame({'subreddit': subreddits,'cluster':clustering.labels_})
123 cluster_sizes = cluster_data.groupby("cluster").count().reset_index()
124 print(f"the largest cluster has {cluster_sizes.loc[cluster_sizes.cluster!=-1].subreddit.max()} members")
126 print(f"the median cluster has {cluster_sizes.subreddit.median()} members")
127 n_isolates1 = (cluster_sizes.subreddit==1).sum()
129 print(f"{n_isolates1} clusters have 1 member")
131 n_isolates2 = (cluster_sizes.loc[cluster_sizes.cluster==-1,['subreddit']])
133 print(f"{n_isolates2} subreddits are in cluster -1",flush=True)
136 self.n_isolates = n_isolates2
138 self.n_isolates = n_isolates1
144 def set_lsi_dims(self, lsi_dims):
145 self.lsi_dims = lsi_dims
148 class clustering_result:
150 silhouette_score:float
154 silhouette_samples:str
157 class lsi_result_mixin: