]> code.communitydata.science - cdsc_reddit.git/blobdiff - clustering/kmeans_clustering_lsi.py
Refactor to make a decent api.
[cdsc_reddit.git] / clustering / kmeans_clustering_lsi.py
diff --git a/clustering/kmeans_clustering_lsi.py b/clustering/kmeans_clustering_lsi.py
new file mode 100644 (file)
index 0000000..20d582b
--- /dev/null
@@ -0,0 +1,93 @@
+import fire
+from dataclasses import dataclass
+from kmeans_clustering import kmeans_job, kmeans_clustering_result, kmeans_grid_sweep
+from lsi_base import lsi_mixin, lsi_result_mixin, lsi_grid_sweep
+from grid_sweep import grid_sweep
+
+@dataclass
+class kmeans_clustering_result_lsi(kmeans_clustering_result, lsi_result_mixin):
+    pass
+
+class kmeans_lsi_job(kmeans_job, lsi_mixin):
+    def __init__(self, infile, outpath, name, lsi_dims, *args, **kwargs):
+        super().__init__(infile,
+                         outpath,
+                         name,
+                         *args,
+                         **kwargs)
+        super().set_lsi_dims(lsi_dims)
+
+    def get_info(self):
+        result = super().get_info()
+        self.result = kmeans_clustering_result_lsi(**result.__dict__,
+                                                   lsi_dimensions=self.lsi_dims)
+        return self.result
+
+class _kmeans_lsi_grid_sweep(grid_sweep):
+    def __init__(self,
+                 inpath,
+                 outpath,
+                 lsi_dim,
+                 *args,
+                 **kwargs):
+        print(args)
+        print(kwargs)
+        self.lsi_dim = lsi_dim
+        self.jobtype = kmeans_lsi_job
+        super().__init__(self.jobtype, inpath, outpath, self.namer, self.lsi_dim, *args, **kwargs)
+
+    def namer(self, *args, **kwargs):
+        s = kmeans_grid_sweep.namer(self, *args[1:], **kwargs)
+        s += f"_lsi-{self.lsi_dim}"
+        return s
+
+class kmeans_lsi_grid_sweep(lsi_grid_sweep):
+
+    def __init__(self,
+                 inpath,
+                 lsi_dims,
+                 outpath,
+                 n_clusters,
+                 n_inits,
+                 max_iters
+                 ):
+
+        super().__init__(kmeans_lsi_job,
+                         _kmeans_lsi_grid_sweep,
+                         inpath,
+                         lsi_dims,
+                         outpath,
+                         n_clusters,
+                         n_inits,
+                         max_iters)
+
+def run_kmeans_lsi_grid_sweep(savefile, inpath, outpath,  n_clusters=[500], n_inits=[1], max_iters=[3000], lsi_dimensions="all"):
+    """Run kmeans clustering once or more with different parameters.
+    
+    Usage:
+    kmeans_clustering_lsi.py --savefile=SAVEFILE --inpath=INPATH --outpath=OUTPATH d--lsi_dimensions=<"all"|csv number of LSI dimensions to use> --n_clusters=<csv number of clusters> --n_inits=<csv> --max_iters=<csv>
+
+    Keword arguments:
+    savefile: path to save the metadata and diagnostics 
+    inpath: path to folder containing feather files with LSI similarity labeled matrices of subreddit similarities.
+    outpath: path to output fit kmeans clusterings.
+    lsi_dimensions: either "all" or one or more available lsi similarity dimensions at INPATH.
+    n_clusters: one or more numbers of kmeans clusters to select.
+    n_inits: one or more numbers of different initializations to use for each clustering.
+    max_iters: one or more numbers of different maximum interations. 
+    """    
+
+    obj = kmeans_lsi_grid_sweep(inpath,
+                                lsi_dimensions,
+                                outpath,
+                                list(map(int,n_clusters)),
+                                list(map(int,n_inits)),
+                                list(map(int,max_iters))
+                                )
+
+    obj.run(1)
+    obj.save(savefile)
+
+
+if __name__ == "__main__":
+    fire.Fire(run_kmeans_lsi_grid_sweep)

Community Data Science Collective || Want to submit a patch?