]> code.communitydata.science - cdsc_reddit.git/blobdiff - similarities/similarities_helper.py
Merge branch 'excise_reindex' of code:cdsc_reddit into excise_reindex
[cdsc_reddit.git] / similarities / similarities_helper.py
index 88c830cacf7d5971e6e885882b1771d0c864183b..13845d155200d04cb270308c6f61ef924900bdc2 100644 (file)
+from pyspark.sql import SparkSession
 from pyspark.sql import Window
 from pyspark.sql import functions as f
 from enum import Enum
 from pyspark.sql import Window
 from pyspark.sql import functions as f
 from enum import Enum
+from multiprocessing import cpu_count, Pool
 from pyspark.mllib.linalg.distributed import CoordinateMatrix
 from tempfile import TemporaryDirectory
 import pyarrow
 import pyarrow.dataset as ds
 from pyspark.mllib.linalg.distributed import CoordinateMatrix
 from tempfile import TemporaryDirectory
 import pyarrow
 import pyarrow.dataset as ds
-from scipy.sparse import csr_matrix
+from sklearn.metrics import pairwise_distances
+from scipy.sparse import csr_matrix, issparse
+from sklearn.decomposition import TruncatedSVD
 import pandas as pd
 import numpy as np
 import pathlib
 import pandas as pd
 import numpy as np
 import pathlib
+from datetime import datetime
+from pathlib import Path
+import pickle
 
 class tf_weight(Enum):
     MaxTF = 1
     Norm05 = 2
 
 
 class tf_weight(Enum):
     MaxTF = 1
     Norm05 = 2
 
-def read_tfidf_matrix_weekly(path, term_colname, week):
-    term = term_colname
-    term_id = term + '_id'
-    term_id_new = term + '_id_new'
-
-    dataset = ds.dataset(path,format='parquet')
-    entries = dataset.to_table(columns=['tf_idf','subreddit_id_new',term_id_new],filter=ds.field('week')==week).to_pandas()
-    return(csr_matrix((entries.tf_idf,(entries[term_id_new]-1, entries.subreddit_id_new-1))))
-
-def write_weekly_similarities(path, sims, week, names):
-    sims['week'] = week
-    p = pathlib.Path(path)
-    if not p.is_dir():
-        p.mkdir()
-        
-    # reformat as a pairwise list
-    sims = sims.melt(id_vars=['subreddit','week'],value_vars=names.subreddit.values)
-    sims.to_parquet(p / week.isoformat())
-
+# infile = "/gscratch/comdata/output/reddit_similarity/tfidf_weekly/comment_terms.parquet"
+# cache_file = "/gscratch/comdata/users/nathante/cdsc_reddit/similarities/term_tfidf_entries_bak.parquet"
 
 
+# subreddits missing after this step don't have any terms that have a high enough idf
+# try rewriting without merges
 
 
-def read_tfidf_matrix(path,term_colname):
-    term = term_colname
-    term_id = term + '_id'
-    term_id_new = term + '_id_new'
+# does reindex_tfidf, but without reindexing.
+def reindex_tfidf(*args, **kwargs):
+    df, tfidf_ds, ds_filter = _pull_or_reindex_tfidf(*args, **kwargs, reindex=True)
 
 
-    dataset = ds.dataset(path,format='parquet')
-    entries = dataset.to_table(columns=['tf_idf','subreddit_id_new',term_id_new]).to_pandas()
-    return(csr_matrix((entries.tf_idf,(entries[term_id_new]-1, entries.subreddit_id_new-1))))
+    print("assigning names")
+    subreddit_names = tfidf_ds.to_table(filter=ds_filter,columns=['subreddit','subreddit_id'])
+    batches = subreddit_names.to_batches()
     
     
-def column_similarities(mat):
-    norm = np.matrix(np.power(mat.power(2).sum(axis=0),0.5,dtype=np.float32))
-    mat = mat.multiply(1/norm)
-    sims = mat.T @ mat
-    return(sims)
-
-
-def prep_tfidf_entries_weekly(tfidf, term_colname, min_df, included_subreddits):
-    term = term_colname
-    term_id = term + '_id'
-    term_id_new = term + '_id_new'
-
-    if min_df is None:
-        min_df = 0.1 * len(included_subreddits)
+    with Pool(cpu_count()) as pool:
+        chunks = pool.imap_unordered(pull_names,batches) 
+        subreddit_names = pd.concat(chunks,copy=False).drop_duplicates()
+        subreddit_names = subreddit_names.set_index("subreddit_id")
+
+    new_ids = df.loc[:,['subreddit_id','subreddit_id_new']].drop_duplicates()
+    new_ids = new_ids.set_index('subreddit_id')
+    subreddit_names = subreddit_names.join(new_ids,on='subreddit_id').reset_index()
+    subreddit_names = subreddit_names.drop("subreddit_id",1)
+    subreddit_names = subreddit_names.sort_values("subreddit_id_new")
+    return(df, subreddit_names)
+
+def pull_tfidf(*args, **kwargs):
+    df, _, _ =  _pull_or_reindex_tfidf(*args, **kwargs, reindex=False)
+    return df
 
 
-    tfidf = tfidf.filter(f.col("subreddit").isin(included_subreddits))
+def _pull_or_reindex_tfidf(infile, term_colname, min_df=None, max_df=None, included_subreddits=None, topN=500, week=None, from_date=None, to_date=None, rescale_idf=True, tf_family=tf_weight.MaxTF, reindex=True):
+    print(f"loading tfidf {infile}", flush=True)
+    if week is not None:
+        tfidf_ds = ds.dataset(infile, partitioning='hive')
+    else: 
+        tfidf_ds = ds.dataset(infile)
 
 
-    # we might not have the same terms or subreddits each week, so we need to make unique ids for each week.
-    sub_ids = tfidf.select(['subreddit_id','week']).distinct()
-    sub_ids = sub_ids.withColumn("subreddit_id_new",f.row_number().over(Window.partitionBy('week').orderBy("subreddit_id")))
-    tfidf = tfidf.join(sub_ids,['subreddit_id','week'])
+    if included_subreddits is None:
+        included_subreddits = select_topN_subreddits(topN)
+    else:
+        included_subreddits = set(map(str.strip,open(included_subreddits)))
 
 
-    # only use terms in at least min_df included subreddits in a given week
-    new_count = tfidf.groupBy([term_id,'week']).agg(f.count(term_id).alias('new_count'))
-    tfidf = tfidf.join(new_count,[term_id,'week'],how='inner')
+    ds_filter = ds.field("subreddit").isin(included_subreddits)
 
 
-    # reset the term ids
-    term_ids = tfidf.select([term_id,'week']).distinct()
-    term_ids = term_ids.withColumn(term_id_new,f.row_number().over(Window.partitionBy('week').orderBy(term_id)))
-    tfidf = tfidf.join(term_ids,[term_id,'week'])
+    if min_df is not None:
+        ds_filter &= ds.field("count") >= min_df
 
 
-    tfidf = tfidf.withColumnRenamed("tf_idf","tf_idf_old")
-    tfidf = tfidf.withColumn("tf_idf", (tfidf.relative_tf * tfidf.idf).cast('float'))
+    if max_df is not None:
+        ds_filter &= ds.field("count") <= max_df
 
 
-    tempdir =TemporaryDirectory(suffix='.parquet',prefix='term_tfidf_entries',dir='.')
+    if week is not None:
+        ds_filter &= ds.field("week") == week
 
 
-    tfidf = tfidf.repartition('week')
+    if from_date is not None:
+        ds_filter &= ds.field("week") >= from_date
 
 
-    tfidf.write.parquet(tempdir.name,mode='overwrite',compression='snappy')
-    return(tempdir)
-    
+    if to_date is not None:
+        ds_filter &= ds.field("week") <= to_date
 
 
-def prep_tfidf_entries(tfidf, term_colname, min_df, included_subreddits):
     term = term_colname
     term_id = term + '_id'
     term_id_new = term + '_id_new'
     term = term_colname
     term_id = term + '_id'
     term_id_new = term + '_id_new'
-
-    if min_df is None:
-        min_df = 0.1 * len(included_subreddits)
-
-    tfidf = tfidf.filter(f.col("subreddit").isin(included_subreddits))
-
-    # reset the subreddit ids
-    sub_ids = tfidf.select('subreddit_id').distinct()
-    sub_ids = sub_ids.withColumn("subreddit_id_new",f.row_number().over(Window.orderBy("subreddit_id")))
-    tfidf = tfidf.join(sub_ids,'subreddit_id')
-
-    # only use terms in at least min_df included subreddits
-    new_count = tfidf.groupBy(term_id).agg(f.count(term_id).alias('new_count'))
-    tfidf = tfidf.join(new_count,term_id,how='inner')
     
     
-    # reset the term ids
-    term_ids = tfidf.select([term_id]).distinct()
-    term_ids = term_ids.withColumn(term_id_new,f.row_number().over(Window.orderBy(term_id)))
-    tfidf = tfidf.join(term_ids,term_id)
+    projection = {
+        'subreddit_id':ds.field('subreddit_id'),
+        term_id:ds.field(term_id),
+        'relative_tf':ds.field("relative_tf").cast('float32')
+        }
+
+    if not rescale_idf:
+        projection = {
+            'subreddit_id':ds.field('subreddit_id'),
+            term_id:ds.field(term_id),
+            'relative_tf':ds.field('relative_tf').cast('float32'),
+            'tf_idf':ds.field('tf_idf').cast('float32')}
+
+
+    df = tfidf_ds.to_table(filter=ds_filter,columns=projection)
+
+    df = df.to_pandas(split_blocks=True,self_destruct=True)
+    print("assigning indexes",flush=True)
+    if reindex:
+        df['subreddit_id_new'] = df.groupby("subreddit_id").ngroup()
+    else:
+        df['subreddit_id_new'] = df['subreddit_id']
+
+    if reindex:
+        grouped = df.groupby(term_id)
+        df[term_id_new] = grouped.ngroup()
+    else:
+        df[term_id_new] = df[term_id]
+
+    if rescale_idf:
+        print("computing idf", flush=True)
+        df['new_count'] = grouped[term_id].transform('count')
+        N_docs = df.subreddit_id_new.max() + 1
+        df['idf'] = np.log(N_docs/(1+df.new_count),dtype='float32') + 1
+        if tf_family == tf_weight.MaxTF:
+            df["tf_idf"] = df.relative_tf * df.idf
+        else: # tf_fam = tf_weight.Norm05
+            df["tf_idf"] = (0.5 + 0.5 * df.relative_tf) * df.idf
+
+    return (df, tfidf_ds, ds_filter)
+
+    with Pool(cpu_count()) as pool:
+        chunks = pool.imap_unordered(pull_names,batches) 
+        subreddit_names = pd.concat(chunks,copy=False).drop_duplicates()
+
+    subreddit_names = subreddit_names.set_index("subreddit_id")
+    new_ids = df.loc[:,['subreddit_id','subreddit_id_new']].drop_duplicates()
+    new_ids = new_ids.set_index('subreddit_id')
+    subreddit_names = subreddit_names.join(new_ids,on='subreddit_id').reset_index()
+    subreddit_names = subreddit_names.drop("subreddit_id",1)
+    subreddit_names = subreddit_names.sort_values("subreddit_id_new")
+    return(df, subreddit_names)
+
+def pull_names(batch):
+    return(batch.to_pandas().drop_duplicates())
+
+def similarities(inpath, simfunc, term_colname, outfile, min_df=None, max_df=None, included_subreddits=None, topN=500, from_date=None, to_date=None, tfidf_colname='tf_idf'):
+    '''
+    tfidf_colname: set to 'relative_tf' to use normalized term frequency instead of tf-idf, which can be useful for author-based similarities.
+    '''
+
+    def proc_sims(sims, outfile):
+        if issparse(sims):
+            sims = sims.todense()
+
+        print(f"shape of sims:{sims.shape}")
+        print(f"len(subreddit_names.subreddit.values):{len(subreddit_names.subreddit.values)}",flush=True)
+        sims = pd.DataFrame(sims)
+        sims = sims.rename({i:sr for i, sr in enumerate(subreddit_names.subreddit.values)}, axis=1)
+        sims['_subreddit'] = subreddit_names.subreddit.values
+
+        p = Path(outfile)
+
+        output_feather =  Path(str(p).replace("".join(p.suffixes), ".feather"))
+        output_csv =  Path(str(p).replace("".join(p.suffixes), ".csv"))
+        output_parquet =  Path(str(p).replace("".join(p.suffixes), ".parquet"))
+        p.parent.mkdir(exist_ok=True, parents=True)
+
+        sims.to_feather(outfile)
 
 
-    tfidf = tfidf.withColumnRenamed("tf_idf","tf_idf_old")
-    tfidf = tfidf.withColumn("tf_idf", (tfidf.relative_tf * tfidf.idf).cast('float'))
-    
-    tempdir =TemporaryDirectory(suffix='.parquet',prefix='term_tfidf_entries',dir='.')
-    
-    tfidf.write.parquet(tempdir.name,mode='overwrite',compression='snappy')
-    return tempdir
-
-
-# try computing cosine similarities using spark
-def spark_cosine_similarities(tfidf, term_colname, min_df, included_subreddits, similarity_threshold):
     term = term_colname
     term_id = term + '_id'
     term_id_new = term + '_id_new'
 
     term = term_colname
     term_id = term + '_id'
     term_id_new = term + '_id_new'
 
-    if min_df is None:
-        min_df = 0.1 * len(included_subreddits)
+    entries, subreddit_names = reindex_tfidf(inpath, term_colname=term_colname, min_df=min_df, max_df=max_df, included_subreddits=included_subreddits, topN=topN,from_date=from_date,to_date=to_date)
+    mat = csr_matrix((entries[tfidf_colname],(entries[term_id_new], entries.subreddit_id_new)))
 
 
-    tfidf = tfidf.filter(f.col("subreddit").isin(included_subreddits))
-    tfidf = tfidf.cache()
+    print("loading matrix")        
 
 
-    # reset the subreddit ids
-    sub_ids = tfidf.select('subreddit_id').distinct()
-    sub_ids = sub_ids.withColumn("subreddit_id_new",f.row_number().over(Window.orderBy("subreddit_id")))
-    tfidf = tfidf.join(sub_ids,'subreddit_id')
-
-    # only use terms in at least min_df included subreddits
-    new_count = tfidf.groupBy(term_id).agg(f.count(term_id).alias('new_count'))
-    tfidf = tfidf.join(new_count,term_id,how='inner')
-    
-    # reset the term ids
-    term_ids = tfidf.select([term_id]).distinct()
-    term_ids = term_ids.withColumn(term_id_new,f.row_number().over(Window.orderBy(term_id)))
-    tfidf = tfidf.join(term_ids,term_id)
+    #    mat = read_tfidf_matrix("term_tfidf_entries7ejhvnvl.parquet", term_colname)
 
 
-    tfidf = tfidf.withColumnRenamed("tf_idf","tf_idf_old")
-    tfidf = tfidf.withColumn("tf_idf", tfidf.relative_tf * tfidf.idf)
+    print(f'computing similarities on mat. mat.shape:{mat.shape}')
+    print(f"size of mat is:{mat.data.nbytes}",flush=True)
+    sims = simfunc(mat)
+    del mat
 
 
-    # step 1 make an rdd of entires
-    # sorted by (dense) spark subreddit id
-    n_partitions = int(len(included_subreddits)*2 / 5)
+    if hasattr(sims,'__next__'):
+        for simmat, name in sims:
+            proc_sims(simmat, Path(outfile)/(str(name) + ".feather"))
+    else:
+        proc_sims(sims, outfile)
 
 
-    entries = tfidf.select(f.col(term_id_new)-1,f.col("subreddit_id_new")-1,"tf_idf").rdd.repartition(n_partitions)
-
-    # put like 10 subredis in each partition
-
-    # step 2 make it into a distributed.RowMatrix
-    coordMat = CoordinateMatrix(entries)
+def write_weekly_similarities(path, sims, week, names):
+    sims['week'] = week
+    p = pathlib.Path(path)
+    if not p.is_dir():
+        p.mkdir(exist_ok=True,parents=True)
+        
+    # reformat as a pairwise list
+    sims = sims.melt(id_vars=['_subreddit','week'],value_vars=names.subreddit.values)
+    sims.to_parquet(p / week.isoformat())
 
 
-    coordMat = CoordinateMatrix(coordMat.entries.repartition(n_partitions))
+def column_overlaps(mat):
+    non_zeros = (mat != 0).astype('double')
+    
+    intersection = non_zeros.T @ non_zeros
+    card1 = non_zeros.sum(axis=0)
+    den = np.add.outer(card1,card1) - intersection
 
 
-    # this needs to be an IndexedRowMatrix()
-    mat = coordMat.toRowMatrix()
+    return intersection / den
+    
+def test_lsi_sims():
+    term = "term"
+    term_id = term + '_id'
+    term_id_new = term + '_id_new'
 
 
-    #goal: build a matrix of subreddit columns and tf-idfs rows
-    sim_dist = mat.columnSimilarities(threshold=similarity_threshold)
+    t1 = time.perf_counter()
+    entries, subreddit_names = reindex_tfidf("/gscratch/comdata/output/reddit_similarity/tfidf/comment_terms_100k_repartitioned.parquet",
+                                             term_colname='term',
+                                             min_df=2000,
+                                             topN=10000
+                                             )
+    t2 = time.perf_counter()
+    print(f"first load took:{t2 - t1}s")
+
+    entries, subreddit_names = reindex_tfidf("/gscratch/comdata/output/reddit_similarity/tfidf/comment_terms_100k.parquet",
+                                             term_colname='term',
+                                             min_df=2000,
+                                             topN=10000
+                                             )
+    t3=time.perf_counter()
+
+    print(f"second load took:{t3 - t2}s")
+
+    mat = csr_matrix((entries['tf_idf'],(entries[term_id_new], entries.subreddit_id_new)))
+    sims = list(lsi_column_similarities(mat, [10,50]))
+    sims_og = sims
+    sims_test = list(lsi_column_similarities(mat,[10,50],algorithm='randomized',n_iter=10))
+
+# n_components is the latent dimensionality. sklearn recommends 100. More might be better
+# if n_components is a list we'll return a list of similarities with different latent dimensionalities
+# if algorithm is 'randomized' instead of 'arpack' then n_iter gives the number of iterations.
+# this function takes the svd and then the column similarities of it
+def lsi_column_similarities(tfidfmat,n_components=300,n_iter=10,random_state=1968,algorithm='randomized',lsi_model_save=None,lsi_model_load=None):
+    # first compute the lsi of the matrix
+    # then take the column similarities
+    print("running LSI",flush=True)
+
+    if type(n_components) is int:
+        n_components = [n_components]
+
+    n_components = sorted(n_components,reverse=True)
+    
+    svd_components = n_components[0]
+    
+    if lsi_model_load is not None:
+        mod = pickle.load(open(lsi_model_load ,'rb'))
+
+    else:
+        svd = TruncatedSVD(n_components=svd_components,random_state=random_state,algorithm=algorithm,n_iter=n_iter)
+        mod = svd.fit(tfidfmat.T)
+
+    lsimat = mod.transform(tfidfmat.T)
+    if lsi_model_save is not None:
+        pickle.dump(mod, open(lsi_model_save,'wb'))
+
+    sims_list = []
+    for n_dims in n_components:
+        sims = column_similarities(lsimat[:,np.arange(n_dims)])
+        if len(n_components) > 1:
+            yield (sims, n_dims)
+        else:
+            return sims
+    
 
 
-    return (sim_dist, tfidf)
+def column_similarities(mat):
+    return 1 - pairwise_distances(mat,metric='cosine')
 
 
 def build_weekly_tfidf_dataset(df, include_subs, term_colname, tf_family=tf_weight.Norm05):
 
 
 def build_weekly_tfidf_dataset(df, include_subs, term_colname, tf_family=tf_weight.Norm05):
@@ -194,20 +297,20 @@ def build_weekly_tfidf_dataset(df, include_subs, term_colname, tf_family=tf_weig
     idf = idf.withColumn('idf',f.log(idf.subreddits_in_week) / (1+f.col('count'))+1)
 
     # collect the dictionary to make a pydict of terms to indexes
     idf = idf.withColumn('idf',f.log(idf.subreddits_in_week) / (1+f.col('count'))+1)
 
     # collect the dictionary to make a pydict of terms to indexes
-    terms = idf.select([term,'week']).distinct() # terms are distinct
+    terms = idf.select([term]).distinct() # terms are distinct
 
 
-    terms = terms.withColumn(term_id,f.row_number().over(Window.partitionBy('week').orderBy(term))) # term ids are distinct
+    terms = terms.withColumn(term_id,f.row_number().over(Window.orderBy(term))) # term ids are distinct
 
     # make subreddit ids
 
     # make subreddit ids
-    subreddits = df.select(['subreddit','week']).distinct()
-    subreddits = subreddits.withColumn('subreddit_id',f.row_number().over(Window.partitionBy("week").orderBy("subreddit")))
+    subreddits = df.select(['subreddit']).distinct()
+    subreddits = subreddits.withColumn('subreddit_id',f.row_number().over(Window.orderBy("subreddit")))
 
 
-    df = df.join(subreddits,on=['subreddit','week'])
+    df = df.join(subreddits,on=['subreddit'])
 
     # map terms to indexes in the tfs and the idfs
 
     # map terms to indexes in the tfs and the idfs
-    df = df.join(terms,on=[term,'week']) # subreddit-term-id is unique
+    df = df.join(terms,on=[term]) # subreddit-term-id is unique
 
 
-    idf = idf.join(terms,on=[term,'week'])
+    idf = idf.join(terms,on=[term])
 
     # join on subreddit/term to create tf/dfs indexed by term
     df = df.join(idf, on=[term_id, term,'week'])
 
     # join on subreddit/term to create tf/dfs indexed by term
     df = df.join(idf, on=[term_id, term,'week'])
@@ -219,30 +322,24 @@ def build_weekly_tfidf_dataset(df, include_subs, term_colname, tf_family=tf_weig
     else: # tf_fam = tf_weight.Norm05
         df = df.withColumn("tf_idf",  (0.5 + 0.5 * df.relative_tf) * df.idf)
 
     else: # tf_fam = tf_weight.Norm05
         df = df.withColumn("tf_idf",  (0.5 + 0.5 * df.relative_tf) * df.idf)
 
-    return df
-
-
-
-def build_tfidf_dataset(df, include_subs, term_colname, tf_family=tf_weight.Norm05):
+    df = df.repartition(400,'subreddit','week')
+    dfwriter = df.write.partitionBy("week")
+    return dfwriter
 
 
+def _calc_tfidf(df, term_colname, tf_family):
     term = term_colname
     term_id = term + '_id'
     term = term_colname
     term_id = term + '_id'
-    # aggregate counts by week. now subreddit-term is distinct
-    df = df.filter(df.subreddit.isin(include_subs))
-    df = df.groupBy(['subreddit',term]).agg(f.sum('tf').alias('tf'))
 
     max_subreddit_terms = df.groupby(['subreddit']).max('tf') # subreddits are unique
     max_subreddit_terms = max_subreddit_terms.withColumnRenamed('max(tf)','sr_max_tf')
 
     df = df.join(max_subreddit_terms, on='subreddit')
 
 
     max_subreddit_terms = df.groupby(['subreddit']).max('tf') # subreddits are unique
     max_subreddit_terms = max_subreddit_terms.withColumnRenamed('max(tf)','sr_max_tf')
 
     df = df.join(max_subreddit_terms, on='subreddit')
 
-    df = df.withColumn("relative_tf", df.tf / df.sr_max_tf)
+    df = df.withColumn("relative_tf", (df.tf / df.sr_max_tf))
 
     # group by term. term is unique
     idf = df.groupby([term]).count()
 
     # group by term. term is unique
     idf = df.groupby([term]).count()
-
     N_docs = df.select('subreddit').distinct().count()
     N_docs = df.select('subreddit').distinct().count()
-
     # add a little smoothing to the idf
     idf = idf.withColumn('idf',f.log(N_docs/(1+f.col('count')))+1)
 
     # add a little smoothing to the idf
     idf = idf.withColumn('idf',f.log(N_docs/(1+f.col('count')))+1)
 
@@ -271,8 +368,38 @@ def build_tfidf_dataset(df, include_subs, term_colname, tf_family=tf_weight.Norm
         df = df.withColumn("tf_idf",  (0.5 + 0.5 * df.relative_tf) * df.idf)
 
     return df
         df = df.withColumn("tf_idf",  (0.5 + 0.5 * df.relative_tf) * df.idf)
 
     return df
+    
 
 
-def select_topN_subreddits(topN, path="/gscratch/comdata/output/reddit_similarity/subreddits_by_num_comments.csv"):
+def tfidf_dataset(df, include_subs, term_colname, tf_family=tf_weight.Norm05):
+    term = term_colname
+    term_id = term + '_id'
+    # aggregate counts by week. now subreddit-term is distinct
+    df = df.filter(df.subreddit.isin(include_subs))
+    df = df.groupBy(['subreddit',term]).agg(f.sum('tf').alias('tf'))
+
+    df = _calc_tfidf(df, term_colname, tf_family)
+    df = df.repartition('subreddit')
+    dfwriter = df.write
+    return dfwriter
+
+def select_topN_subreddits(topN, path="/gscratch/comdata/output/reddit_similarity/subreddits_by_num_comments_nonsfw.csv"):
     rankdf = pd.read_csv(path)
     included_subreddits = set(rankdf.loc[rankdf.comments_rank <= topN,'subreddit'].values)
     return included_subreddits
     rankdf = pd.read_csv(path)
     included_subreddits = set(rankdf.loc[rankdf.comments_rank <= topN,'subreddit'].values)
     return included_subreddits
+
+
+def repartition_tfidf(inpath="/gscratch/comdata/output/reddit_similarity/tfidf/comment_terms_100k.parquet",
+                      outpath="/gscratch/comdata/output/reddit_similarity/tfidf/comment_terms_100k_repartitioned.parquet"):
+    spark = SparkSession.builder.getOrCreate()
+    df = spark.read.parquet(inpath)
+    df = df.repartition(400,'subreddit')
+    df.write.parquet(outpath,mode='overwrite')
+
+    
+def repartition_tfidf_weekly(inpath="/gscratch/comdata/output/reddit_similarity/tfidf_weekly/comment_terms.parquet",
+                      outpath="/gscratch/comdata/output/reddit_similarity/tfidf/comment_terms_repartitioned.parquet"):
+    spark = SparkSession.builder.getOrCreate()
+    df = spark.read.parquet(inpath)
+    df = df.repartition(400,'subreddit','week')
+    dfwriter = df.write.partitionBy("week")
+    dfwriter.parquet(outpath,mode='overwrite')

Community Data Science Collective || Want to submit a patch?