]> code.communitydata.science - cdsc_reddit.git/blob - ngrams/tf_comments.py
make pass keyword arg to dataframe.drop
[cdsc_reddit.git] / ngrams / tf_comments.py
1 #!/usr/bin/env python3
2 import pandas as pd
3 import pyarrow as pa
4 import pyarrow.dataset as ds
5 import pyarrow.parquet as pq
6 import pyarrow.compute as pc
7 from itertools import groupby, islice, chain
8 import fire
9 from collections import Counter
10 import os
11 import re
12 from nltk import wordpunct_tokenize, MWETokenizer, sent_tokenize
13 from nltk.corpus import stopwords
14 from nltk.util import ngrams
15 import string
16 from random import random
17 from redditcleaner import clean
18 from pathlib import Path
19 from datetime import datetime
20
21 # compute term frequencies for comments in each subreddit by week
22 def weekly_tf(partition, outputdir = '/gscratch/comdata/output/reddit_ngrams/', inputdir="/gscratch/comdata/output/reddit_comments_by_subreddit.parquet/", mwe_pass = 'first', excluded_users=None):
23
24     dataset = ds.dataset(Path(inputdir)/partition, format='parquet')
25     outputdir = Path(outputdir)
26     samppath = outputdir / "reddit_comment_ngrams_10p_sample"
27
28     if not samppath.exists():
29         samppath.mkdir(parents=True, exist_ok=True)
30
31     ngram_output = partition.replace("parquet","txt")
32
33     if excluded_users is not None:
34         excluded_users = set(map(str.strip,open(excluded_users)))
35         df = df.filter(~ (f.col("author").isin(excluded_users)))
36
37
38     ngram_path = samppath / ngram_output
39     if mwe_pass == 'first':
40         if ngram_path.exists():
41             ngram_path.unlink()
42
43     dataset = dataset.filter(pc.field("CreatedAt") <= pa.scalar(datetime(2020,4,13)))
44     batches = dataset.to_batches(columns=['CreatedAt','subreddit','body','author'])
45
46
47     schema = pa.schema([pa.field('subreddit', pa.string(), nullable=False),
48                         pa.field('term', pa.string(), nullable=False),
49                         pa.field('week', pa.date32(), nullable=False),
50                         pa.field('tf', pa.int64(), nullable=False)]
51     )
52
53     author_schema = pa.schema([pa.field('subreddit', pa.string(), nullable=False),
54                                pa.field('author', pa.string(), nullable=False),
55                                pa.field('week', pa.date32(), nullable=False),
56                                pa.field('tf', pa.int64(), nullable=False)]
57     )
58
59     dfs = (b.to_pandas() for b in batches)
60
61     def add_week(df):
62         df['week'] = (df.CreatedAt - pd.to_timedelta(df.CreatedAt.dt.dayofweek, unit='d')).dt.date
63         return(df)
64
65     dfs = (add_week(df) for df in dfs)
66
67     def iterate_rows(dfs):
68         for df in dfs:
69             for row in df.itertuples():
70                 yield row
71
72     rows = iterate_rows(dfs)
73
74     subreddit_weeks = groupby(rows, lambda r: (r.subreddit, r.week))
75
76     mwe_path = outputdir / "multiword_expressions.feather"
77
78     if mwe_pass != 'first':
79         mwe_dataset = pd.read_feather(mwe_path)
80         mwe_dataset = mwe_dataset.sort_values(['phrasePWMI'],ascending=False)
81         mwe_phrases = list(mwe_dataset.phrase)
82         mwe_phrases = [tuple(s.split(' ')) for s in mwe_phrases]
83         mwe_tokenizer = MWETokenizer(mwe_phrases)
84         mwe_tokenize = mwe_tokenizer.tokenize
85     
86     else:
87         mwe_tokenize = MWETokenizer().tokenize
88
89     def remove_punct(sentence):
90         new_sentence = []
91         for token in sentence:
92             new_token = ''
93             for c in token:
94                 if c not in string.punctuation:
95                     new_token += c
96             if len(new_token) > 0:
97                 new_sentence.append(new_token)
98         return new_sentence
99
100     stopWords = set(stopwords.words('english'))
101
102     # we follow the approach described in datta, phelan, adar 2017
103     def my_tokenizer(text):
104         # remove stopwords, punctuation, urls, lower case
105         # lowercase        
106         text = text.lower()
107
108         # redditcleaner removes reddit markdown(newlines, quotes, bullet points, links, strikethrough, spoiler, code, superscript, table, headings)
109         text = clean(text)
110
111         # sentence tokenize
112         sentences = sent_tokenize(text)
113
114         # wordpunct_tokenize
115         sentences = map(wordpunct_tokenize, sentences)
116
117         # remove punctuation
118                         
119         sentences = map(remove_punct, sentences)
120         # datta et al. select relatively common phrases from the reddit corpus, but they don't really explain how. We'll try that in a second phase.
121         # they say that the extract 1-4 grams from 10% of the sentences and then find phrases that appear often relative to the original terms
122         # here we take a 10 percent sample of sentences 
123         if mwe_pass == 'first':
124
125             # remove sentences with less than 2 words
126             sentences = filter(lambda sentence: len(sentence) > 2, sentences)
127             sentences = list(sentences)
128             for sentence in sentences:
129                 if random() <= 0.1:
130                     grams = list(chain(*map(lambda i : ngrams(sentence,i),range(4))))
131                     with open(ngram_path,'a') as gram_file:
132                         for ng in grams:
133                             gram_file.write(' '.join(ng) + '\n')
134                 for token in sentence:
135                     if token not in stopWords:
136                         yield token
137
138         else:
139             # remove stopWords
140             sentences = map(mwe_tokenize, sentences)
141             sentences = map(lambda s: filter(lambda token: token not in stopWords, s), sentences)
142             for sentence in sentences:
143                 for token in sentence:
144                     yield token
145
146     def tf_comments(subreddit_weeks):
147         for key, posts in subreddit_weeks:
148             subreddit, week = key
149             tfs = Counter([])
150             authors = Counter([])
151             for post in posts:
152                 tokens = my_tokenizer(post.body)
153                 tfs.update(tokens)
154                 authors.update([post.author])
155
156             for term, tf in tfs.items():
157                 yield [True, subreddit, term, week, tf]
158
159             for author, tf in authors.items():
160                 yield [False, subreddit, author, week, tf]
161
162     outrows = tf_comments(subreddit_weeks)
163
164     outchunksize = 10000
165     
166     termtf_outputdir = (outputdir / "comment_terms.parquet")
167     termtf_outputdir.mkdir(parents=True, exist_ok=True)
168     authortf_outputdir = (outputdir / "comment_authors.parquet")
169     authortf_outputdir.mkdir(parents=True, exist_ok=True)    
170     termtf_path = termtf_outputdir / partition
171     authortf_path = authortf_outputdir / partition
172     with pq.ParquetWriter(termtf_path, schema=schema, compression='snappy', flavor='spark') as writer, \
173          pq.ParquetWriter(authortf_path, schema=author_schema, compression='snappy', flavor='spark') as author_writer:
174     
175         while True:
176
177             chunk = islice(outrows,outchunksize)
178             chunk = (c for c in chunk if c[1] is not None)
179             pddf = pd.DataFrame(chunk, columns=["is_token"] + schema.names)
180             author_pddf = pddf.loc[pddf.is_token == False, schema.names]
181             pddf = pddf.loc[pddf.is_token == True, schema.names]
182             author_pddf = author_pddf.rename({'term':'author'}, axis='columns')
183             author_pddf = author_pddf.loc[:,author_schema.names]
184             table = pa.Table.from_pandas(pddf,schema=schema)
185             author_table = pa.Table.from_pandas(author_pddf,schema=author_schema)
186             do_break = True
187
188             if table.shape[0] != 0:
189                 writer.write_table(table)
190                 do_break = False
191             if author_table.shape[0] != 0:
192                 author_writer.write_table(author_table)
193                 do_break = False
194
195             if do_break:
196                 break
197
198         writer.close()
199         author_writer.close()
200
201
202 def gen_task_list(mwe_pass='first', inputdir="/gscratch/comdata/output/reddit_comments_by_subreddit.parquet/", outputdir='/gscratch/comdata/output/reddit_ngrams/', tf_task_list='tf_task_list', excluded_users_file=None):
203     files = os.listdir(inputdir)
204     with open(tf_task_list,'w') as outfile:
205         for f in files:
206             if f.endswith(".parquet"):
207                 outfile.write(f"./tf_comments.py weekly_tf --mwe-pass {mwe_pass} --inputdir {inputdir} --outputdir {outputdir} --excluded_users {excluded_users_file} {f}\n")
208
209 if __name__ == "__main__":
210     fire.Fire({"gen_task_list":gen_task_list,
211                "weekly_tf":weekly_tf})

Community Data Science Collective || Want to submit a patch?