X-Git-Url: https://code.communitydata.science/cdsc_reddit.git/blobdiff_plain/55b75ea6fcf421e95f4fe6b180dcec6e64676619..5a40465a629a1d7d95dbec9730d3950842bcb4f5:/clustering/grid_sweep.py diff --git a/clustering/grid_sweep.py b/clustering/grid_sweep.py index c0365d0..f021515 100644 --- a/clustering/grid_sweep.py +++ b/clustering/grid_sweep.py @@ -31,3 +31,19 @@ class grid_sweep: outcsv = Path(outcsv) outcsv.parent.mkdir(parents=True, exist_ok=True) self.infos.to_csv(outcsv) + + +class twoway_grid_sweep(grid_sweep): + def __init__(self, jobtype, inpath, outpath, namer, args1, args2, *args, **kwargs): + self.jobtype = jobtype + self.namer = namer + prod1 = product(* args1.values()) + prod2 = product(* args2.values()) + grid1 = [dict(zip(args1.keys(), pargs)) for pargs in prod1] + grid2 = [dict(zip(args2.keys(), pargs)) for pargs in prod2] + grid = product(grid1, grid2) + inpath = Path(inpath) + outpath = Path(outpath) + self.hasrun = False + self.grid = [(inpath,outpath,namer(**(g[0] | g[1])), g[0], g[1], *args) for g in grid] + self.jobs = [jobtype(*g) for g in self.grid]