]> code.communitydata.science - cdsc_reddit.git/blob - clustering/grid_sweep.py
Merge branch 'master' of code:cdsc_reddit into excise_reindex
[cdsc_reddit.git] / clustering / grid_sweep.py
1 from pathlib import Path
2 from multiprocessing import Pool, cpu_count
3 from itertools import product, chain
4 import pandas as pd
5
6 class grid_sweep:
7     def __init__(self, jobtype, inpath, outpath, namer, *args):
8         self.jobtype = jobtype
9         self.namer = namer
10         print(*args)
11         grid = list(product(*args))
12         inpath = Path(inpath)
13         outpath = Path(outpath)
14         self.hasrun = False
15         self.grid = [(inpath,outpath,namer(*g)) + g for g in grid]
16         self.jobs = [jobtype(*g) for g in self.grid]
17
18     def run(self, cores=20):
19         if cores is not None and cores > 1:
20             with Pool(cores) as pool:
21                 infos = pool.map(self.jobtype.get_info, self.jobs)
22         else:
23             infos = map(self.jobtype.get_info, self.jobs)
24
25         self.infos = pd.DataFrame(infos)
26         self.hasrun = True
27
28     def save(self, outcsv):
29         if not self.hasrun:
30             self.run()
31         outcsv = Path(outcsv)
32         outcsv.parent.mkdir(parents=True, exist_ok=True)
33         self.infos.to_csv(outcsv)

Community Data Science Collective || Want to submit a patch?