]> code.communitydata.science - cdsc_reddit.git/blobdiff - clustering/clustering_base.py
Updates to similarities code for smap project.
[cdsc_reddit.git] / clustering / clustering_base.py
index 1d24533b520865d8e3f8bd53bad8a344178d8741..3778fc3fa91259f49a1b1470e7d0901e1f3ee6ba 100644 (file)
@@ -3,6 +3,7 @@ import numpy as np
 import pandas as pd
 from dataclasses import dataclass
 from sklearn.metrics import silhouette_score, silhouette_samples
 import pandas as pd
 from dataclasses import dataclass
 from sklearn.metrics import silhouette_score, silhouette_samples
+from collections import Counter
 
 # this is meant to be an interface, not created directly
 class clustering_job:
 
 # this is meant to be an interface, not created directly
 class clustering_job:
@@ -38,9 +39,11 @@ class clustering_job:
         return self.result
 
     def silhouette(self):
         return self.result
 
     def silhouette(self):
-        isolates = self.clustering.labels_ == -1
+        counts = Counter(self.clustering.labels_)
+        singletons = [key for key, value in counts.items() if value == 1]
+        isolates = (self.clustering.labels_ == -1) | (np.isin(self.clustering.labels_,np.array(singletons)))
         scoremat = self.mat[~isolates][:,~isolates]
         scoremat = self.mat[~isolates][:,~isolates]
-        if scoremat.shape[0] > 0:
+        if self.n_clusters > 1:
             score = silhouette_score(scoremat, self.clustering.labels_[~isolates], metric='precomputed')
             silhouette_samp = silhouette_samples(self.mat, self.clustering.labels_, metric='precomputed')
             silhouette_samp = pd.DataFrame({'subreddit':self.subreddits,'score':silhouette_samp})
             score = silhouette_score(scoremat, self.clustering.labels_[~isolates], metric='precomputed')
             silhouette_samp = silhouette_samples(self.mat, self.clustering.labels_, metric='precomputed')
             silhouette_samp = pd.DataFrame({'subreddit':self.subreddits,'score':silhouette_samp})
@@ -80,8 +83,9 @@ class clustering_job:
 
         print(f"{n_isolates1} clusters have 1 member")
 
 
         print(f"{n_isolates1} clusters have 1 member")
 
-        n_isolates2 = (cluster_sizes.loc[cluster_sizes.cluster==-1,['subreddit']])
-
+        n_isolates2 = cluster_sizes.loc[cluster_sizes.cluster==-1,:]['subreddit'].to_list()
+        if len(n_isolates2) > 0:
+            n_isloates2 = n_isolates2[0]
         print(f"{n_isolates2} subreddits are in cluster -1",flush=True)
 
         if n_isolates1 == 0:
         print(f"{n_isolates2} subreddits are in cluster -1",flush=True)
 
         if n_isolates1 == 0:

Community Data Science Collective || Want to submit a patch?