1 from clustering_base import clustering_job, clustering_result
 
   2 from grid_sweep import grid_sweep
 
   3 from dataclasses import dataclass
 
   4 from itertools import chain
 
   5 from pathlib import Path
 
   8     def set_lsi_dims(self, lsi_dims):
 
   9         self.lsi_dims = lsi_dims
 
  12 class lsi_result_mixin:
 
  15 class lsi_grid_sweep(grid_sweep):
 
  16     def __init__(self, jobtype, subsweep, inpath, lsi_dimensions, outpath, *args, **kwargs):
 
  17         self.jobtype = jobtype
 
  18         self.subsweep = subsweep
 
  20         if lsi_dimensions == 'all':
 
  21             lsi_paths = list(inpath.glob("*.feather"))
 
  23             lsi_paths = [inpath / (str(dim) + '.feather') for dim in lsi_dimensions]
 
  26         lsi_nums = [int(p.stem) for p in lsi_paths]
 
  28         self.subgrids = [self.subsweep(lsi_path, outpath,  lsi_dim, *args, **kwargs) for lsi_dim, lsi_path in zip(lsi_nums, lsi_paths)]
 
  29         self.jobs = list(chain(*map(lambda gs: gs.jobs, self.subgrids)))