X-Git-Url: https://code.communitydata.science/cdsc_reddit.git/blobdiff_plain/55b75ea6fcf421e95f4fe6b180dcec6e64676619..5a40465a629a1d7d95dbec9730d3950842bcb4f5:/clustering/lsi_base.py diff --git a/clustering/lsi_base.py b/clustering/lsi_base.py index 80b7101..14bbfc5 100644 --- a/clustering/lsi_base.py +++ b/clustering/lsi_base.py @@ -1,5 +1,5 @@ from clustering_base import clustering_job, clustering_result -from grid_sweep import grid_sweep +from grid_sweep import grid_sweep, twoway_grid_sweep from dataclasses import dataclass from itertools import chain from pathlib import Path @@ -27,3 +27,18 @@ class lsi_grid_sweep(grid_sweep): self.hasrun = False self.subgrids = [self.subsweep(lsi_path, outpath, lsi_dim, *args, **kwargs) for lsi_dim, lsi_path in zip(lsi_nums, lsi_paths)] self.jobs = list(chain(*map(lambda gs: gs.jobs, self.subgrids))) + +class twoway_lsi_grid_sweep(twoway_grid_sweep): + def __init__(self, jobtype, subsweep, inpath, lsi_dimensions, outpath, args1, args2, save_step1): + self.jobtype = jobtype + self.subsweep = subsweep + inpath = Path(inpath) + if lsi_dimensions == 'all': + lsi_paths = list(inpath.glob("*.feather")) + else: + lsi_paths = [inpath / (str(dim) + '.feather') for dim in lsi_dimensions] + + lsi_nums = [int(p.stem) for p in lsi_paths] + self.hasrun = False + self.subgrids = [self.subsweep(lsi_path, outpath, lsi_dim, args1, args2, save_step1) for lsi_dim, lsi_path in zip(lsi_nums, lsi_paths)] + self.jobs = list(chain(*map(lambda gs: gs.jobs, self.subgrids)))