X-Git-Url: https://code.communitydata.science/covid19.git/blobdiff_plain/f548eeedd59a1d7d99deb8864c7d11947271e426..6493361fbd95f44a3b27131f4f79329d40e61c90:/transliterations/src/wikidata_transliterations.py diff --git a/transliterations/src/wikidata_transliterations.py b/transliterations/src/wikidata_transliterations.py index d878354..1ac956c 100644 --- a/transliterations/src/wikidata_transliterations.py +++ b/transliterations/src/wikidata_transliterations.py @@ -2,6 +2,7 @@ from wikidata_api_calls import run_sparql_query from itertools import chain, islice import csv from json import JSONDecodeError +from os import path class LabelData: __slots__ = ['entityid','label','langcode','is_alt'] @@ -23,7 +24,7 @@ def GetAllLabels(in_csvs, outfile, topNs): def load_entity_ids(in_csv, topN=5): with open(in_csv,'r',newline='') as infile: - reader = csv.DictReader(infile) + reader = list(csv.DictReader(infile)) for row in reader: if int(row['search_position']) < topN: yield row["entityid"] @@ -84,6 +85,14 @@ def GetEntityLabels(entityids): return chain(*calls) +def find_new_output_file(output, i = 1): + if path.exists(output): + name, ext = path.splitext(output) + + return find_new_output_file(f"{name}_{i}.{ext}", i+1) + else: + return output + if __name__ == "__main__": import argparse parser = argparse.ArgumentParser("Use wikidata to find transliterations of terms") @@ -93,4 +102,6 @@ if __name__ == "__main__": args = parser.parse_args() - GetAllLabels(args.inputs, args.output, topNs=args.topN) + output = find_new_output_file(args.output) + + GetAllLabels(args.inputs, output, topNs=args.topN)