X-Git-Url: https://code.communitydata.science/cdsc_reddit.git/blobdiff_plain/f05cb962e0388feaf38aaf84f222696ab8f5f398..4cb7eeec80c5a9c8f49339acd378c515e290ed81:/clustering/kmeans_clustering_lsi.py?ds=sidebyside diff --git a/clustering/kmeans_clustering_lsi.py b/clustering/kmeans_clustering_lsi.py new file mode 100644 index 0000000..20d582b --- /dev/null +++ b/clustering/kmeans_clustering_lsi.py @@ -0,0 +1,93 @@ +import fire +from dataclasses import dataclass +from kmeans_clustering import kmeans_job, kmeans_clustering_result, kmeans_grid_sweep +from lsi_base import lsi_mixin, lsi_result_mixin, lsi_grid_sweep +from grid_sweep import grid_sweep + +@dataclass +class kmeans_clustering_result_lsi(kmeans_clustering_result, lsi_result_mixin): + pass + +class kmeans_lsi_job(kmeans_job, lsi_mixin): + def __init__(self, infile, outpath, name, lsi_dims, *args, **kwargs): + super().__init__(infile, + outpath, + name, + *args, + **kwargs) + super().set_lsi_dims(lsi_dims) + + def get_info(self): + result = super().get_info() + self.result = kmeans_clustering_result_lsi(**result.__dict__, + lsi_dimensions=self.lsi_dims) + return self.result + +class _kmeans_lsi_grid_sweep(grid_sweep): + def __init__(self, + inpath, + outpath, + lsi_dim, + *args, + **kwargs): + print(args) + print(kwargs) + self.lsi_dim = lsi_dim + self.jobtype = kmeans_lsi_job + super().__init__(self.jobtype, inpath, outpath, self.namer, self.lsi_dim, *args, **kwargs) + + def namer(self, *args, **kwargs): + s = kmeans_grid_sweep.namer(self, *args[1:], **kwargs) + s += f"_lsi-{self.lsi_dim}" + return s + +class kmeans_lsi_grid_sweep(lsi_grid_sweep): + + def __init__(self, + inpath, + lsi_dims, + outpath, + n_clusters, + n_inits, + max_iters + ): + + super().__init__(kmeans_lsi_job, + _kmeans_lsi_grid_sweep, + inpath, + lsi_dims, + outpath, + n_clusters, + n_inits, + max_iters) + +def run_kmeans_lsi_grid_sweep(savefile, inpath, outpath, n_clusters=[500], n_inits=[1], max_iters=[3000], lsi_dimensions="all"): + """Run kmeans clustering once or more with different parameters. + + Usage: + kmeans_clustering_lsi.py --savefile=SAVEFILE --inpath=INPATH --outpath=OUTPATH d--lsi_dimensions=<"all"|csv number of LSI dimensions to use> --n_clusters= --n_inits= --max_iters= + + Keword arguments: + savefile: path to save the metadata and diagnostics + inpath: path to folder containing feather files with LSI similarity labeled matrices of subreddit similarities. + outpath: path to output fit kmeans clusterings. + lsi_dimensions: either "all" or one or more available lsi similarity dimensions at INPATH. + n_clusters: one or more numbers of kmeans clusters to select. + n_inits: one or more numbers of different initializations to use for each clustering. + max_iters: one or more numbers of different maximum interations. + """ + + obj = kmeans_lsi_grid_sweep(inpath, + lsi_dimensions, + outpath, + list(map(int,n_clusters)), + list(map(int,n_inits)), + list(map(int,max_iters)) + ) + + obj.run(1) + obj.save(savefile) + + +if __name__ == "__main__": + fire.Fire(run_kmeans_lsi_grid_sweep)