Hacker Newsnew | past | comments | ask | show | jobs | submitlogin
Generalized K-Means Clustering (github.com/derrickburns)
192 points by derrickrburns on Jan 13, 2024 | hide | past | favorite | 80 comments


I built a pipeline to automatically cluster and visualize large amounts of text documents in a completely unsupervised manner:

- Embed all the text documents.

- Project to 2D using UMAP which also creates its own emergent "clusters".

- Use k-means clustering with a high cluster count depending on dataset size.

- Feed the ChatGPT API ~10 examples from each cluster and ask it to provide a concise label for the cluster.

- Bonus: Use DBSCAN to identify arbitrary subclusters within each cluster.

It is extremely effective and I have a theoetical implementation of a more practical use case to use said UMAP dimensionality reduction for better inference. There is evidence that current popular text embedding models (e.g. OpenAI ada, which outputs 1536D embeddings) are way too big for most use cases and could be giving poorly specified results for embedding similarity as a result, in addition to higher costs for the entire pipeline.


Funny, I did almost the exact same thing: https://github.com/colehaus/hammock-public. Though I project to 3D and then put them in an interactive 3D plot. The other fun little thing the interactive plotting enables is stepping through a variety of clustering granularities.


Thanks for sharing. I 'd like to know what the (re)compute time might be when adding, say, another million documents using this pipeline. The cluster embedding approach in my view, while streamlined, still adds a (sometimes significant) timebump when high throughput is required.

I see some significant speedups can be achieved when discretising dimensions into buckets, and doing a simple frequency count of associated buckets -- leaving only highly related buckets per document. These 'signatures' can then be indexed LSH style and a graph construed from documents with similar hashes.

When the input set is sufficiently large, this graph contains 'natural' clusters, without any UMAP or k-means parameter tuning required. When implemented in BQ, I achieve sub minute performance for 5-10 million documents, from indexing to clustering.


You can also look at Bertopic which has this functionality as an open source library:

https://maartengr.github.io/BERTopic/index.html


I did something similar (but not for documents) but I’m struggling with selecting the optimal number of clusters.


Cluster stability is a good heuristic that should be more well-known:

For a given k:

  for n=30 or 100 or 300 trials:
    subsample 80% of the points
    cluster them
    compute Fowlkes-Mallow score (available in sklearn) of the subset to the original, restricting only to the instances in the subset (otherwise you can't compute it)
  output the average f-m score
This essentially measure how "stable" the clusters are. The Fowlkes-Mallow score decreases when instances pop over to other clusters in the subset.

If you do this and plot the average score versus k, you'll see a sharp dropoff at some point. That's the maximal plausible k.

edit: Here's code

  def stability(Z, k):
    kmeans = KMeans(n_clusters=k, n_init="auto")
    kmeans.fit(Z)
    scores = []
    for i in range(100):
        # Randomly select 80% of the data, with replacement
        # TODO: without
        idx = np.random.choice(Z.shape[0], int(Z.shape[0]*0.8))
        kmeans2 = KMeans(n_clusters=k, n_init="auto")
        kmeans2.fit(Z[idx])

        # Compare the two clusterings
        score = fowlkes_mallows_score(kmeans.labels_[idx], kmeans2.labels_)
        scores.append(score)
    scores = np.array(scores) 
    return np.mean(scores), np.std(scores)


A simple metric for that is the Silhouette

https://en.m.wikipedia.org/wiki/Silhouette_(clustering)

Another elegant method is the Calinsky-Harabasz Index

https://en.m.wikipedia.org/wiki/Calinski%E2%80%93Harabasz_in...


Checkout hdbscan


When doing DBSCAN on the subclusters, do you cluster on the 2-D projected space? Do you use the original 2-D projection you used prior to k-means, or does each subcluster get its own UMAP projection?


I DBSCAN in the 2D projected space.

These aren't visualized: I use identified clusters to look at manually to find trends.


Is it possible to dbscan on the unprojected space or does that lead to poor effectiveness? Also what led you to choose dbscan vs another technique?


Poor effectiveness. (again another hint why working in high dimensional space may not be ideal)

I was not aware of a robust clustering technique that's better/as easy to use other than DBSCAN.


Any reason to pick DBSCAN instead of HDBSCAN*?


Interesting. What do you use the visualization for? Looking at trends in the documents?


Let's say you want to look at a large dataset of user-submitted reviews for you app. User reviews are written extremely idiosyncratic so all traditional NLP methods will likely fail.

With the pipeline mentioned, it's much easier to look at cluster density to identify patterns and high-level trends.


Why not just use DBSCAN though


You can use DBSCAN instead of k-means, but DBSCAN has a worst-case memory complexity of O(n^2) so things can get spicy with large datasets, which is why I opt it to only use it for subclusters. k-means also fixes the number of clusters, which is good for visualization sanity.

https://scikit-learn.org/stable/modules/generated/sklearn.cl...


Isn’t the embedding step much slower than clustering? How many documents are you dealing with?

For I news aggregator I worked on I disregarded k-means because you have to know the number of clusters in advance, and I think it will cluster every document, which is bad for the actual outliers in a dataset.

Agglomerative clustering yielded the best results for us. HDBSCAN was promising but doing weird things with some docs.


The embedding step is certainly slower than clustering, but the memory requirements blow up pretty fast when you're doing density-based clustering on a dataset of even, say, 100k embeddings.


Which libraries are you using, in particular for the first step?


Embeddings is just SentenceTransformers: https://www.sbert.net/

I used the bge-large-en-v1.5 model (https://huggingface.co/BAAI/bge-large-en-v1.5) because I could, but the common all-MiniLM-L6-v2 model is sufficient. The trick is to batch generate the embeddings on a GPU, which SentenceTransformers mostly does by default.

Other libraries are the typical ones (umap for UMAP, scikit-learn for k-means/DBSCAN, chatgpt-python for ChatGPT interfacing, plotly for viz, pandas for some ETL). You don't need to use a bespoke AI/ML package for these workflows and they aren't too complicated.


It's just SentenceTransformers, but: the wrong model is common because no one read SentenceTransformers. MiniLM-L6-V2 is for symmetric search (target document has same wording as source document) MiniLM-L6-V3 is for asymmetric search (target document is likely to contain material matching query in source document)


Can you share your chatGPT prompt, please? I do something similar at the moment and I try out Bert topic, but chatGPT seems also worth a try.


Why 2D? (edit: just the vis or there is some other reason?)


Both the viz, and that the 2D UMAP projection is actually enough to get accurately delineated topics.

Hence why I think the typical embedding dimensionality is way way too high.


Do you think 1D could work? Maybe topic-space is some sort of tree-shaped structure where documents live in the thin strands.


1D could work on certain datasets but it wouldn't be ideal.


Why not just embed directly to 2d? Does it give worse results than UMAP?


cluster naming was still an open problem pre-LLM


AI has sparked new interest in high dimensional embeddings for approximate nearest neighbor search. Here is a highly scalable, implementation of a companion technique, k-means clustering that uses Spark 1.1 written in Scala.

Please let me know if you fork this library and update it to the latter versions of Spark.


Just curious, have you actually profiled this against running on a single large-memory machine?


There is a Twitch streamer Tsoding who posted a video of himself implementing K-means clustering in C recently [1]. He also does a follow up 3d visualization of the algorithm in progress using raylib [2].

1. https://www.youtube.com/watch?v=kH-hqG34ylA&t=4788s&ab_chann...

2. https://www.youtube.com/watch?v=K7hWqxC_7Mw&ab_channel=Tsodi...


Here’s a very simple toy demonstration of how K-Means works that I made for fun years ago while studying machine learning: https://k-means.stackblitz.io/

Essentially K-Means is a way of “learning” categories or other kinds of groupings within an unlabeled dataset, without any fancy deep learning. It’s handy for its simplicity and speed.

The demo works with simple 2D coordinates for illustrative purposes but the technique works with any number of dimensions.

Note that there may be some things I got wrong with the implementation and that there are other variations of the algorithm surely, but it still captures the basic idea well enough for an intro.


I applied to a certain scraping fintech in the Bay Area around 5 years ago and was asked to open the Wikipedia page to k-means squared clustering and implement the algorithm with tests from scratch. I was applying for an android position. I still laugh thinking about how they paid to fly me out and ask such a stupid interview question.


I see how it might not have anything to do with usual Android development, but why do you consider it a stupid question?

K-means is not that complicated and naive implementation with e.g. Euclidean distance is a couple of dozens of lines of code, so should be practical enough for an interview.


What are people using k-means for? I can count on one hand the number of times I’ve had a good a priori rationale for the value of k.


Polis (and Twitter's community notes, I believe)

Participation At Scale Can Repair The Public Square https://www.noemamag.com/participation-at-scale-can-repair-t...

Polis: Scaling deliberation by mapping high dimensional opinion spaces https://scholar.google.com/scholar?q=Polis:+Scaling+delibera...

Restricting clustering to 2-5 groups impacts group aware/informed consensus and comment routing https://github.com/compdemocracy/polis/issues/1289


The kmeans metric is exactly the metric you would want to optimize the performance of an algorithm like [bolt](https://arxiv.org/abs/1706.10283). In that and other discretization routines, the value of k is a parameter related to compression ratio, efficiency, and other metrics more predictable than some ethereal notion of how many clusters the data "naturally" has.


I would recommend checking out DBSCAN as it is similar without having to provide a number k https://en.m.wikipedia.org/wiki/DBSCAN


I've used something similar for tissue segmentation from hyperspectral images of animals where I know there should be K different tissue types I care about.


We used k means clustering on a project used to track fruit fly memory and learning behaviors

http://git.ceux.org/FlyTracking.git/


Sometimes you can use a heuristic to estimate K, or use a variant that terminates at some distance threshold.

That said, something like hdbscan doesn’t suffer from this problem.


Used it in college to downscale an X color image to Y number of colors. Sure, Photoshop does it, but it was informative to do it manually.


Google's Material You uses this to initiate color theming

(n.b. Celebi's, note usage of Lab / notHSL, respect Cartesian / polar nature of inputs / outputs, and ask for high-K, 128 is what we went with but it's arbitrary. Can get away with as few as 32 if you're ex. Doing brand color from favicon)


Ooh this is a much nicer approach than the kind of brute force approach we took at work for theme gen


I did a modified version of this once for a map of auto dealerships, although rather than working with a fixed k, I used a fix threshold for cluster distance. The algorithm I was working with had O(n³) complexity so to keep the pregeneration of clusters manageable, I partitioned data by state. The other fun part was finding the right metric formula for measuring distances. Because clusters needed to correspond to the rectangular view window on the map, rather than a standard Euclidean distance, I used d = max(Δxy) which gives square neighborhoods rather than round ones.


I implemented an algorithm which used k-means to reduce noise in a path tracer.

For each pixel instead of a single color value it generated k mean color values, using an online algorithm. These were then combined to produce the final pixel color.

The idea was that a pixel might have several distinct contributions (ie from different light sources for example), but due to the random sampling used in path tracing the variance of sample values is usually large.

The value k then was chosen based on scene complexity. There was also a memory trade-off of course, as memory usage was linear in k.


> For each pixel instead of a single color value it generated k mean color values, using an online algorithm.

What does online mean here?


Online means it processes the items as they come[1]. This means the algorithm can't consider all the items at once, and has to try to be clever on the spot.

In my case the algorithm uswd would use the first k samples as the initial means, and would then find which of the k means were closest to the current sample, and update that mean[2].

Given that in parh tracing one would typically use a fairly large number of samples per pixel relative to k, this approach did a reasonable job of approximating the k means.

[1]: https://en.wikipedia.org/wiki/Online_algorithm

[2]: https://yokolet.com/2017/05/29/online-algorithm.html


k-means is good for fast unsupervised clustering on an unknown low-dimensional dataset. It's helpful for EDA.

If you want accuracy at an order of magnitude more compute, you can use something like DBSCAN.


Real estate price estimates would be the classic, and frankly still common, example


Measuring multiple physical objects with the same sensor, you can use k-means to separate the measurements from each object, given that you know how many objects are being sensed. I can't get more specific than that.


Sounds like you're describing my carpet-color classification project! [1]

Built as part of a larger carpet based localisation project [2]

1: https://nbviewer.org/github/tim-fan/carpet_color_classificat...

2: https://github.com/tim-fan/carpet_localisation/wiki/Carpet-L...


Baseball pitch types based on physics profiles.


Constantly in remote sensing and GIS work.


Frequently used in e-commerce, such as RFM clustering for targeted marketing.


Additional small molecule pharmaceutical candidates via molecular descriptors


I've used it for identifying dominant colors in images.


data segmentation (e.g Voronoi), grouping (like search query results), anomaly detection. lots of different things


So after re-reading your comment a few times, I am left with this thought: either you don't understand what k-means clustering is, or I don't understand what k-means clustering is. I wouldn't describe myself as a machine learning expert, but I have taken some grad level classes in statistics, analytics, and methods like this/related to this.

So my question is... could you elaborate?


Not GP, but I understood their question as follows.

Assume you collect some kindergartners and top NBA players into a room and collect their heights. Now say you pass these to two hapless grad students and ask them to perform K-means clustering.

Suppose one of the grad students knew the composition of the people you measured and can guess these height should clump into 2 nice clusters. The other student who doesn't know the composition of the class - what should they guess K to be?

I understood the GP's comment to refer to the state of the second grad student. How useful is K-means clustering without knowing K in advance?


>I understood the GP's comment to refer to the state of the second grad student. How useful is K-means clustering without knowing K in advance?

There are several heuristics for this. Googling I see that the elbow method, average sillhouette method and gap statistic method is the most used.

I think you could play around with your own heuristics as well. Simple KDE plots showing the amount of peaks. Maybe, say the variance between clusters should be greater than the variance inside any cluster could maybe work. (Edit: this seems to be the main point of the average sillhouette method).


The first problem is picking k. The second problem is the definition of distance or, equivalently, the uniformity of the space. Naively using the Euclidean distance in an embedding where similarity is non-uniform leads to bad outcomes. This problem is solved by learning a uniform embedding, and this is much harder than running k-means.

k-means assumes these hard parts of the problem are taken care of and offers a trivial solution to the rest. Thanks for the help, I'll cluster it myself.


You have to choose the number of clusters, before using k-means.

Imagine that you have a dataset, where you think there are likely meaningful clusters, but you don't know how many, especially where it's many-dimensioned.

If you pick a k that is too small, you lump unrelated points together.

If k is too large, your meaningful clusters will be fragmented/overfitted.

There are some algorithms that try to estimate the number of clusters or try to find the k with the best fit to the data to make up for this.


Couldn’t you make some educated guesses and then stop when you arrive at a K that gives you meaningful clusters that are neither too high level nor too atomized.


Probably not the best in terms of efficiency.

Easier just to deliberately overshoot (with a too high k) and then merge any clusters with too much overlap.


K is 3.

It’s honestly fine for just finding key differences like a principal component for light storytelling. They don’t need to be distinct clusters


SMOTE


knowledge graphs


Although K-means clustering is often the correct approach given time crunch and code complexity constraints, I don't like how it's hard to extend and how it's not principled. By not principled, I mean that it feels more like an algorithm (that happens to optimize) rather than an explicit optimization with an explicit loss function. And I found that in practice, modifying the distance function to anything more interesting doesn't work.


K-means clustering is very well principled actually as an instance of the expectation maximization algorithm with "hard" cluster assignment. Turns out it's just good old maximum likelihood:

https://alliance.seas.upenn.edu/~cis520/dynamic/2022/wiki/in...


There are two issues I had in mind. One is that the link between argmin and the algorithm (k-means in this case) feels too "tied to the algorithm" and less explicit than in other algorithms.

The other is that in practice, you typically want to bring your true optimization objective as close as possible to what the algorithm is optimizing, and what k-means is optimizing for is usually pretty far removed. Even small tweaks (lets say, augmenting data with some sparse labels, or modifying the loss function weight based on some aspect of embedding values) are difficult to do with k-means.


Isn’t it also formally equivalent to a Gaussian mixture model?

https://timydaley.github.io/kmeans_gmm/gmm_vs_kmeans.html


I remember when i first learned k-means, it opened the door for so many projects. Two that are on my GitHub to this day are a python script that groups your images by similarly (histogram) and one that classify your expenses based on previous data. I had so much fun working on those.


Check out sampling with lightweight coresets if your data is big - it's a principled approach with theoretical guarantees, and it's only a couple of lines of numpy. Do check if the assumptions hold for your data though, as they are stronger than with regular coresets.


Do you have a link to any implementations for this?



Fun fact: K-Means is the least interesting clustering algorithm known to humans, but is quite fast and therefore useful in certain applications


It's boring, in a sense that it always gives reasonable results and is easy to implement. It also scales well in N, D and K, and from my experience converges in just a few iterations from anything better than a pure random initialisation strategy.

IMO it is very good as a final clustering algorithm once you've already applied some more complex transformations on your data to linearise it and account for deeper knowledge of the problem. This might be a spectral space transformation (you care about connectedness), or an embedding (you care about whatever the network was trained on) or descriptor (you care about the algorithm's similarity).

But once you've applied the transform, you then have a minimalist fast scalable clustering that just does clustering and doesn't need to know anything more about the problem being solved. Very unix-y feeling.


Does the “curse of dimensionality” affect the usefulness of k-means?




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: