]> code.communitydata.science - cdsc_reddit.git/blob - clustering/kmeans_clustering.py
refactor clustring in object oriented style
[cdsc_reddit.git] / clustering / kmeans_clustering.py
1 from sklearn.cluster import KMeans
2 import fire
3 from pathlib import Path
4 from multiprocessing import cpu_count
5 from dataclasses import dataclass
6 from clustering_base import sim_to_dist, process_clustering_result, clustering_result, read_similarity_mat
7 from clustering_base import lsi_result_mixin, lsi_mixin, clustering_job, grid_sweep, lsi_grid_sweep
8
9
10 @dataclass
11 class kmeans_clustering_result(clustering_result):
12     n_clusters:int
13     n_init:int
14     max_iter:int
15
16 @dataclass
17 class kmeans_clustering_result_lsi(kmeans_clustering_result, lsi_result_mixin):
18     pass
19
20 class kmeans_job(clustering_job):
21     def __init__(self, infile, outpath, name, n_clusters, n_init=10, max_iter=100000, random_state=1968, verbose=True):
22         super().__init__(infile,
23                          outpath,
24                          name,
25                          call=kmeans_job._kmeans_clustering,
26                          n_clusters=n_clusters,
27                          n_init=n_init,
28                          max_iter=max_iter,
29                          random_state=random_state,
30                          verbose=verbose)
31
32         self.n_clusters=n_clusters
33         self.n_init=n_init
34         self.max_iter=max_iter
35
36     def _kmeans_clustering(mat, *args, **kwargs):
37
38         clustering = KMeans(*args,
39                             **kwargs,
40                             ).fit(mat)
41
42         return clustering
43
44
45     def get_info(self):
46         result = super().get_info()
47         self.result = kmeans_clustering_result(**result.__dict__,
48                                                n_init=n_init,
49                                                max_iter=max_iter)
50         return self.result
51
52
53 class kmeans_lsi_job(kmeans_job, lsi_mixin):
54     def __init__(self, infile, outpath, name, lsi_dims, *args, **kwargs):
55         super().__init__(infile,
56                          outpath,
57                          name,
58                          *args,
59                          **kwargs)
60         super().set_lsi_dims(lsi_dims)
61
62     def get_info(self):
63         result = super().get_info()
64         self.result = kmeans_clustering_result_lsi(**result.__dict__,
65                                                    lsi_dimensions=self.lsi_dims)
66         return self.result
67     
68
69 class kmeans_grid_sweep(grid_sweep):
70     def __init__(self,
71                  inpath,
72                  outpath,
73                  *args,
74                  **kwargs):
75         super().__init__(kmeans_job, inpath, outpath, self.namer, *args, **kwargs)
76
77     def namer(self,
78              n_clusters,
79              n_init,
80              max_iter):
81         return f"nclusters-{n_clusters}_nit-{n_init}_maxit-{max_iter}"
82
83 class _kmeans_lsi_grid_sweep(grid_sweep):
84     def __init__(self,
85                  inpath,
86                  outpath,
87                  lsi_dim,
88                  *args,
89                  **kwargs):
90         self.lsi_dim = lsi_dim
91         self.jobtype = kmeans_lsi_job
92         super().__init__(self.jobtype, inpath, outpath, self.namer, self.lsi_dim, *args, **kwargs)
93
94     def namer(self, *args, **kwargs):
95         s = kmeans_grid_sweep.namer(self, *args[1:], **kwargs)
96         s += f"_lsi-{self.lsi_dim}"
97         return s
98
99 class kmeans_lsi_grid_sweep(lsi_grid_sweep):
100     def __init__(self,
101                  inpath,
102                  lsi_dims,
103                  outpath,
104                  n_clusters,
105                  n_inits,
106                  max_iters
107                  ):
108
109         super().__init__(kmeans_lsi_job,
110                          _kmeans_lsi_grid_sweep,
111                          inpath,
112                          lsi_dims,
113                          outpath,
114                          n_clusters,
115                          n_inits,
116                          max_iters)
117
118 def test_select_kmeans_clustering():
119     # select_hdbscan_clustering("/gscratch/comdata/output/reddit_similarity/subreddit_comment_authors-tf_30k_LSI",
120     #                           "test_hdbscan_author30k",
121     #                           min_cluster_sizes=[2],
122     #                           min_samples=[1,2],
123     #                           cluster_selection_epsilons=[0,0.05,0.1,0.15],
124     #                           cluster_selection_methods=['eom','leaf'],
125     #                           lsi_dimensions='all')
126     inpath = "/gscratch/comdata/output/reddit_similarity/subreddit_comment_authors-tf_10k_LSI/"
127     outpath = "test_kmeans";
128     n_clusters=[200,300,400];
129     n_init=[1,2,3];
130     max_iter=[100000]
131
132     gs = kmeans_lsi_grid_sweep(inpath, 'all', outpath, n_clusters, n_init, max_iter)
133     gs.run(1)
134
135     cluster_selection_epsilons=[0,0.1,0.3,0.5];
136     cluster_selection_methods=['eom'];
137     lsi_dimensions='all'
138     gs = hdbscan_lsi_grid_sweep(inpath, "all", outpath, min_cluster_sizes, min_samples, cluster_selection_epsilons, cluster_selection_methods)
139     gs.run(20)
140     gs.save("test_hdbscan/lsi_sweep.csv")
141
142
143 if __name__ == "__main__":
144
145     fire.Fire{'grid_sweep':kmeans_grid_sweep,
146               'grid_sweep_lsi':kmeans_lsi_grid_sweep
147               'cluster':kmeans_job,
148               'cluster_lsi':kmeans_lsi_job}

Community Data Science Collective || Want to submit a patch?