1 from sklearn.cluster import KMeans
3 from pathlib import Path
4 from dataclasses import dataclass
5 from clustering_base import clustering_result, clustering_job
6 from grid_sweep import grid_sweep
9 class kmeans_clustering_result(clustering_result):
14 class kmeans_job(clustering_job):
15 def __init__(self, infile, outpath, name, n_clusters, n_init=10, max_iter=100000, random_state=1968, verbose=True):
16 super().__init__(infile,
19 call=kmeans_job._kmeans_clustering,
20 n_clusters=n_clusters,
23 random_state=random_state,
26 self.n_clusters=n_clusters
28 self.max_iter=max_iter
30 def _kmeans_clustering(mat, *args, **kwargs):
32 clustering = KMeans(*args,
40 result = super().get_info()
41 self.result = kmeans_clustering_result(**result.__dict__,
43 max_iter=self.max_iter)
47 class kmeans_grid_sweep(grid_sweep):
54 super().__init__(kmeans_job, inpath, outpath, self.namer, *args, **kwargs)
60 return f"nclusters-{n_clusters}_nit-{n_init}_maxit-{max_iter}"
62 def test_select_kmeans_clustering():
63 inpath = "/gscratch/comdata/output/reddit_similarity/subreddit_comment_authors-tf_10k_LSI/"
64 outpath = "test_kmeans";
65 n_clusters=[200,300,400];
69 gs = kmeans_lsi_grid_sweep(inpath, 'all', outpath, n_clusters, n_init, max_iter)
72 cluster_selection_epsilons=[0,0.1,0.3,0.5];
73 cluster_selection_methods=['eom'];
75 gs = hdbscan_lsi_grid_sweep(inpath, "all", outpath, min_cluster_sizes, min_samples, cluster_selection_epsilons, cluster_selection_methods)
77 gs.save("test_hdbscan/lsi_sweep.csv")
79 def run_kmeans_grid_sweep(savefile, inpath, outpath, n_clusters=[500], n_inits=[1], max_iters=[3000]):
80 """Run kmeans clustering once or more with different parameters.
83 kmeans_clustering.py --savefile=SAVEFILE --inpath=INPATH --outpath=OUTPATH --n_clusters=<csv number of clusters> --n_inits=<csv> --max_iters=<csv>
86 savefile: path to save the metadata and diagnostics
87 inpath: path to feather data containing a labeled matrix of subreddit similarities.
88 outpath: path to output fit kmeans clusterings.
89 n_clusters: one or more numbers of kmeans clusters to select.
90 n_inits: one or more numbers of different initializations to use for each clustering.
91 max_iters: one or more numbers of different maximum interations.
94 obj = kmeans_grid_sweep(inpath,
104 if __name__ == "__main__":
105 fire.Fire(run_kmeans_grid_sweep)