]> code.communitydata.science - cdsc_reddit.git/blobdiff - clustering/lsi_base.py
add support for umap->hdbscan clustering method
[cdsc_reddit.git] / clustering / lsi_base.py
index 80b7101a3723a3910a9b10d7c1ad64fe97db00f8..14bbfc55f8b0263209c9124afefad17dda4faaae 100644 (file)
@@ -1,5 +1,5 @@
 from clustering_base import clustering_job, clustering_result
-from grid_sweep import grid_sweep
+from grid_sweep import grid_sweep, twoway_grid_sweep
 from dataclasses import dataclass
 from itertools import chain
 from pathlib import Path
@@ -27,3 +27,18 @@ class lsi_grid_sweep(grid_sweep):
         self.hasrun = False
         self.subgrids = [self.subsweep(lsi_path, outpath,  lsi_dim, *args, **kwargs) for lsi_dim, lsi_path in zip(lsi_nums, lsi_paths)]
         self.jobs = list(chain(*map(lambda gs: gs.jobs, self.subgrids)))
+
+class twoway_lsi_grid_sweep(twoway_grid_sweep):
+    def __init__(self, jobtype, subsweep, inpath, lsi_dimensions, outpath, args1, args2, save_step1):
+        self.jobtype = jobtype
+        self.subsweep = subsweep
+        inpath = Path(inpath)
+        if lsi_dimensions == 'all':
+            lsi_paths = list(inpath.glob("*.feather"))
+        else:
+            lsi_paths = [inpath / (str(dim) + '.feather') for dim in lsi_dimensions]
+
+        lsi_nums = [int(p.stem) for p in lsi_paths]
+        self.hasrun = False
+        self.subgrids = [self.subsweep(lsi_path, outpath, lsi_dim, args1, args2, save_step1) for lsi_dim, lsi_path in zip(lsi_nums, lsi_paths)]
+        self.jobs = list(chain(*map(lambda gs: gs.jobs, self.subgrids)))

Community Data Science Collective || Want to submit a patch?