]> code.communitydata.science - cdsc_reddit.git/blob - clustering/grid_sweep.py
Merge branch 'excise_reindex' of code:cdsc_reddit into excise_reindex
[cdsc_reddit.git] / clustering / grid_sweep.py
1 from pathlib import Path
2 from multiprocessing import Pool, cpu_count
3 from itertools import product, chain
4 import pandas as pd
5
6 class grid_sweep:
7     def __init__(self, jobtype, inpath, outpath, namer, *args):
8         self.jobtype = jobtype
9         self.namer = namer
10         grid = list(product(*args))
11         inpath = Path(inpath)
12         outpath = Path(outpath)
13         self.hasrun = False
14         self.grid = [(inpath,outpath,namer(*g)) + g for g in grid]
15         self.jobs = [jobtype(*g) for g in self.grid]
16
17     def run(self, cores=20):
18         if cores is not None and cores > 1:
19             with Pool(cores) as pool:
20                 infos = pool.map(self.jobtype.get_info, self.jobs)
21         else:
22             infos = map(self.jobtype.get_info, self.jobs)
23
24         self.infos = pd.DataFrame(infos)
25         self.hasrun = True
26
27     def save(self, outcsv):
28         if not self.hasrun:
29             self.run()
30         outcsv = Path(outcsv)
31         outcsv.parent.mkdir(parents=True, exist_ok=True)
32         self.infos.to_csv(outcsv)

Community Data Science Collective || Want to submit a patch?