X-Git-Url: https://code.communitydata.science/cdsc_reddit.git/blobdiff_plain/554660275fe525733918aa0e25d0c4ea86dc5a41..07b0dff9bc0dae2ab6f7fb7334007a5269a512ad:/visualization/tsne_vis.py diff --git a/visualization/tsne_vis.py b/visualization/tsne_vis.py deleted file mode 100644 index c192d21..0000000 --- a/visualization/tsne_vis.py +++ /dev/null @@ -1,167 +0,0 @@ -import pyarrow -import altair as alt -alt.data_transformers.disable_max_rows() -alt.data_transformers.enable('default') -from sklearn.neighbors import NearestNeighbors -import pandas as pd -from numpy import random -import fire -import numpy as np - -def base_plot(plot_data): - -# base = base.encode(alt.Color(field='color',type='nominal',scale=alt.Scale(scheme='category10'))) - - cluster_dropdown = alt.binding_select(options=[str(c) for c in sorted(set(plot_data.cluster))]) - - subreddit_dropdown = alt.binding_select(options=sorted(plot_data.subreddit)) - - cluster_click_select = alt.selection_single(on='click,',fields=['cluster'], bind=cluster_dropdown, name=' ') - # cluster_select = alt.selection_single(fields=['cluster'], bind=cluster_dropdown, name='cluster') - # cluster_select_and = cluster_click_select & cluster_select - # - # subreddit_select = alt.selection_single(on='click',fields=['subreddit'],bind=subreddit_dropdown,name='subreddit_click') - - color = alt.condition(cluster_click_select , - alt.Color(field='color',type='nominal',scale=alt.Scale(scheme='category10')), - alt.value("lightgray")) - - - base = alt.Chart(plot_data).mark_text().encode( - alt.X('x',axis=alt.Axis(grid=False),scale=alt.Scale(domain=(-65,65))), - alt.Y('y',axis=alt.Axis(grid=False),scale=alt.Scale(domain=(-65,65))), - color=color, - text='subreddit') - - base = base.add_selection(cluster_click_select) - - - return base - -def zoom_plot(plot_data): - chart = base_plot(plot_data) - - chart = chart.interactive() - chart = chart.properties(width=1275,height=1000) - - return chart - -def viewport_plot(plot_data): - selector1 = alt.selection_interval(encodings=['x','y'],init={'x':(-65,65),'y':(-65,65)}) - selectorx2 = alt.selection_interval(encodings=['x'],init={'x':(30,40)}) - selectory2 = alt.selection_interval(encodings=['y'],init={'y':(-20,0)}) - - base = base_plot(plot_data) - - viewport = base.mark_point(fillOpacity=0.2,opacity=0.2).encode( - alt.X('x',axis=alt.Axis(grid=False)), - alt.Y('y',axis=alt.Axis(grid=False)), - ) - - viewport = viewport.properties(width=600,height=400) - - viewport1 = viewport.add_selection(selector1) - - viewport2 = viewport.encode( - alt.X('x',axis=alt.Axis(grid=False),scale=alt.Scale(domain=selector1)), - alt.Y('y',axis=alt.Axis(grid=False),scale=alt.Scale(domain=selector1)) - ) - - viewport2 = viewport2.add_selection(selectorx2) - viewport2 = viewport2.add_selection(selectory2) - - sr = base.encode(alt.X('x',axis=alt.Axis(grid=False),scale=alt.Scale(domain=selectorx2)), - alt.Y('y',axis=alt.Axis(grid=False),scale=alt.Scale(domain=selectory2)) - ) - - - sr = sr.properties(width=1275,height=600) - - - chart = (viewport1 | viewport2) & sr - - - return chart - -def assign_cluster_colors(tsne_data, clusters, n_colors, n_neighbors = 4): - tsne_data = tsne_data.merge(clusters,on='subreddit') - - centroids = tsne_data.groupby('cluster').agg({'x':np.mean,'y':np.mean}) - - color_ids = np.arange(n_colors) - - distances = np.empty(shape=(centroids.shape[0],centroids.shape[0])) - - groups = tsne_data.groupby('cluster') - - points = np.array(tsne_data.loc[:,['x','y']]) - centers = np.array(centroids.loc[:,['x','y']]) - - # point x centroid - point_center_distances = np.linalg.norm((points[:,None,:] - centers[None,:,:]),axis=-1) - - # distances is cluster x point - for gid, group in groups: - c_dists = point_center_distances[group.index.values,:].min(axis=0) - distances[group.cluster.values[0],] = c_dists - - # nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(centroids) - # distances, indices = nbrs.kneighbors() - - nearest = distances.argpartition(n_neighbors,0) - indices = nearest[:n_neighbors,:].T - # neighbor_distances = np.copy(distances) - # neighbor_distances.sort(0) - # neighbor_distances = neighbor_distances[0:n_neighbors,:] - - # nbrs = NearestNeighbors(n_neighbors=n_neighbors,metric='precomputed').fit(distances) - # distances, indices = nbrs.kneighbors() - - color_assignments = np.repeat(-1,len(centroids)) - - for i in range(len(centroids)): - knn = indices[i] - knn_colors = color_assignments[knn] - available_colors = color_ids[list(set(color_ids) - set(knn_colors))] - - if(len(available_colors) > 0): - color_assignments[i] = available_colors[0] - else: - raise Exception("Can't color this many neighbors with this many colors") - - - centroids = centroids.reset_index() - colors = centroids.loc[:,['cluster']] - colors['color'] = color_assignments - - tsne_data = tsne_data.merge(colors,on='cluster') - return(tsne_data) - -def build_visualization(tsne_data, clusters, output): - - tsne_data = pd.read_feather(tsne_data) - clusters = pd.read_feather(clusters) - - tsne_data = assign_cluster_colors(tsne_data,clusters,10,8) - - term_zoom_plot = zoom_plot(tsne_data) - - term_zoom_plot.save(output) - - term_viewport_plot = viewport_plot(tsne_data) - - term_viewport_plot.save(output.replace(".html","_viewport.html")) - -if __name__ == "__main__": - fire.Fire(build_visualization) - -# commenter_data = pd.read_feather("tsne_author_fit.feather") -# clusters = pd.read_feather('author_3000_clusters.feather') -# commenter_data = assign_cluster_colors(commenter_data,clusters,10,8) -# commenter_zoom_plot = zoom_plot(commenter_data) -# commenter_viewport_plot = viewport_plot(commenter_data) -# commenter_zoom_plot.save("subreddit_commenters_tsne_3000.html") -# commenter_viewport_plot.save("subreddit_commenters_tsne_3000_viewport.html") - -# chart = chart.properties(width=10000,height=10000) -# chart.save("test_tsne_whole.svg")