]> code.communitydata.science - cdsc_reddit.git/blob - clustering/hdbscan_clustering.py
Merge branch 'charliepatch' of code:cdsc_reddit into charliepatch
[cdsc_reddit.git] / clustering / hdbscan_clustering.py
1 from clustering_base import sim_to_dist, process_clustering_result, clustering_result, read_similarity_mat
2 from dataclasses import dataclass
3 import hdbscan
4 from sklearn.neighbors import NearestNeighbors
5 import plotnine as pn
6 import numpy as np
7 from itertools import product, starmap
8 import pandas as pd
9 from sklearn.metrics import silhouette_score, silhouette_samples
10 from pathlib import Path
11 from multiprocessing import Pool, cpu_count
12 import fire
13 from pyarrow.feather import write_feather
14
15 def test_select_hdbscan_clustering():
16     select_hdbscan_clustering("/gscratch/comdata/output/reddit_similarity/subreddit_comment_authors-tf_30k_LSI",
17                               "test_hdbscan_author30k",
18                               min_cluster_sizes=[2],
19                               min_samples=[1,2],
20                               cluster_selection_epsilons=[0,0.05,0.1,0.15],
21                               cluster_selection_methods=['eom','leaf'],
22                               lsi_dimensions='all')
23     inpath = "/gscratch/comdata/output/reddit_similarity/subreddit_comment_authors-tf_30k_LSI"
24     outpath = "test_hdbscan";
25     min_cluster_sizes=[2,3,4];
26     min_samples=[1,2,3];
27     cluster_selection_epsilons=[0,0.1,0.3,0.5];
28     cluster_selection_methods=['eom'];
29     lsi_dimensions='all'
30
31 @dataclass
32 class hdbscan_clustering_result(clustering_result):
33     min_cluster_size:int
34     min_samples:int
35     cluster_selection_epsilon:float
36     cluster_selection_method:str
37     lsi_dimensions:int
38     n_isolates:int
39     silhouette_samples:str
40
41 def select_hdbscan_clustering(inpath,
42                               outpath,
43                               outfile=None,
44                               min_cluster_sizes=[2],
45                               min_samples=[1],
46                               cluster_selection_epsilons=[0],
47                               cluster_selection_methods=['eom'],
48                               lsi_dimensions='all'
49                               ):
50
51     inpath = Path(inpath)
52     outpath = Path(outpath)
53     outpath.mkdir(exist_ok=True, parents=True)
54     
55     if lsi_dimensions == 'all':
56         lsi_paths = list(inpath.glob("*"))
57
58     else:
59         lsi_paths = [inpath / (dim + '.feather') for dim in lsi_dimensions]
60
61     lsi_nums = [p.stem for p in lsi_paths]
62     grid = list(product(lsi_nums,
63                         min_cluster_sizes,
64                         min_samples,
65                         cluster_selection_epsilons,
66                         cluster_selection_methods))
67
68     # fix the output file names
69     names = list(map(lambda t:'_'.join(map(str,t)),grid))
70
71     grid = [(inpath/(str(t[0])+'.feather'),outpath/(name + '.feather'), t[0], name) + t[1:] for t, name in zip(grid, names)]
72         
73     with Pool(int(cpu_count()/4)) as pool:
74         mods = starmap(hdbscan_clustering, grid)
75
76     res = pd.DataFrame(mods)
77     if outfile is None:
78         outfile = outpath / "selection_data.csv"
79
80     res.to_csv(outfile)
81
82 def hdbscan_clustering(similarities, output, lsi_dim, name, min_cluster_size=2, min_samples=1, cluster_selection_epsilon=0, cluster_selection_method='eom'):
83     subreddits, mat = read_similarity_mat(similarities)
84     mat = sim_to_dist(mat)
85     clustering = _hdbscan_clustering(mat,
86                                      min_cluster_size=min_cluster_size,
87                                      min_samples=min_samples,
88                                      cluster_selection_epsilon=cluster_selection_epsilon,
89                                      cluster_selection_method=cluster_selection_method,
90                                      metric='precomputed',
91                                      core_dist_n_jobs=cpu_count()
92                                      )
93
94     cluster_data = process_clustering_result(clustering, subreddits)
95     isolates = clustering.labels_ == -1
96     scoremat = mat[~isolates][:,~isolates]
97     score = silhouette_score(scoremat, clustering.labels_[~isolates], metric='precomputed')
98     cluster_data.to_feather(output)
99
100     silhouette_samp = silhouette_samples(mat, clustering.labels_, metric='precomputed')
101     silhouette_samp = pd.DataFrame({'subreddit':subreddits,'score':silhouette_samp})
102     silsampout = output.parent / ("silhouette_samples" + output.name)
103     silhouette_samp.to_feather(silsampout)
104
105     result = hdbscan_clustering_result(outpath=output,
106                                        max_iter=None,
107                                        silhouette_samples=silsampout,
108                                        silhouette_score=score,
109                                        alt_silhouette_score=score,
110                                        name=name,
111                                        min_cluster_size=min_cluster_size,
112                                        min_samples=min_samples,
113                                        cluster_selection_epsilon=cluster_selection_epsilon,
114                                        cluster_selection_method=cluster_selection_method,
115                                        lsi_dimensions=lsi_dim,
116                                        n_isolates=isolates.sum(),
117                                        n_clusters=len(set(clustering.labels_))
118                                    )
119
120
121                                        
122     return(result)
123
124 # for all runs we should try cluster_selection_epsilon = None
125 # for terms we should try cluster_selection_epsilon around 0.56-0.66
126 # for authors we should try cluster_selection_epsilon around 0.98-0.99
127 def _hdbscan_clustering(mat, *args, **kwargs):
128     print(f"running hdbscan clustering. args:{args}. kwargs:{kwargs}")
129
130     print(mat)
131     clusterer = hdbscan.HDBSCAN(*args,
132                                 **kwargs,
133                                 )
134     
135     clustering = clusterer.fit(mat.astype('double'))
136     
137     return(clustering)
138
139 def KNN_distances_plot(mat,outname,k=2):
140     nbrs = NearestNeighbors(n_neighbors=k,algorithm='auto',metric='precomputed').fit(mat)
141     distances, indices = nbrs.kneighbors(mat)
142     d2 = distances[:,-1]
143     df = pd.DataFrame({'dist':d2})
144     df = df.sort_values("dist",ascending=False)
145     df['idx'] = np.arange(0,d2.shape[0]) + 1
146     p = pn.qplot(x='idx',y='dist',data=df,geom='line') + pn.scales.scale_y_continuous(minor_breaks = np.arange(0,50)/50,
147                                                                                       breaks = np.arange(0,10)/10)
148     p.save(outname,width=16,height=10)
149     
150 def make_KNN_plots():
151     similarities = "/gscratch/comdata/output/reddit_similarity/subreddit_comment_terms_10k.feather"
152     subreddits, mat = read_similarity_mat(similarities)
153     mat = sim_to_dist(mat)
154
155     KNN_distances_plot(mat,k=2,outname='terms_knn_dist2.png')
156
157     similarities = "/gscratch/comdata/output/reddit_similarity/subreddit_comment_authors_10k.feather"
158     subreddits, mat = read_similarity_mat(similarities)
159     mat = sim_to_dist(mat)
160     KNN_distances_plot(mat,k=2,outname='authors_knn_dist2.png')
161
162     similarities = "/gscratch/comdata/output/reddit_similarity/subreddit_comment_authors-tf_10k.feather"
163     subreddits, mat = read_similarity_mat(similarities)
164     mat = sim_to_dist(mat)
165     KNN_distances_plot(mat,k=2,outname='authors-tf_knn_dist2.png')
166
167 if __name__ == "__main__":
168     df = pd.read_csv("test_hdbscan/selection_data.csv")
169     test_select_hdbscan_clustering()
170     check_clusters = pd.read_feather("test_hdbscan/500_2_2_0.1_eom.feather")
171     silscores = pd.read_feather("test_hdbscan/silhouette_samples500_2_2_0.1_eom.feather")
172     c = check_clusters.merge(silscores,on='subreddit')#    fire.Fire(select_hdbscan_clustering)

Community Data Science Collective || Want to submit a patch?