]> code.communitydata.science - cdsc_reddit.git/blob - clustering/kmeans_clustering.py
8822e9f0cd67c2ece76d9552d7c2c77883bac76f
[cdsc_reddit.git] / clustering / kmeans_clustering.py
1 from sklearn.cluster import KMeans
2 import fire
3 from pathlib import Path
4 from multiprocessing import cpu_count
5 from dataclasses import dataclass
6 from clustering_base import sim_to_dist, process_clustering_result, clustering_result, read_similarity_mat
7
8 @dataclass
9 class kmeans_clustering_result(clustering_result):
10     n_clusters:int
11     n_init:int
12
13 def kmeans_clustering(similarities, *args, **kwargs):
14     subreddits, mat = read_similarity_mat(similarities)
15     mat = sim_to_dist(mat)
16     clustering = _kmeans_clustering(mat, *args, **kwargs)
17     cluster_data = process_clustering_result(clustering, subreddits)
18     return(cluster_data)
19
20 def _kmeans_clustering(mat, output, n_clusters, n_init=10, max_iter=100000, random_state=1968, verbose=True):
21
22     clustering = KMeans(n_clusters=n_clusters,
23                         n_init=n_init,
24                         max_iter=max_iter,
25                         random_state=random_state,
26                         verbose=verbose
27                         ).fit(mat)
28
29     return clustering
30
31 def do_clustering(n_clusters, n_init, name, mat, subreddits,  max_iter, outdir:Path, random_state, verbose, alt_mat, overwrite=False):
32     if name is None:
33         name = f"damping-{damping}_convergenceIter-{convergence_iter}_preferenceQuantile-{preference_quantile}"
34     print(name)
35     sys.stdout.flush()
36     outpath = outdir / (str(name) + ".feather")
37     print(outpath)
38     mat = sim_to_dist(mat)
39     clustering = _kmeans_clustering(mat, outpath, n_clusters, n_init, max_iter, random_state, verbose)
40
41     outpath.parent.mkdir(parents=True,exist_ok=True)
42     cluster_data.to_feather(outpath)
43     cluster_data = process_clustering_result(clustering, subreddits)
44
45     try: 
46         score = silhouette_score(mat, clustering.labels_, metric='precomputed')
47     except ValueError:
48         score = None
49
50     if alt_mat is not None:
51         alt_distances = sim_to_dist(alt_mat)
52         try:
53             alt_score = silhouette_score(alt_mat, clustering.labels_, metric='precomputed')
54         except ValueError:
55             alt_score = None
56     
57     res = kmeans_clustering_result(outpath=outpath,
58                                    max_iter=max_iter,
59                                    n_clusters=n_clusters,
60                                    n_init = n_init,
61                                    silhouette_score=score,
62                                    alt_silhouette_score=score,
63                                    name=str(name))
64
65     return res
66
67
68 # alt similiarities is for checking the silhouette coefficient of an alternative measure of similarity (e.g., topic similarities for user clustering).
69 def select_kmeans_clustering(similarities, outdir, outinfo, n_clusters=[1000], max_iter=100000, n_init=10, random_state=1968, verbose=True, alt_similarities=None):
70
71     n_clusters = list(map(int,n_clusters))
72     n_init  = list(map(int,n_init))
73
74     if type(outdir) is str:
75         outdir = Path(outdir)
76
77     outdir.mkdir(parents=True,exist_ok=True)
78
79     subreddits, mat = read_similarity_mat(similarities,use_threads=True)
80
81     if alt_similarities is not None:
82         alt_mat = read_similarity_mat(alt_similarities,use_threads=True)
83     else:
84         alt_mat = None
85
86     # get list of tuples: the combinations of hyperparameters
87     hyper_grid = product(n_clusters, n_init)
88     hyper_grid = (t + (str(i),) for i, t in enumerate(hyper_grid))
89
90     _do_clustering = partial(do_clustering,  mat=mat, subreddits=subreddits, outdir=outdir, max_iter=max_iter, random_state=random_state, verbose=verbose, alt_mat=alt_mat)
91
92     # call starmap
93     print("running clustering selection")
94     clustering_data = starmap(_do_clustering, hyper_grid)
95     clustering_data = pd.DataFrame(list(clustering_data))
96     clustering_data.to_csv(outinfo)
97     
98     return clustering_data
99
100 if __name__ == "__main__":
101     x = fire.Fire(select_kmeans_clustering)

Community Data Science Collective || Want to submit a patch?