]> code.communitydata.science - cdsc_reddit.git/blob - tf_reddit_comments.py
010b75935e761585a87c01914b92b84c1c02d806
[cdsc_reddit.git] / tf_reddit_comments.py
1 import pyarrow as pa
2 import pyarrow.dataset as ds
3 import pyarrow.parquet as pq
4 from itertools import groupby, islice, chain
5 import fire
6 from collections import Counter
7 import pandas as pd
8 import os
9 import datetime
10
11 # compute term frequencies for comments in each subreddit by week
12 def weekly_tf(partition):
13     dataset = ds.dataset(f'/gscratch/comdata/output/reddit_comments_by_subreddit.parquet/{partition}', format='parquet')
14     batches = dataset.to_batches(columns=['CreatedAt','subreddit','body'])
15
16     schema = pa.schema([pa.field('subreddit', pa.string(), nullable=False),
17                         pa.field('term', pa.string(), nullable=False),
18                         pa.field('week', pa.date32(), nullable=False),
19                         pa.field('tf', pa.int64(), nullable=False)]
20     )
21
22     dfs = (b.to_pandas() for b in batches)
23
24     def add_week(df):
25         df['week'] = (df.CreatedAt - pd.to_timedelta(df.CreatedAt.dt.dayofweek, unit='d')).dt.date
26         return(df)
27
28     dfs = (add_week(df) for df in dfs)
29
30     def iterate_rows(dfs):
31         for df in dfs:
32             for row in df.itertuples():
33                 yield row
34
35     rows = iterate_rows(dfs)
36
37     subreddit_weeks = groupby(rows, lambda r: (r.subreddit, r.week))
38
39     def tf_comments(subreddit_weeks):
40         for key, posts in subreddit_weeks:
41             subreddit, week = key
42             tfs = Counter([])
43
44             for post in posts:
45                 tfs.update(post.body.split())
46
47             for term, tf in tfs.items():
48                 yield [subreddit, term, week, tf]
49             
50     outrows = tf_comments(subreddit_weeks)
51
52     outchunksize = 10000
53
54     with pq.ParquetWriter("/gscratch/comdata/users/nathante/reddit_tfidf_test.parquet_temp/{partition}",schema=schema,compression='snappy',flavor='spark') as writer:
55         while True:
56             chunk = islice(outrows,outchunksize)
57             pddf = pd.DataFrame(chunk, columns=schema.names)
58             table = pa.Table.from_pandas(pddf,schema=schema)
59             if table.shape[0] == 0:
60                 break
61             writer.write_table(table)
62
63         writer.close()
64
65
66 def gen_task_list():
67     files = os.listdir("/gscratch/comdata/output/reddit_comments_by_subreddit.parquet/")
68     with open("tf_task_list",'w') as outfile:
69         for f in files:
70             if f.endswith(".parquet"):
71                 outfile.write(f"python3 tf_reddit_comments.py weekly_tf {f}\n")
72
73 if __name__ == "__main__":
74     fire.Fire({"gen_task_list":gen_task_list,
75                "weekly_tf":weekly_tf})

Community Data Science Collective || Want to submit a patch?