]> code.communitydata.science - cdsc_reddit.git/blob - similarities/cosine_similarities.py
ae080d5d6b614dac46de8a45ce5237cd5d94bd53
[cdsc_reddit.git] / similarities / cosine_similarities.py
1 from pyspark.sql import functions as f
2 from pyspark.sql import SparkSession
3 import pandas as pd
4 import fire
5 from pathlib import Path
6 from similarities_helper import prep_tfidf_entries, read_tfidf_matrix, select_topN_subreddits
7
8
9 def cosine_similarities(infile, term_colname, outfile, min_df=None, included_subreddits=None, topN=500, exclude_phrases=False):
10     spark = SparkSession.builder.getOrCreate()
11     conf = spark.sparkContext.getConf()
12     print(outfile)
13     print(exclude_phrases)
14
15     tfidf = spark.read.parquet(infile)
16
17     if included_subreddits is None:
18         included_subreddits = select_topN_subreddits(topN)
19     else:
20         included_subreddits = set(open(included_subreddits))
21
22     if exclude_phrases == True:
23         tfidf = tfidf.filter(~f.col(term_colname).contains("_"))
24
25     print("creating temporary parquet with matrix indicies")
26     tempdir = prep_tfidf_entries(tfidf, term_colname, min_df, included_subreddits)
27     tfidf = spark.read.parquet(tempdir.name)
28     subreddit_names = tfidf.select(['subreddit','subreddit_id_new']).distinct().toPandas()
29     subreddit_names = subreddit_names.sort_values("subreddit_id_new")
30     subreddit_names['subreddit_id_new'] = subreddit_names['subreddit_id_new'] - 1
31     spark.stop()
32
33     print("loading matrix")
34     mat = read_tfidf_matrix(tempdir.name, term_colname)
35     print('computing similarities')
36     sims = column_similarities(mat)
37     del mat
38     
39     sims = pd.DataFrame(sims.todense())
40     sims = sims.rename({i:sr for i, sr in enumerate(subreddit_names.subreddit.values)}, axis=1)
41     sims['subreddit'] = subreddit_names.subreddit.values
42
43     p = Path(outfile)
44
45     output_feather =  Path(str(p).replace("".join(p.suffixes), ".feather"))
46     output_csv =  Path(str(p).replace("".join(p.suffixes), ".csv"))
47     output_parquet =  Path(str(p).replace("".join(p.suffixes), ".parquet"))
48
49     sims.to_feather(outfile)
50     tempdir.cleanup()
51
52 def term_cosine_similarities(outfile, min_df=None, included_subreddits=None, topN=500, exclude_phrases=False):
53     return cosine_similarities('/gscratch/comdata/output/reddit_similarity/tfidf/comment_terms.parquet',
54                                'term',
55                                outfile,
56                                min_df,
57                                included_subreddits,
58                                topN,
59                                exclude_phrases)
60
61 def author_cosine_similarities(outfile, min_df=2, included_subreddits=None, topN=10000):
62     return cosine_similarities('/gscratch/comdata/output/reddit_similarity/tfidf/comment_authors.parquet',
63                                'author',
64                                outfile,
65                                min_df,
66                                included_subreddits,
67                                topN,
68                                exclude_phrases=False)
69
70 if __name__ == "__main__":
71     fire.Fire({'term':term_cosine_similarities,
72                'author':author_cosine_similarities})
73

Community Data Science Collective || Want to submit a patch?