X-Git-Url: https://code.communitydata.science/cdsc_reddit.git/blobdiff_plain/4cb7eeec80c5a9c8f49339acd378c515e290ed81..87ffaa6858919bd830694d60dd4fc7b1857b462a:/clustering/clustering_base.py diff --git a/clustering/clustering_base.py b/clustering/clustering_base.py index 1d24533..3778fc3 100644 --- a/clustering/clustering_base.py +++ b/clustering/clustering_base.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd from dataclasses import dataclass from sklearn.metrics import silhouette_score, silhouette_samples +from collections import Counter # this is meant to be an interface, not created directly class clustering_job: @@ -38,9 +39,11 @@ class clustering_job: return self.result def silhouette(self): - isolates = self.clustering.labels_ == -1 + counts = Counter(self.clustering.labels_) + singletons = [key for key, value in counts.items() if value == 1] + isolates = (self.clustering.labels_ == -1) | (np.isin(self.clustering.labels_,np.array(singletons))) scoremat = self.mat[~isolates][:,~isolates] - if scoremat.shape[0] > 0: + if self.n_clusters > 1: score = silhouette_score(scoremat, self.clustering.labels_[~isolates], metric='precomputed') silhouette_samp = silhouette_samples(self.mat, self.clustering.labels_, metric='precomputed') silhouette_samp = pd.DataFrame({'subreddit':self.subreddits,'score':silhouette_samp}) @@ -80,8 +83,9 @@ class clustering_job: print(f"{n_isolates1} clusters have 1 member") - n_isolates2 = (cluster_sizes.loc[cluster_sizes.cluster==-1,['subreddit']]) - + n_isolates2 = cluster_sizes.loc[cluster_sizes.cluster==-1,:]['subreddit'].to_list() + if len(n_isolates2) > 0: + n_isloates2 = n_isolates2[0] print(f"{n_isolates2} subreddits are in cluster -1",flush=True) if n_isolates1 == 0: