]> code.communitydata.science - cdsc_reddit.git/blob - clustering/affinity_clustering.py
Merge branch 'excise_reindex' of code:cdsc_reddit into excise_reindex
[cdsc_reddit.git] / clustering / affinity_clustering.py
1 from sklearn.cluster import AffinityPropagation
2 from dataclasses import dataclass
3 from clustering_base import clustering_result, clustering_job
4 from grid_sweep import grid_sweep
5 from pathlib import Path
6 from itertools import product, starmap
7 import fire
8 import sys
9 import numpy as np
10
11 # silhouette is the only one that doesn't need the feature matrix. So it's probably the only one that's worth trying. 
12 @dataclass
13 class affinity_clustering_result(clustering_result):
14     damping:float
15     convergence_iter:int
16     preference_quantile:float
17     preference:float
18     max_iter:int
19
20 class affinity_job(clustering_job):
21     def __init__(self, infile, outpath, name, damping=0.9, max_iter=100000, convergence_iter=30, preference_quantile=0.5, random_state=1968, verbose=True):
22         super().__init__(infile,
23                          outpath,
24                          name,
25                          call=self._affinity_clustering,
26                          preference_quantile=preference_quantile,
27                          damping=damping,
28                          max_iter=max_iter,
29                          convergence_iter=convergence_iter,
30                          random_state=1968,
31                          verbose=verbose)
32         self.damping=damping
33         self.max_iter=max_iter
34         self.convergence_iter=convergence_iter
35         self.preference_quantile=preference_quantile
36
37     def _affinity_clustering(self, mat, preference_quantile, *args, **kwargs):
38         mat = 1-mat
39         preference = np.quantile(mat, preference_quantile)
40         self.preference = preference
41         print(f"preference is {preference}")
42         print("data loaded")
43         sys.stdout.flush()
44         clustering = AffinityPropagation(*args,
45                                          preference=preference,
46                                          affinity='precomputed',
47                                          copy=False,
48                                          **kwargs).fit(mat)
49         return clustering
50
51     def get_info(self):
52         result = super().get_info()
53         self.result=affinity_clustering_result(**result.__dict__,
54                                                damping=self.damping,
55                                                max_iter=self.max_iter,
56                                                convergence_iter=self.convergence_iter,
57                                                preference_quantile=self.preference_quantile,
58                                                preference=self.preference)
59
60         return self.result
61
62 class affinity_grid_sweep(grid_sweep):
63     def __init__(self,
64                  inpath,
65                  outpath,
66                  *args,
67                  **kwargs):
68
69         super().__init__(affinity_job,
70                          _afffinity_grid_sweep,
71                          inpath,
72                          outpath,
73                          self.namer,
74                          *args,
75                          **kwargs)
76     def namer(self,
77               damping,
78               max_iter,
79               convergence_iter,
80               preference_quantile):
81
82         return f"damp-{damping}_maxit-{max_iter}_convit-{convergence_iter}_prefq-{preference_quantile}"
83
84 def run_affinity_grid_sweep(savefile, inpath, outpath, dampings=[0.8], max_iters=[3000], convergence_iters=[30], preference_quantiles=[0.5],n_cores=10):
85     """Run affinity clustering once or more with different parameters.
86     
87     Usage:
88     affinity_clustering.py --savefile=SAVEFILE --inpath=INPATH --outpath=OUTPATH --max_iters=<csv> --dampings=<csv> --preference_quantiles=<csv>
89
90     Keword arguments:
91     savefile: path to save the metadata and diagnostics 
92     inpath: path to feather data containing a labeled matrix of subreddit similarities.
93     outpath: path to output fit kmeans clusterings.
94     dampings:one or more numbers in [0.5, 1). damping parameter in affinity propagatin clustering. 
95     preference_quantiles:one or more numbers in (0,1) for selecting the 'preference' parameter.
96     convergence_iters:one or more integers of number of iterations without improvement before stopping.
97     max_iters: one or more numbers of different maximum interations.
98     """
99     obj = affinity_grid_sweep(inpath,
100                          outpath,
101                          map(float,dampings),
102                          map(int,max_iters),
103                          map(int,convergence_iters),
104                          map(float,preference_quantiles))
105     obj.run(n_cores)
106     obj.save(savefile)
107     
108 def test_select_affinity_clustering():
109     # select_hdbscan_clustering("/gscratch/comdata/output/reddit_similarity/subreddit_comment_authors-tf_30k_LSI",
110     #                           "test_hdbscan_author30k",
111     #                           min_cluster_sizes=[2],
112     #                           min_samples=[1,2],
113     #                           cluster_selection_epsilons=[0,0.05,0.1,0.15],
114     #                           cluster_selection_methods=['eom','leaf'],
115     #                           lsi_dimensions='all')
116     inpath = "/gscratch/comdata/output/reddit_similarity/subreddit_comment_authors-tf_10k_LSI/"
117     outpath = "test_affinity";
118     dampings=[0.8,0.9]
119     max_iters=[100000]
120     convergence_iters=[15]
121     preference_quantiles=[0.5,0.7]
122     
123     gs = affinity_lsi_grid_sweep(inpath, 'all', outpath, dampings, max_iters, convergence_iters, preference_quantiles)
124     gs.run(20)
125     gs.save("test_affinity/lsi_sweep.csv")
126
127
128 if __name__ == "__main__":
129     fire.Fire(run_affinity_grid_sweep)

Community Data Science Collective || Want to submit a patch?