]> code.communitydata.science - cdsc_reddit.git/blobdiff - fit_tsne.py
refactor visualization code.
[cdsc_reddit.git] / fit_tsne.py
index 06e949a925bd7b585d33fd9db0adabc108e48f7a..28b0fd30630e4a666d619974f321cffb4ba37470 100644 (file)
@@ -1,35 +1,34 @@
+import fire
 import pyarrow
 import pandas as pd
 from numpy import random
 import numpy as np
 from sklearn.manifold import TSNE
 
 import pyarrow
 import pandas as pd
 from numpy import random
 import numpy as np
 from sklearn.manifold import TSNE
 
-df = pd.read_feather("reddit_term_similarity_3000.feather")
-df = df.sort_values(['i','j'])
+similarities = "term_similarities_10000.feather"
 
 
-n = max(df.i.max(),df.j.max())
+def fit_tsne(similarities, output, learning_rate=750, perplexity=50, n_iter=10000, early_exaggeration=20):
+    '''
+    similarities: feather file with a dataframe of similarity scores
+    learning_rate: parameter controlling how fast the model converges. Too low and you get outliers. Too high and you get a ball.
+    perplexity: number of neighbors to use. the default of 50 is often good.
 
 
-def zero_pad(grp):
-    p = grp.shape[0]
-    grp = grp.sort_values('j')
-    return np.concatenate([np.zeros(n-p),np.ones(1),np.array(grp.value)])
+    '''
+    df = pd.read_feather(similarities)
 
 
-col_names = df.sort_values('j').loc[:,['subreddit_j']].drop_duplicates()
-first_name = list(set(df.subreddit_i) - set(df.subreddit_j))[0]
-col_names = [first_name] + list(col_names.subreddit_j)
-mat = df.groupby('i').apply(zero_pad)
-mat.loc[n] = np.concatenate([np.zeros(n),np.ones(1)])
-mat = np.stack(mat)
+    n = df.shape[0]
+    mat = np.array(df.drop('subreddit',1),dtype=np.float64)
+    mat[range(n),range(n)] = 1
+    mat[mat > 1] = 1
+    dist = 2*np.arccos(mat)/np.pi
+    tsne_model = TSNE(2,learning_rate=750,perplexity=50,n_iter=10000,metric='precomputed',early_exaggeration=20,n_jobs=-1)
+    tsne_fit_model = tsne_model.fit(dist)
 
 
-mat = mat + np.tril(mat.transpose(),k=-1)
-dist = 2*np.arccos(mat)/np.pi
+    tsne_fit_whole = tsne_fit_model.fit_transform(dist)
 
 
-tsne_model = TSNE(2,learning_rate=500,perplexity=50,n_iter=10000,metric='precomputed',early_exaggeration=20,n_jobs=-1)
+    plot_data = pd.DataFrame({'x':tsne_fit_whole[:,0],'y':tsne_fit_whole[:,1], 'subreddit':df.subreddit})
 
 
-tsne_fit_model = tsne_model.fit(dist)
+    plot_data.to_feather(output)
 
 
-tsne_fit_whole = tsne_fit_model.fit_transform(dist)
-
-plot_data = pd.DataFrame({'x':tsne_fit_whole[:,0],'y':tsne_fit_whole[:,1], 'subreddit':col_names})
-
-plot_data.to_feather("tsne_subreddit_fit.feather")
+if __name__ == "__main__":
+    fire.Fire(fit_tsne)

Community Data Science Collective || Want to submit a patch?