]> code.communitydata.science - cdsc_reddit.git/blobdiff - similarities/similarities_helper.py
Merge remote-tracking branch 'refs/remotes/origin/excise_reindex' into excise_reindex
[cdsc_reddit.git] / similarities / similarities_helper.py
index 13845d155200d04cb270308c6f61ef924900bdc2..03c10b2310d3984e120eefcc23a6b3d4878bf113 100644 (file)
@@ -97,6 +97,7 @@ def _pull_or_reindex_tfidf(infile, term_colname, min_df=None, max_df=None, inclu
             'relative_tf':ds.field('relative_tf').cast('float32'),
             'tf_idf':ds.field('tf_idf').cast('float32')}
 
             'relative_tf':ds.field('relative_tf').cast('float32'),
             'tf_idf':ds.field('tf_idf').cast('float32')}
 
+        print(projection)
 
     df = tfidf_ds.to_table(filter=ds_filter,columns=projection)
 
 
     df = tfidf_ds.to_table(filter=ds_filter,columns=projection)
 
@@ -240,7 +241,6 @@ def test_lsi_sims():
 def lsi_column_similarities(tfidfmat,n_components=300,n_iter=10,random_state=1968,algorithm='randomized',lsi_model_save=None,lsi_model_load=None):
     # first compute the lsi of the matrix
     # then take the column similarities
 def lsi_column_similarities(tfidfmat,n_components=300,n_iter=10,random_state=1968,algorithm='randomized',lsi_model_save=None,lsi_model_load=None):
     # first compute the lsi of the matrix
     # then take the column similarities
-    print("running LSI",flush=True)
 
     if type(n_components) is int:
         n_components = [n_components]
 
     if type(n_components) is int:
         n_components = [n_components]
@@ -249,15 +249,20 @@ def lsi_column_similarities(tfidfmat,n_components=300,n_iter=10,random_state=196
     
     svd_components = n_components[0]
     
     
     svd_components = n_components[0]
     
-    if lsi_model_load is not None:
+    if lsi_model_load is not None and Path(lsi_model_load).exists():
+        print("loading LSI")
         mod = pickle.load(open(lsi_model_load ,'rb'))
         mod = pickle.load(open(lsi_model_load ,'rb'))
+        lsi_model_save = lsi_model_load
 
     else:
 
     else:
+        print("running LSI",flush=True)
+
         svd = TruncatedSVD(n_components=svd_components,random_state=random_state,algorithm=algorithm,n_iter=n_iter)
         mod = svd.fit(tfidfmat.T)
 
     lsimat = mod.transform(tfidfmat.T)
     if lsi_model_save is not None:
         svd = TruncatedSVD(n_components=svd_components,random_state=random_state,algorithm=algorithm,n_iter=n_iter)
         mod = svd.fit(tfidfmat.T)
 
     lsimat = mod.transform(tfidfmat.T)
     if lsi_model_save is not None:
+        Path(lsi_model_save).parent.mkdir(exist_ok=True, parents=True)
         pickle.dump(mod, open(lsi_model_save,'wb'))
 
     sims_list = []
         pickle.dump(mod, open(lsi_model_save,'wb'))
 
     sims_list = []

Community Data Science Collective || Want to submit a patch?