]> code.communitydata.science - cdsc_reddit.git/blob - clustering/kmeans_clustering.py
Refactor to make a decent api.
[cdsc_reddit.git] / clustering / kmeans_clustering.py
1 from sklearn.cluster import KMeans
2 import fire
3 from pathlib import Path
4 from dataclasses import dataclass
5 from clustering_base import clustering_result, clustering_job
6 from grid_sweep import grid_sweep
7
8 @dataclass
9 class kmeans_clustering_result(clustering_result):
10     n_clusters:int
11     n_init:int
12     max_iter:int
13
14 class kmeans_job(clustering_job):
15     def __init__(self, infile, outpath, name, n_clusters, n_init=10, max_iter=100000, random_state=1968, verbose=True):
16         super().__init__(infile,
17                          outpath,
18                          name,
19                          call=kmeans_job._kmeans_clustering,
20                          n_clusters=n_clusters,
21                          n_init=n_init,
22                          max_iter=max_iter,
23                          random_state=random_state,
24                          verbose=verbose)
25
26         self.n_clusters=n_clusters
27         self.n_init=n_init
28         self.max_iter=max_iter
29
30     def _kmeans_clustering(mat, *args, **kwargs):
31
32         clustering = KMeans(*args,
33                             **kwargs,
34                             ).fit(mat)
35
36         return clustering
37
38
39     def get_info(self):
40         result = super().get_info()
41         self.result = kmeans_clustering_result(**result.__dict__,
42                                                n_init=self.n_init,
43                                                max_iter=self.max_iter)
44         return self.result
45
46
47 class kmeans_grid_sweep(grid_sweep):
48        
49     def __init__(self,
50                  inpath,
51                  outpath,
52                  *args,
53                  **kwargs):
54         super().__init__(kmeans_job, inpath, outpath, self.namer, *args, **kwargs)
55
56     def namer(self,
57              n_clusters,
58              n_init,
59              max_iter):
60         return f"nclusters-{n_clusters}_nit-{n_init}_maxit-{max_iter}"
61
62 def test_select_kmeans_clustering():
63     inpath = "/gscratch/comdata/output/reddit_similarity/subreddit_comment_authors-tf_10k_LSI/"
64     outpath = "test_kmeans";
65     n_clusters=[200,300,400];
66     n_init=[1,2,3];
67     max_iter=[100000]
68
69     gs = kmeans_lsi_grid_sweep(inpath, 'all', outpath, n_clusters, n_init, max_iter)
70     gs.run(1)
71
72     cluster_selection_epsilons=[0,0.1,0.3,0.5];
73     cluster_selection_methods=['eom'];
74     lsi_dimensions='all'
75     gs = hdbscan_lsi_grid_sweep(inpath, "all", outpath, min_cluster_sizes, min_samples, cluster_selection_epsilons, cluster_selection_methods)
76     gs.run(20)
77     gs.save("test_hdbscan/lsi_sweep.csv")
78
79 def run_kmeans_grid_sweep(savefile, inpath, outpath,  n_clusters=[500], n_inits=[1], max_iters=[3000]):
80     """Run kmeans clustering once or more with different parameters.
81     
82     Usage:
83     kmeans_clustering.py --savefile=SAVEFILE --inpath=INPATH --outpath=OUTPATH --n_clusters=<csv number of clusters> --n_inits=<csv> --max_iters=<csv>
84
85     Keword arguments:
86     savefile: path to save the metadata and diagnostics 
87     inpath: path to feather data containing a labeled matrix of subreddit similarities.
88     outpath: path to output fit kmeans clusterings.
89     n_clusters: one or more numbers of kmeans clusters to select.
90     n_inits: one or more numbers of different initializations to use for each clustering.
91     max_iters: one or more numbers of different maximum interations. 
92     """    
93
94     obj = kmeans_grid_sweep(inpath,
95                             outpath,
96                             map(int,n_clusters),
97                             map(int,n_inits),
98                             map(int,max_iters))
99
100
101     obj.run(1)
102     obj.save(savefile)
103
104 if __name__ == "__main__":
105     fire.Fire(run_kmeans_grid_sweep)

Community Data Science Collective || Want to submit a patch?