X-Git-Url: https://code.communitydata.science/cdsc_reddit.git/blobdiff_plain/8a2248fae1ee5818576b9a8f2849d1ad0efd8187..87ffaa6858919bd830694d60dd4fc7b1857b462a:/clustering/kmeans_clustering.py?ds=sidebyside diff --git a/clustering/kmeans_clustering.py b/clustering/kmeans_clustering.py index e41b88b..211b666 100644 --- a/clustering/kmeans_clustering.py +++ b/clustering/kmeans_clustering.py @@ -1,11 +1,9 @@ from sklearn.cluster import KMeans import fire from pathlib import Path -from multiprocessing import cpu_count from dataclasses import dataclass -from clustering_base import sim_to_dist, process_clustering_result, clustering_result, read_similarity_mat -from clustering_base import lsi_result_mixin, lsi_mixin, clustering_job, grid_sweep, lsi_grid_sweep - +from clustering_base import clustering_result, clustering_job +from grid_sweep import grid_sweep @dataclass class kmeans_clustering_result(clustering_result): @@ -13,10 +11,6 @@ class kmeans_clustering_result(clustering_result): n_init:int max_iter:int -@dataclass -class kmeans_clustering_result_lsi(kmeans_clustering_result, lsi_result_mixin): - pass - class kmeans_job(clustering_job): def __init__(self, infile, outpath, name, n_clusters, n_init=10, max_iter=100000, random_state=1968, verbose=True): super().__init__(infile, @@ -45,28 +39,13 @@ class kmeans_job(clustering_job): def get_info(self): result = super().get_info() self.result = kmeans_clustering_result(**result.__dict__, - n_init=n_init, - max_iter=max_iter) + n_init=self.n_init, + max_iter=self.max_iter) return self.result -class kmeans_lsi_job(kmeans_job, lsi_mixin): - def __init__(self, infile, outpath, name, lsi_dims, *args, **kwargs): - super().__init__(infile, - outpath, - name, - *args, - **kwargs) - super().set_lsi_dims(lsi_dims) - - def get_info(self): - result = super().get_info() - self.result = kmeans_clustering_result_lsi(**result.__dict__, - lsi_dimensions=self.lsi_dims) - return self.result - - class kmeans_grid_sweep(grid_sweep): + def __init__(self, inpath, outpath, @@ -80,49 +59,7 @@ class kmeans_grid_sweep(grid_sweep): max_iter): return f"nclusters-{n_clusters}_nit-{n_init}_maxit-{max_iter}" -class _kmeans_lsi_grid_sweep(grid_sweep): - def __init__(self, - inpath, - outpath, - lsi_dim, - *args, - **kwargs): - self.lsi_dim = lsi_dim - self.jobtype = kmeans_lsi_job - super().__init__(self.jobtype, inpath, outpath, self.namer, self.lsi_dim, *args, **kwargs) - - def namer(self, *args, **kwargs): - s = kmeans_grid_sweep.namer(self, *args[1:], **kwargs) - s += f"_lsi-{self.lsi_dim}" - return s - -class kmeans_lsi_grid_sweep(lsi_grid_sweep): - def __init__(self, - inpath, - lsi_dims, - outpath, - n_clusters, - n_inits, - max_iters - ): - - super().__init__(kmeans_lsi_job, - _kmeans_lsi_grid_sweep, - inpath, - lsi_dims, - outpath, - n_clusters, - n_inits, - max_iters) - def test_select_kmeans_clustering(): - # select_hdbscan_clustering("/gscratch/comdata/output/reddit_similarity/subreddit_comment_authors-tf_30k_LSI", - # "test_hdbscan_author30k", - # min_cluster_sizes=[2], - # min_samples=[1,2], - # cluster_selection_epsilons=[0,0.05,0.1,0.15], - # cluster_selection_methods=['eom','leaf'], - # lsi_dimensions='all') inpath = "/gscratch/comdata/output/reddit_similarity/subreddit_comment_authors-tf_10k_LSI/" outpath = "test_kmeans"; n_clusters=[200,300,400]; @@ -139,10 +76,30 @@ def test_select_kmeans_clustering(): gs.run(20) gs.save("test_hdbscan/lsi_sweep.csv") +def run_kmeans_grid_sweep(savefile, inpath, outpath, n_clusters=[500], n_inits=[1], max_iters=[3000]): + """Run kmeans clustering once or more with different parameters. + + Usage: + kmeans_clustering.py --savefile=SAVEFILE --inpath=INPATH --outpath=OUTPATH --n_clusters= --n_inits= --max_iters= + + Keword arguments: + savefile: path to save the metadata and diagnostics + inpath: path to feather data containing a labeled matrix of subreddit similarities. + outpath: path to output fit kmeans clusterings. + n_clusters: one or more numbers of kmeans clusters to select. + n_inits: one or more numbers of different initializations to use for each clustering. + max_iters: one or more numbers of different maximum interations. + """ -if __name__ == "__main__": + obj = kmeans_grid_sweep(inpath, + outpath, + map(int,n_clusters), + map(int,n_inits), + map(int,max_iters)) - fire.Fire{'grid_sweep':kmeans_grid_sweep, - 'grid_sweep_lsi':kmeans_lsi_grid_sweep - 'cluster':kmeans_job, - 'cluster_lsi':kmeans_lsi_job} + + obj.run(1) + obj.save(savefile) + +if __name__ == "__main__": + fire.Fire(run_kmeans_grid_sweep)