]> code.communitydata.science - cdsc_reddit.git/blob - clustering/kmeans_clustering_lsi.py
update clustering scripts
[cdsc_reddit.git] / clustering / kmeans_clustering_lsi.py
1 import fire
2 from dataclasses import dataclass
3 from kmeans_clustering import kmeans_job, kmeans_clustering_result, kmeans_grid_sweep
4 from lsi_base import lsi_mixin, lsi_result_mixin, lsi_grid_sweep
5 from grid_sweep import grid_sweep
6
7 @dataclass
8 class kmeans_clustering_result_lsi(kmeans_clustering_result, lsi_result_mixin):
9     pass
10
11 class kmeans_lsi_job(kmeans_job, lsi_mixin):
12     def __init__(self, infile, outpath, name, lsi_dims, *args, **kwargs):
13         super().__init__(infile,
14                          outpath,
15                          name,
16                          *args,
17                          **kwargs)
18         super().set_lsi_dims(lsi_dims)
19
20     def get_info(self):
21         result = super().get_info()
22         self.result = kmeans_clustering_result_lsi(**result.__dict__,
23                                                    lsi_dimensions=self.lsi_dims)
24         return self.result
25
26 class _kmeans_lsi_grid_sweep(grid_sweep):
27     def __init__(self,
28                  inpath,
29                  outpath,
30                  lsi_dim,
31                  *args,
32                  **kwargs):
33         print(args)
34         print(kwargs)
35         self.lsi_dim = lsi_dim
36         self.jobtype = kmeans_lsi_job
37         super().__init__(self.jobtype, inpath, outpath, self.namer, [self.lsi_dim], *args, **kwargs)
38
39     def namer(self, *args, **kwargs):
40         s = kmeans_grid_sweep.namer(self, *args[1:], **kwargs)
41         s += f"_lsi-{self.lsi_dim}"
42         return s
43
44 class kmeans_lsi_grid_sweep(lsi_grid_sweep):
45
46     def __init__(self,
47                  inpath,
48                  lsi_dims,
49                  outpath,
50                  n_clusters,
51                  n_inits,
52                  max_iters
53                  ):
54
55         super().__init__(kmeans_lsi_job,
56                          _kmeans_lsi_grid_sweep,
57                          inpath,
58                          lsi_dims,
59                          outpath,
60                          n_clusters,
61                          n_inits,
62                          max_iters)
63
64 def run_kmeans_lsi_grid_sweep(savefile, inpath, outpath,  n_clusters=[500], n_inits=[1], max_iters=[3000], lsi_dimensions="all"):
65     """Run kmeans clustering once or more with different parameters.
66     
67     Usage:
68     kmeans_clustering_lsi.py --savefile=SAVEFILE --inpath=INPATH --outpath=OUTPATH d--lsi_dimensions=<"all"|csv number of LSI dimensions to use> --n_clusters=<csv number of clusters> --n_inits=<csv> --max_iters=<csv>
69
70     Keword arguments:
71     savefile: path to save the metadata and diagnostics 
72     inpath: path to folder containing feather files with LSI similarity labeled matrices of subreddit similarities.
73     outpath: path to output fit kmeans clusterings.
74     lsi_dimensions: either "all" or one or more available lsi similarity dimensions at INPATH.
75     n_clusters: one or more numbers of kmeans clusters to select.
76     n_inits: one or more numbers of different initializations to use for each clustering.
77     max_iters: one or more numbers of different maximum interations. 
78     """    
79
80     obj = kmeans_lsi_grid_sweep(inpath,
81                                 lsi_dimensions,
82                                 outpath,
83                                 list(map(int,n_clusters)),
84                                 list(map(int,n_inits)),
85                                 list(map(int,max_iters))
86                                 )
87
88     obj.run(1)
89     obj.save(savefile)
90
91
92 if __name__ == "__main__":
93     fire.Fire(run_kmeans_lsi_grid_sweep)

Community Data Science Collective || Want to submit a patch?