]> code.communitydata.science - social-media-chapter.git/blob - code/prediction/00_ngram_extraction.py
initial import of material for public archive into git
[social-media-chapter.git] / code / prediction / 00_ngram_extraction.py
1 from time import time
2
3 from sklearn.feature_extraction.text import CountVectorizer
4 import csv
5 import argparse
6
7 n_features = 100000 # Gets the top n_features terms
8 n_samples = None # Enter an integer here for testing, so it doesn't take so long
9
10 def main():
11
12     parser = argparse.ArgumentParser(description='Take in abstracts, output CSV of n-gram counts')
13     parser.add_argument('-i', help='Location of the abstracts file',
14             default='processed_data/abstracts.tsv')
15     parser.add_argument('-o', help='Location of the output file',
16             default='processed_data/ngram_table.csv')
17     parser.add_argument('-n', type=int, help='Gets from 1 to n ngrams',
18         default=3)
19
20     args = parser.parse_args()
21
22     print("Loading dataset...")
23     t0 = time()
24     doc_ids, data_samples = get_ids_and_abstracts(args.i, n_samples)
25     print("done in %0.3fs." % (time() - t0))
26
27     # Write the header
28     write_header(args.o)
29
30     bags_o_words = get_counts(data_samples, n_features, args.n)
31     write_output(doc_ids, bags_o_words, args.o)
32
33 def get_counts(abstracts, n_features, ngram_max):
34     tf_vectorizer = CountVectorizer(max_df=0.95, min_df=2,
35                                     max_features=n_features,
36                                     stop_words='english',
37                                     ngram_range = (1,ngram_max))
38     t0 = time()
39     tf = tf_vectorizer.fit_transform(abstracts)
40     print("done in %0.3fs." % (time() - t0))
41
42     terms = tf_vectorizer.get_feature_names()
43     freqs = tf.toarray()
44     bags_o_words = to_bags_o_words(terms, freqs)
45     return bags_o_words
46
47
48 def write_header(out_file):
49     with open(out_file, 'w') as o_f:
50         out = csv.writer(o_f)
51         out.writerow(['document_id','term','frequency'])
52
53 def to_bags_o_words(terms, freqs):
54     '''Takes in the vectorizer stuff, and returns a list of dictionaries, one for each document.
55     The format of the dictionaries is term:count within that document.
56     '''
57     result = []
58     for d in freqs:
59         curr_result = {terms[i]:val for i,val in enumerate(d) if val > 0 }
60         result.append(curr_result)
61     return result
62
63 def write_output(ids, bags_o_words, out_file):
64     with open(out_file, 'a') as o_f:
65         out = csv.writer(o_f)
66         for i, doc in enumerate(bags_o_words):
67             for k,v in doc.items():
68                 # For each term and count, output a row, together with the document id
69                 out.writerow([ids[i],k,v])
70
71 def get_ids_and_abstracts(fn, length_limit):
72     with open(fn, 'r') as f:
73         in_csv = csv.DictReader(f, delimiter='\t')
74         abstracts = []
75         ids = []
76         i = 1
77         for r in in_csv:
78             try:
79                 abstracts.append(r['abstract'])
80                 ids.append(r['eid'])
81             except KeyError:
82                 print(r)
83             if length_limit and  i > length_limit:
84                 break
85             i += 1
86     return ids, abstracts
87
88 if __name__ == '__main__':
89     main()

Community Data Science Collective || Want to submit a patch?