]> code.communitydata.science - cdsc_reddit.git/blobdiff - clustering/grid_sweep.py
make pass keyword arg to dataframe.drop
[cdsc_reddit.git] / clustering / grid_sweep.py
index c0365d041480394b8cd95d258ea1279c6580c2a9..f021515e9ef3296cde2a1ae8d867a807958b50db 100644 (file)
@@ -31,3 +31,19 @@ class grid_sweep:
         outcsv = Path(outcsv)
         outcsv.parent.mkdir(parents=True, exist_ok=True)
         self.infos.to_csv(outcsv)
+
+
+class twoway_grid_sweep(grid_sweep):
+    def __init__(self, jobtype, inpath, outpath, namer, args1, args2, *args, **kwargs):
+        self.jobtype = jobtype
+        self.namer = namer
+        prod1 = product(* args1.values())
+        prod2 = product(* args2.values())
+        grid1 = [dict(zip(args1.keys(), pargs)) for pargs in prod1]
+        grid2 = [dict(zip(args2.keys(), pargs)) for pargs in prod2]
+        grid = product(grid1, grid2)
+        inpath = Path(inpath)
+        outpath = Path(outpath)
+        self.hasrun = False
+        self.grid = [(inpath,outpath,namer(**(g[0] | g[1])), g[0], g[1], *args) for g in grid]
+        self.jobs = [jobtype(*g) for g in self.grid]

Community Data Science Collective || Want to submit a patch?