]> code.communitydata.science - cdsc_reddit.git/blobdiff - ngrams/tf_comments.py
Merge remote-tracking branch 'refs/remotes/origin/excise_reindex' into excise_reindex
[cdsc_reddit.git] / ngrams / tf_comments.py
index f86548a957a866b56d4dec6e9b4f813b2a4b5fa2..f472eebbb2538bb4af353fd7fed9a7c5ff3825d2 100755 (executable)
@@ -13,25 +13,30 @@ from nltk.corpus import stopwords
 from nltk.util import ngrams
 import string
 from random import random
-
-# remove urls
-# taken from https://stackoverflow.com/questions/3809401/what-is-a-good-regular-expression-to-match-a-url
-urlregex = re.compile(r"[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*)")
+from redditcleaner import clean
+from pathlib import Path
 
 # compute term frequencies for comments in each subreddit by week
-def weekly_tf(partition, mwe_pass = 'first'):
-    dataset = ds.dataset(f'/gscratch/comdata/output/reddit_comments_by_subreddit.parquet/{partition}', format='parquet')
-    if not os.path.exists("/gscratch/comdata/users/nathante/reddit_comment_ngrams_10p_sample/"):
-        os.mkdir("/gscratch/comdata/users/nathante/reddit_comment_ngrams_10p_sample/")
+def weekly_tf(partition, outputdir = '/gscratch/comdata/output/reddit_ngrams/', input_dir="/gscratch/comdata/output/reddit_comments_by_subreddit.parquet/", mwe_pass = 'first', excluded_users=None):
+
+    dataset = ds.dataset(Path(input_dir)/partition, format='parquet')
+    outputdir = Path(outputdir)
+    samppath = outputdir / "reddit_comment_ngrams_10p_sample"
 
-    if not os.path.exists("/gscratch/comdata/users/nathante/reddit_tfidf_test_authors.parquet_temp/"):
-        os.mkdir("/gscratch/comdata/users/nathante/reddit_tfidf_test_authors.parquet_temp/")
+    if not samppath.exists():
+        samppath.mkdir(parents=True, exist_ok=True)
 
     ngram_output = partition.replace("parquet","txt")
 
+    if excluded_users is not None:
+        excluded_users = set(map(str.strip,open(excluded_users)))
+        df = df.filter(~ (f.col("author").isin(excluded_users)))
+
+
+    ngram_path = samppath / ngram_output
     if mwe_pass == 'first':
-        if os.path.exists(f"/gscratch/comdata/output/reddit_ngrams/comment_ngrams_10p_sample/{ngram_output}"):
-            os.remove(f"/gscratch/comdata/output/reddit_ngrams/comment_ngrams_10p_sample/{ngram_output}")
+        if ngram_path.exists():
+            ngram_path.unlink()
     
     batches = dataset.to_batches(columns=['CreatedAt','subreddit','body','author'])
 
@@ -65,8 +70,10 @@ def weekly_tf(partition, mwe_pass = 'first'):
 
     subreddit_weeks = groupby(rows, lambda r: (r.subreddit, r.week))
 
+    mwe_path = outputdir / "multiword_expressions.feather"
+
     if mwe_pass != 'first':
-        mwe_dataset = pd.read_feather(f'/gscratch/comdata/output/reddit_ngrams/multiword_expressions.feather')
+        mwe_dataset = pd.read_feather(mwe_path)
         mwe_dataset = mwe_dataset.sort_values(['phrasePWMI'],ascending=False)
         mwe_phrases = list(mwe_dataset.phrase)
         mwe_phrases = [tuple(s.split(' ')) for s in mwe_phrases]
@@ -95,8 +102,8 @@ def weekly_tf(partition, mwe_pass = 'first'):
         # lowercase        
         text = text.lower()
 
-        # remove urls
-        text = urlregex.sub("", text)
+        # redditcleaner removes reddit markdown(newlines, quotes, bullet points, links, strikethrough, spoiler, code, superscript, table, headings)
+        text = clean(text)
 
         # sentence tokenize
         sentences = sent_tokenize(text)
@@ -107,19 +114,18 @@ def weekly_tf(partition, mwe_pass = 'first'):
         # remove punctuation
                         
         sentences = map(remove_punct, sentences)
-
-        # remove sentences with less than 2 words
-        sentences = filter(lambda sentence: len(sentence) > 2, sentences)
-
         # datta et al. select relatively common phrases from the reddit corpus, but they don't really explain how. We'll try that in a second phase.
         # they say that the extract 1-4 grams from 10% of the sentences and then find phrases that appear often relative to the original terms
         # here we take a 10 percent sample of sentences 
         if mwe_pass == 'first':
+
+            # remove sentences with less than 2 words
+            sentences = filter(lambda sentence: len(sentence) > 2, sentences)
             sentences = list(sentences)
             for sentence in sentences:
                 if random() <= 0.1:
                     grams = list(chain(*map(lambda i : ngrams(sentence,i),range(4))))
-                    with open(f'/gscratch/comdata/output/reddit_ngrams/comment_ngrams_10p_sample/{ngram_output}','a') as gram_file:
+                    with open(ngram_path,'a') as gram_file:
                         for ng in grams:
                             gram_file.write(' '.join(ng) + '\n')
                 for token in sentence:
@@ -153,8 +159,15 @@ def weekly_tf(partition, mwe_pass = 'first'):
     outrows = tf_comments(subreddit_weeks)
 
     outchunksize = 10000
-
-    with pq.ParquetWriter(f"/gscratch/comdata/output/reddit_ngrams/comment_terms.parquet/{partition}",schema=schema,compression='snappy',flavor='spark') as writer, pq.ParquetWriter(f"/gscratch/comdata/output/reddit_ngrams/comment_authors.parquet/{partition}",schema=author_schema,compression='snappy',flavor='spark') as author_writer:
+    
+    termtf_outputdir = (outputdir / "comment_terms")
+    termtf_outputdir.mkdir(parents=True, exist_ok=True)
+    authortf_outputdir = (outputdir / "comment_authors")
+    authortf_outputdir.mkdir(parents=True, exist_ok=True)    
+    termtf_path = termtf_outputdir / partition
+    authortf_path = authortf_outputdir / partition
+    with pq.ParquetWriter(termtf_path, schema=schema, compression='snappy', flavor='spark') as writer, \
+         pq.ParquetWriter(authortf_path, schema=author_schema, compression='snappy', flavor='spark') as author_writer:
     
         while True:
 
@@ -183,12 +196,12 @@ def weekly_tf(partition, mwe_pass = 'first'):
         author_writer.close()
 
 
-def gen_task_list(mwe_pass='first'):
+def gen_task_list(mwe_pass='first', outputdir='/gscratch/comdata/output/reddit_ngrams/', tf_task_list='tf_task_list', excluded_users_file=None):
     files = os.listdir("/gscratch/comdata/output/reddit_comments_by_subreddit.parquet/")
-    with open("tf_task_list",'w') as outfile:
+    with open(tf_task_list,'w') as outfile:
         for f in files:
             if f.endswith(".parquet"):
-                outfile.write(f"./tf_comments.py weekly_tf --mwe-pass {mwe_pass} {f}\n")
+                outfile.write(f"./tf_comments.py weekly_tf --mwe-pass {mwe_pass} --outputdir {outputdir} --excluded_users {excluded_users_file} {f}\n")
 
 if __name__ == "__main__":
     fire.Fire({"gen_task_list":gen_task_list,

Community Data Science Collective || Want to submit a patch?