.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/02_plot_imbalanced_ekmeans.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_02_plot_imbalanced_ekmeans.py: Imbalanced clustering comparison with EKMeans ============================================= This example compares clustering performance on an imbalanced dataset for several algorithms, including EKMeans and MiniBatchEKMeans. It is intended for the gallery and requires matplotlib to render plots. .. GENERATED FROM PYTHON SOURCE LINES 11-186 .. image-sg:: /auto_examples/images/sphx_glr_02_plot_imbalanced_ekmeans_001.png :alt: Ground truth, EKMeans, MiniBatchEKMeans, KMeans, MiniBatchKMeans, MeanShift, DBSCAN, OPTICS, Affinity Propagation, Birch, Spectral, Agglomerative :srcset: /auto_examples/images/sphx_glr_02_plot_imbalanced_ekmeans_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none /opt/hostedtoolcache/Python/3.12.11/x64/lib/python3.12/site-packages/sklearn/cluster/_affinity_propagation.py:140: ConvergenceWarning: Affinity propagation did not converge, this model may return degenerate cluster centers and labels. warnings.warn( | .. code-block:: Python import time import matplotlib.pyplot as plt import numpy as np from sklearn.cluster import ( DBSCAN, OPTICS, AffinityPropagation, AgglomerativeClustering, Birch, KMeans, MeanShift, MiniBatchKMeans, SpectralClustering, ) from sklearn.datasets import make_blobs from sklekmeans import EKMeans, MiniBatchEKMeans def _make_imbalanced( n_samples=2000, weights=(0.840, 0.01, 0.05), centers=np.array([[-3, -2], [2, -2], [2, 2]]), cluster_std=(1.0, 1.0, 1.0), random_state=0, ): X, y = make_blobs( n_samples=n_samples, centers=centers, cluster_std=cluster_std, random_state=random_state, ) # Reweight labels to match desired imbalance by subsampling rng = np.random.RandomState(random_state) out_X, out_y = [], [] for k in range(centers.shape[0]): idx = np.flatnonzero(y == k) take = int(round(weights[k] * n_samples)) take = min(take, idx.size) sel = rng.choice(idx, size=take, replace=False) out_X.append(X[sel]) out_y.append(np.full(sel.size, k)) return np.vstack(out_X), np.concatenate(out_y) def _plot(ax, X, labels, title, *, estimator=None, runtime=None): ax.scatter(X[:, 0], X[:, 1], c=labels, s=10, cmap="tab10") ax.set_title(title) ax.set_xticks([]) ax.set_yticks([]) # Optional: annotate runtime in the bottom-right corner if runtime is not None: ax.text( 0.99, 0.01, (f"{runtime:.2f}s").lstrip("0"), transform=ax.transAxes, fontsize=9, ha="right", va="bottom", ) # If centroid-based, overlay centers if estimator is not None: centers = None if ( hasattr(estimator, "cluster_centers_") and estimator.cluster_centers_ is not None ): centers = estimator.cluster_centers_ elif ( hasattr(estimator, "cluster_centers_indices_") and estimator.cluster_centers_indices_ is not None ): try: centers = X[np.asarray(estimator.cluster_centers_indices_, dtype=int)] except Exception: centers = None if centers is not None and centers.size > 0: ax.scatter( centers[:, 0], centers[:, 1], s=80, c="red", marker="x", linewidths=1.5, ) def main(): X, y = _make_imbalanced() fig, axes = plt.subplots(3, 4, figsize=(12, 6), constrained_layout=True) axes = axes.ravel() _plot(axes[0], X, y, "Ground truth") t0 = time.perf_counter() ekm = EKMeans(n_clusters=3, n_init=10, random_state=0, alpha="dvariance").fit(X) t1 = time.perf_counter() _plot(axes[1], X, ekm.labels_, "EKMeans", estimator=ekm, runtime=(t1 - t0)) t0 = time.perf_counter() mbekm = MiniBatchEKMeans( n_clusters=3, n_init=10, random_state=0, batch_size=256 ).fit(X) t1 = time.perf_counter() _plot( axes[2], X, mbekm.labels_, "MiniBatchEKMeans", estimator=mbekm, runtime=(t1 - t0), ) t0 = time.perf_counter() km = KMeans(n_clusters=3, n_init=10, random_state=0).fit(X) t1 = time.perf_counter() _plot(axes[3], X, km.labels_, "KMeans", estimator=km, runtime=(t1 - t0)) t0 = time.perf_counter() mbk = MiniBatchKMeans(n_clusters=3, n_init=10, random_state=0, batch_size=256).fit( X ) t1 = time.perf_counter() _plot(axes[4], X, mbk.labels_, "MiniBatchKMeans", estimator=mbk, runtime=(t1 - t0)) t0 = time.perf_counter() ms = MeanShift().fit(X) t1 = time.perf_counter() _plot(axes[5], X, ms.labels_, "MeanShift", estimator=ms, runtime=(t1 - t0)) t0 = time.perf_counter() db = DBSCAN(min_samples=10).fit(X) t1 = time.perf_counter() _plot(axes[6], X, db.labels_, "DBSCAN", runtime=(t1 - t0)) t0 = time.perf_counter() optics = OPTICS(min_samples=10).fit(X) t1 = time.perf_counter() _plot(axes[7], X, optics.labels_, "OPTICS", runtime=(t1 - t0)) t0 = time.perf_counter() affinity_propagation = AffinityPropagation(random_state=0).fit(X) t1 = time.perf_counter() _plot( axes[8], X, affinity_propagation.labels_, "Affinity Propagation", estimator=affinity_propagation, runtime=(t1 - t0), ) t0 = time.perf_counter() birch = Birch(n_clusters=3).fit(X) t1 = time.perf_counter() _plot(axes[9], X, birch.labels_, "Birch", estimator=birch, runtime=(t1 - t0)) t0 = time.perf_counter() sc = SpectralClustering(n_clusters=3, assign_labels="kmeans", random_state=0).fit(X) t1 = time.perf_counter() _plot(axes[10], X, sc.labels_, "Spectral", runtime=(t1 - t0)) t0 = time.perf_counter() ac = AgglomerativeClustering(n_clusters=3).fit(X) t1 = time.perf_counter() _plot(axes[11], X, ac.labels_, "Agglomerative", runtime=(t1 - t0)) plt.show() if __name__ == "__main__": main() .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 3.479 seconds) .. _sphx_glr_download_auto_examples_02_plot_imbalanced_ekmeans.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 02_plot_imbalanced_ekmeans.ipynb <02_plot_imbalanced_ekmeans.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 02_plot_imbalanced_ekmeans.py <02_plot_imbalanced_ekmeans.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 02_plot_imbalanced_ekmeans.zip <02_plot_imbalanced_ekmeans.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_