X-Git-Url: https://code.communitydata.science/cdsc_reddit.git/blobdiff_plain/f05cb962e0388feaf38aaf84f222696ab8f5f398..4cb7eeec80c5a9c8f49339acd378c515e290ed81:/clustering/grid_sweep.py diff --git a/clustering/grid_sweep.py b/clustering/grid_sweep.py new file mode 100644 index 0000000..636dcbc --- /dev/null +++ b/clustering/grid_sweep.py @@ -0,0 +1,32 @@ +from pathlib import Path +from multiprocessing import Pool, cpu_count +from itertools import product, chain +import pandas as pd + +class grid_sweep: + def __init__(self, jobtype, inpath, outpath, namer, *args): + self.jobtype = jobtype + self.namer = namer + grid = list(product(*args)) + inpath = Path(inpath) + outpath = Path(outpath) + self.hasrun = False + self.grid = [(inpath,outpath,namer(*g)) + g for g in grid] + self.jobs = [jobtype(*g) for g in self.grid] + + def run(self, cores=20): + if cores is not None and cores > 1: + with Pool(cores) as pool: + infos = pool.map(self.jobtype.get_info, self.jobs) + else: + infos = map(self.jobtype.get_info, self.jobs) + + self.infos = pd.DataFrame(infos) + self.hasrun = True + + def save(self, outcsv): + if not self.hasrun: + self.run() + outcsv = Path(outcsv) + outcsv.parent.mkdir(parents=True, exist_ok=True) + self.infos.to_csv(outcsv)