]> code.communitydata.science - cdsc_reddit.git/blob - clustering/lsi_base.py
f07bca6f01d61f8d6338f4f4adac6a4cf9536046
[cdsc_reddit.git] / clustering / lsi_base.py
1 from clustering_base import clustering_job, clustering_result
2 from grid_sweep import grid_sweep
3 from dataclasses import dataclass
4 from itertools import chain
5 from pathlib import Path
6
7 class lsi_mixin():
8     def set_lsi_dims(self, lsi_dims):
9         self.lsi_dims = lsi_dims
10
11 @dataclass
12 class lsi_result_mixin:
13     lsi_dimensions:int
14
15 class lsi_grid_sweep(grid_sweep):
16     def __init__(self, jobtype, subsweep, inpath, lsi_dimensions, outpath, *args, **kwargs):
17         self.jobtype = jobtype
18         self.subsweep = subsweep
19         inpath = Path(inpath)
20         if lsi_dimensions == 'all':
21             lsi_paths = list(inpath.glob("*"))
22         else:
23             lsi_paths = [inpath / (str(dim) + '.feather') for dim in lsi_dimensions]
24
25         lsi_nums = [int(p.stem) for p in lsi_paths]
26         self.hasrun = False
27         self.subgrids = [self.subsweep(lsi_path, outpath,  lsi_dim, *args, **kwargs) for lsi_dim, lsi_path in zip(lsi_nums, lsi_paths)]
28         self.jobs = list(chain(*map(lambda gs: gs.jobs, self.subgrids)))

Community Data Science Collective || Want to submit a patch?