]> code.communitydata.science - cdsc_reddit.git/blob - clustering/hdbscan_clustering_lsi.py
Merge branch 'excise_reindex' of code:cdsc_reddit into excise_reindex
[cdsc_reddit.git] / clustering / hdbscan_clustering_lsi.py
1 from hdbscan_clustering import hdbscan_job, hdbscan_grid_sweep, hdbscan_clustering_result
2 from lsi_base import lsi_grid_sweep, lsi_mixin, lsi_result_mixin
3 from grid_sweep import grid_sweep
4 import fire
5 from dataclasses import dataclass
6
7 @dataclass
8 class hdbscan_clustering_result_lsi(hdbscan_clustering_result, lsi_result_mixin):
9     pass 
10
11 class hdbscan_lsi_job(hdbscan_job, lsi_mixin):
12     def __init__(self, infile, outpath, name, lsi_dims, *args, **kwargs):
13         super().__init__(
14                          infile,
15                          outpath,
16                          name,
17                          *args,
18                          **kwargs)
19         super().set_lsi_dims(lsi_dims)
20
21     def get_info(self):
22         partial_result = super().get_info()
23         self.result = hdbscan_clustering_result_lsi(**partial_result.__dict__,
24                                                     lsi_dimensions=self.lsi_dims)
25         return self.result
26
27 class hdbscan_lsi_grid_sweep(lsi_grid_sweep):
28     def __init__(self,
29                  inpath,
30                  lsi_dims,
31                  outpath,
32                  min_cluster_sizes,
33                  min_samples,
34                  cluster_selection_epsilons,
35                  cluster_selection_methods
36                  ):
37
38         super().__init__(hdbscan_lsi_job,
39                          _hdbscan_lsi_grid_sweep,
40                          inpath,
41                          lsi_dims,
42                          outpath,
43                          min_cluster_sizes,
44                          min_samples,
45                          cluster_selection_epsilons,
46                          cluster_selection_methods)
47         
48
49
50 class _hdbscan_lsi_grid_sweep(grid_sweep):
51     def __init__(self,
52                  inpath,
53                  outpath,
54                  lsi_dim,
55                  *args,
56                  **kwargs):
57         print(args)
58         print(kwargs)
59
60         self.lsi_dim = lsi_dim
61         self.jobtype = hdbscan_lsi_job
62         super().__init__(self.jobtype, inpath, outpath, self.namer, [self.lsi_dim], *args, **kwargs)
63
64
65     def namer(self, *args, **kwargs):
66         s = hdbscan_grid_sweep.namer(self, *args[1:], **kwargs)
67         s += f"_lsi-{self.lsi_dim}"
68         return s
69
70 def run_hdbscan_lsi_grid_sweep(savefile, inpath, outpath,  min_cluster_sizes=[2], min_samples=[1], cluster_selection_epsilons=[0], cluster_selection_methods=[1],lsi_dimensions='all'):
71     """Run hdbscan clustering once or more with different parameters.
72     
73     Usage:
74     hdbscan_clustering_lsi --savefile=SAVEFILE --inpath=INPATH --outpath=OUTPATH --min_cluster_sizes=<csv> --min_samples=<csv> --cluster_selection_epsilons=<csv> --cluster_selection_methods=[eom]> --lsi_dimensions: either "all" or one or more available lsi similarity dimensions at INPATH.
75
76     Keword arguments:
77     savefile: path to save the metadata and diagnostics 
78     inpath: path to folder containing feather files with LSI similarity labeled matrices of subreddit similarities.
79     outpath: path to output fit clusterings.
80     min_cluster_sizes: one or more integers indicating the minumum cluster size
81     min_samples: one ore more integers indicating the minimum number of samples used in the algorithm
82     cluster_selection_epsilons: one or more similarity thresholds for transition from dbscan to hdbscan
83     cluster_selection_methods: one or more of "eom" or "leaf" eom gives larger clusters. 
84     lsi_dimensions: either "all" or one or more available lsi similarity dimensions at INPATH.
85     """    
86
87     obj = hdbscan_lsi_grid_sweep(inpath,
88                                  lsi_dimensions,
89                                  outpath,
90                                  list(map(int,min_cluster_sizes)),
91                                  list(map(int,min_samples)),
92                                  list(map(float,cluster_selection_epsilons)),
93                                  cluster_selection_methods)
94                                  
95
96     obj.run(10)
97     obj.save(savefile)
98
99
100 if __name__ == "__main__":
101     fire.Fire(run_hdbscan_lsi_grid_sweep)

Community Data Science Collective || Want to submit a patch?