X-Git-Url: https://code.communitydata.science/cdsc_reddit.git/blobdiff_plain/7b14db67de8650e4858d3f102fbeab813a30ee29..refs/heads/icwsm_dataverse:/clustering/grid_sweep.py diff --git a/clustering/grid_sweep.py b/clustering/grid_sweep.py index 636dcbc..f021515 100644 --- a/clustering/grid_sweep.py +++ b/clustering/grid_sweep.py @@ -7,6 +7,7 @@ class grid_sweep: def __init__(self, jobtype, inpath, outpath, namer, *args): self.jobtype = jobtype self.namer = namer + print(*args) grid = list(product(*args)) inpath = Path(inpath) outpath = Path(outpath) @@ -30,3 +31,19 @@ class grid_sweep: outcsv = Path(outcsv) outcsv.parent.mkdir(parents=True, exist_ok=True) self.infos.to_csv(outcsv) + + +class twoway_grid_sweep(grid_sweep): + def __init__(self, jobtype, inpath, outpath, namer, args1, args2, *args, **kwargs): + self.jobtype = jobtype + self.namer = namer + prod1 = product(* args1.values()) + prod2 = product(* args2.values()) + grid1 = [dict(zip(args1.keys(), pargs)) for pargs in prod1] + grid2 = [dict(zip(args2.keys(), pargs)) for pargs in prod2] + grid = product(grid1, grid2) + inpath = Path(inpath) + outpath = Path(outpath) + self.hasrun = False + self.grid = [(inpath,outpath,namer(**(g[0] | g[1])), g[0], g[1], *args) for g in grid] + self.jobs = [jobtype(*g) for g in self.grid]