1 from sklearn.metrics import silhouette_score
2 from sklearn.cluster import AffinityPropagation
3 from functools import partial
4 from clustering import _kmeans_clustering, read_similarity_mat, sim_to_dist, process_clustering_result, clustering_result
5 from dataclasses import dataclass
6 from multiprocessing import Pool, cpu_count, Array, Process
7 from pathlib import Path
8 from itertools import product, starmap
15 class kmeans_clustering_result(clustering_result):
20 # silhouette is the only one that doesn't need the feature matrix. So it's probably the only one that's worth trying.
22 def do_clustering(n_clusters, n_init, name, mat, subreddits, max_iter, outdir:Path, random_state, verbose, alt_mat, overwrite=False):
24 name = f"damping-{damping}_convergenceIter-{convergence_iter}_preferenceQuantile-{preference_quantile}"
27 outpath = outdir / (str(name) + ".feather")
29 mat = sim_to_dist(mat)
30 clustering = _kmeans_clustering(mat, outpath, n_clusters, n_init, max_iter, random_state, verbose)
32 outpath.parent.mkdir(parents=True,exist_ok=True)
33 cluster_data.to_feather(outpath)
34 cluster_data = process_clustering_result(clustering, subreddits)
37 score = silhouette_score(mat, clustering.labels_, metric='precomputed')
41 if alt_mat is not None:
42 alt_distances = sim_to_dist(alt_mat)
44 alt_score = silhouette_score(alt_mat, clustering.labels_, metric='precomputed')
48 res = kmeans_clustering_result(outpath=outpath,
50 n_clusters=n_clusters,
52 silhouette_score=score,
53 alt_silhouette_score=score,
59 # alt similiarities is for checking the silhouette coefficient of an alternative measure of similarity (e.g., topic similarities for user clustering).
60 def select_kmeans_clustering(similarities, outdir, outinfo, n_clusters=[1000], max_iter=100000, n_init=10, random_state=1968, verbose=True, alt_similarities=None):
62 n_clusters = list(map(int,n_clusters))
63 n_init = list(map(int,n_init))
65 if type(outdir) is str:
68 outdir.mkdir(parents=True,exist_ok=True)
70 subreddits, mat = read_similarity_mat(similarities,use_threads=True)
72 if alt_similarities is not None:
73 alt_mat = read_similarity_mat(alt_similarities,use_threads=True)
77 # get list of tuples: the combinations of hyperparameters
78 hyper_grid = product(n_clusters, n_init)
79 hyper_grid = (t + (str(i),) for i, t in enumerate(hyper_grid))
81 _do_clustering = partial(do_clustering, mat=mat, subreddits=subreddits, outdir=outdir, max_iter=max_iter, random_state=random_state, verbose=verbose, alt_mat=alt_mat)
84 print("running clustering selection")
85 clustering_data = starmap(_do_clustering, hyper_grid)
86 clustering_data = pd.DataFrame(list(clustering_data))
87 clustering_data.to_csv(outinfo)
89 return clustering_data
91 if __name__ == "__main__":
92 x = fire.Fire(select_kmeans_clustering)