X-Git-Url: https://code.communitydata.science/ml_measurement_error_public.git/blobdiff_plain/56dfdacc2ef643b4810bec2c6123b7a9f71f313c..003733f22f42b435315803fd5f47d483c712d72d:/simulations/grid_sweep.py diff --git a/simulations/grid_sweep.py b/simulations/grid_sweep.py new file mode 100755 index 0000000..86312ea --- /dev/null +++ b/simulations/grid_sweep.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 + +import fire +from itertools import product + +def main(command, arg_dict, outfile): + keys = [] + values = [] + + for i,p in enumerate(arg_dict.items()): + k,v = p + keys.append(k) + values.append(v) + command = command + f" --{k} {{{i}}} " + + args_list = product(*values) + + with open(outfile,'w') as of: + for args in args_list: + of.write(command.format(*args) + '\n') + +if __name__ == '__main__': + fire.Fire(main)