]> code.communitydata.science - cdsc_reddit.git/blob - clustering/hdbscan_clustering.py
update clustering scripts
[cdsc_reddit.git] / clustering / hdbscan_clustering.py
1 from clustering_base import clustering_result, clustering_job
2 from grid_sweep import grid_sweep
3 from dataclasses import dataclass
4 import hdbscan
5 from sklearn.neighbors import NearestNeighbors
6 import plotnine as pn
7 import numpy as np
8 from itertools import product, starmap, chain
9 import pandas as pd
10 from multiprocessing import cpu_count
11 import fire
12
13 def test_select_hdbscan_clustering():
14     # select_hdbscan_clustering("/gscratch/comdata/output/reddit_similarity/subreddit_comment_authors-tf_30k_LSI",
15     #                           "test_hdbscan_author30k",
16     #                           min_cluster_sizes=[2],
17     #                           min_samples=[1,2],
18     #                           cluster_selection_epsilons=[0,0.05,0.1,0.15],
19     #                           cluster_selection_methods=['eom','leaf'],
20     #                           lsi_dimensions='all')
21     inpath = "/gscratch/comdata/output/reddit_similarity/subreddit_comment_authors-tf_10k_LSI/"
22     outpath = "test_hdbscan";
23     min_cluster_sizes=[2,3,4];
24     min_samples=[1,2,3];
25     cluster_selection_epsilons=[0,0.1,0.3,0.5];
26     cluster_selection_methods=['eom'];
27     lsi_dimensions='all'
28     gs = hdbscan_lsi_grid_sweep(inpath, "all", outpath, min_cluster_sizes, min_samples, cluster_selection_epsilons, cluster_selection_methods)
29     gs.run(20)
30     gs.save("test_hdbscan/lsi_sweep.csv")
31     # job1 = hdbscan_lsi_job(infile=inpath, outpath=outpath, name="test", lsi_dims=500, min_cluster_size=2, min_samples=1,cluster_selection_epsilon=0,cluster_selection_method='eom')
32     # job1.run()
33     # print(job1.get_info())
34
35     # df = pd.read_csv("test_hdbscan/selection_data.csv")
36     # test_select_hdbscan_clustering()
37     # check_clusters = pd.read_feather("test_hdbscan/500_2_2_0.1_eom.feather")
38     # silscores = pd.read_feather("test_hdbscan/silhouette_samples500_2_2_0.1_eom.feather")
39     # c = check_clusters.merge(silscores,on='subreddit')#    fire.Fire(select_hdbscan_clustering)
40 class hdbscan_grid_sweep(grid_sweep):
41     def __init__(self,
42                  inpath,
43                  outpath,
44                  *args,
45                  **kwargs):
46
47         super().__init__(hdbscan_job, inpath, outpath, self.namer, *args, **kwargs)
48
49     def namer(self,
50               min_cluster_size,
51               min_samples,
52               cluster_selection_epsilon,
53               cluster_selection_method):
54         return f"mcs-{min_cluster_size}_ms-{min_samples}_cse-{cluster_selection_epsilon}_csm-{cluster_selection_method}"
55
56 @dataclass
57 class hdbscan_clustering_result(clustering_result):
58     min_cluster_size:int
59     min_samples:int
60     cluster_selection_epsilon:float
61     cluster_selection_method:str
62
63 class hdbscan_job(clustering_job):
64     def __init__(self, infile, outpath, name, min_cluster_size=2, min_samples=1, cluster_selection_epsilon=0, cluster_selection_method='eom'):
65         super().__init__(infile,
66                          outpath,
67                          name,
68                          call=hdbscan_job._hdbscan_clustering,
69                          min_cluster_size=min_cluster_size,
70                          min_samples=min_samples,
71                          cluster_selection_epsilon=cluster_selection_epsilon,
72                          cluster_selection_method=cluster_selection_method
73                          )
74
75         self.min_cluster_size = min_cluster_size
76         self.min_samples = min_samples
77         self.cluster_selection_epsilon = cluster_selection_epsilon
78         self.cluster_selection_method = cluster_selection_method
79 #        self.mat = 1 - self.mat
80
81     def _hdbscan_clustering(mat, *args, **kwargs):
82         print(f"running hdbscan clustering. args:{args}. kwargs:{kwargs}")
83         print(mat)
84         clusterer = hdbscan.HDBSCAN(metric='precomputed',
85                                     core_dist_n_jobs=cpu_count(),
86                                     *args,
87                                     **kwargs,
88                                     )
89     
90         clustering = clusterer.fit(mat.astype('double'))
91     
92         return(clustering)
93
94     def get_info(self):
95         result = super().get_info()
96         self.result = hdbscan_clustering_result(**result.__dict__,
97                                                 min_cluster_size=self.min_cluster_size,
98                                                 min_samples=self.min_samples,
99                                                 cluster_selection_epsilon=self.cluster_selection_epsilon,
100                                                 cluster_selection_method=self.cluster_selection_method)
101         return self.result
102
103 def run_hdbscan_grid_sweep(savefile, inpath, outpath,  min_cluster_sizes=[2], min_samples=[1], cluster_selection_epsilons=[0], cluster_selection_methods=['eom']):
104     """Run hdbscan clustering once or more with different parameters.
105     
106     Usage:
107     hdbscan_clustering.py --savefile=SAVEFILE --inpath=INPATH --outpath=OUTPATH --min_cluster_sizes=<csv> --min_samples=<csv> --cluster_selection_epsilons=<csv> --cluster_selection_methods=<csv "eom"|"leaf">
108
109     Keword arguments:
110     savefile: path to save the metadata and diagnostics 
111     inpath: path to feather data containing a labeled matrix of subreddit similarities.
112     outpath: path to output fit kmeans clusterings.
113     min_cluster_sizes: one or more integers indicating the minumum cluster size
114     min_samples: one ore more integers indicating the minimum number of samples used in the algorithm
115     cluster_selection_epsilon: one or more similarity thresholds for transition from dbscan to hdbscan
116     cluster_selection_method: "eom" or "leaf" eom gives larger clusters. 
117     """    
118     obj = hdbscan_grid_sweep(inpath,
119                              outpath,
120                              map(int,min_cluster_sizes),
121                              map(int,min_samples),
122                              map(float,cluster_selection_epsilons),
123                              map(float,cluster_selection_methods))
124     obj.run()
125     obj.save(savefile)
126
127 def KNN_distances_plot(mat,outname,k=2):
128     nbrs = NearestNeighbors(n_neighbors=k,algorithm='auto',metric='precomputed').fit(mat)
129     distances, indices = nbrs.kneighbors(mat)
130     d2 = distances[:,-1]
131     df = pd.DataFrame({'dist':d2})
132     df = df.sort_values("dist",ascending=False)
133     df['idx'] = np.arange(0,d2.shape[0]) + 1
134     p = pn.qplot(x='idx',y='dist',data=df,geom='line') + pn.scales.scale_y_continuous(minor_breaks = np.arange(0,50)/50,
135                                                                                       breaks = np.arange(0,10)/10)
136     p.save(outname,width=16,height=10)
137     
138 def make_KNN_plots():
139     similarities = "/gscratch/comdata/output/reddit_similarity/subreddit_comment_terms_10k.feather"
140     subreddits, mat = read_similarity_mat(similarities)
141     mat = sim_to_dist(mat)
142
143     KNN_distances_plot(mat,k=2,outname='terms_knn_dist2.png')
144
145     similarities = "/gscratch/comdata/output/reddit_similarity/subreddit_comment_authors_10k.feather"
146     subreddits, mat = read_similarity_mat(similarities)
147     mat = sim_to_dist(mat)
148     KNN_distances_plot(mat,k=2,outname='authors_knn_dist2.png')
149
150     similarities = "/gscratch/comdata/output/reddit_similarity/subreddit_comment_authors-tf_10k.feather"
151     subreddits, mat = read_similarity_mat(similarities)
152     mat = sim_to_dist(mat)
153     KNN_distances_plot(mat,k=2,outname='authors-tf_knn_dist2.png')
154
155 if __name__ == "__main__":
156     fire.Fire(run_hdbscan_grid_sweep)
157     
158 #    test_select_hdbscan_clustering()
159     #fire.Fire(select_hdbscan_clustering)  

Community Data Science Collective || Want to submit a patch?