2 from dataclasses import dataclass
3 from kmeans_clustering import kmeans_job, kmeans_clustering_result, kmeans_grid_sweep
4 from lsi_base import lsi_mixin, lsi_result_mixin, lsi_grid_sweep
5 from grid_sweep import grid_sweep
8 class kmeans_clustering_result_lsi(kmeans_clustering_result, lsi_result_mixin):
11 class kmeans_lsi_job(kmeans_job, lsi_mixin):
12 def __init__(self, infile, outpath, name, lsi_dims, *args, **kwargs):
13 super().__init__(infile,
18 super().set_lsi_dims(lsi_dims)
21 result = super().get_info()
22 self.result = kmeans_clustering_result_lsi(**result.__dict__,
23 lsi_dimensions=self.lsi_dims)
26 class _kmeans_lsi_grid_sweep(grid_sweep):
35 self.lsi_dim = lsi_dim
36 self.jobtype = kmeans_lsi_job
37 super().__init__(self.jobtype, inpath, outpath, self.namer, [self.lsi_dim], *args, **kwargs)
39 def namer(self, *args, **kwargs):
40 s = kmeans_grid_sweep.namer(self, *args[1:], **kwargs)
41 s += f"_lsi-{self.lsi_dim}"
44 class kmeans_lsi_grid_sweep(lsi_grid_sweep):
55 super().__init__(kmeans_lsi_job,
56 _kmeans_lsi_grid_sweep,
64 def run_kmeans_lsi_grid_sweep(savefile, inpath, outpath, n_clusters=[500], n_inits=[1], max_iters=[3000], lsi_dimensions="all"):
65 """Run kmeans clustering once or more with different parameters.
68 kmeans_clustering_lsi.py --savefile=SAVEFILE --inpath=INPATH --outpath=OUTPATH d--lsi_dimensions=<"all"|csv number of LSI dimensions to use> --n_clusters=<csv number of clusters> --n_inits=<csv> --max_iters=<csv>
71 savefile: path to save the metadata and diagnostics
72 inpath: path to folder containing feather files with LSI similarity labeled matrices of subreddit similarities.
73 outpath: path to output fit kmeans clusterings.
74 lsi_dimensions: either "all" or one or more available lsi similarity dimensions at INPATH.
75 n_clusters: one or more numbers of kmeans clusters to select.
76 n_inits: one or more numbers of different initializations to use for each clustering.
77 max_iters: one or more numbers of different maximum interations.
80 obj = kmeans_lsi_grid_sweep(inpath,
83 list(map(int,n_clusters)),
84 list(map(int,n_inits)),
85 list(map(int,max_iters))
92 if __name__ == "__main__":
93 fire.Fire(run_kmeans_lsi_grid_sweep)