Flash-KMeans: Implementação Exata de K-Means Roda 200× Mais Rápido que FAISS em GPUs

Pesquisadores acabam de lançar o Flash-KMeans, uma implementação exata do algoritmo k-means de Lloyd que roda mais de 200× mais rápido que o FAISS em GPUs — sem aproximações, sem atalhos algorítmicos.
O projeto, open-source sob licença Apache 2.0, é escrito em kernels Triton para GPU e instalável via pip install flash-kmeans. A grande inovação está em reestruturar como os dados se movem na GPU, não em mudar a matemática.
Os dois gargalos que o Flash-KMeans resolve
1. FlashAssign: eliminação da matriz de distância N×K
Em implementações convencionais, a etapa de atribuição materializa a matriz completa de distâncias N×K na memória de alta largura de banda (HBM) e depois a lê de volta para encontrar o argmin. Essa operação O(NK) domina o tempo de execução.
O FlashAssign transmite tiles de pontos e centroides da HBM para a SRAM on-chip, fundindo o cálculo de distância com um argmin online. A matriz N×K jamais é instanciada, reduzindo a E/S dominante de O(NK) para O(Nd + Kd).
- Aceleração do kernel: até 21,2×
- Inspirado no padrão de tiling e recomputação do FlashAttention
2. Sort-Inverse Update: atualização de centroides sem contenção
Atualizações convencionais com scatter e adições atômicas sofrem contenção quando muitas threads miram o mesmo cluster "quente". A equipe mediu apenas 50 GB/s de largura de banda efetiva nesse cenário em uma H200.
O Sort-Inverse Update ordena o vetor de atribuição por ID de cluster (argsort), criando segmentos contíguos de IDs idênticos. Cada bloco de threads reduz um segmento on-chip e emite uma adição atômica por segmento — não por ponto.
- Aceleração do kernel: até 6,3×
Benchmarks (NVIDIA H200, FP16, d=128)
| Comparação | Aceleração | Contexto |
|---|---|---|
| End-to-end vs melhor baseline | até 17,9× | N=8M, K=1024 |
| vs NVIDIA cuML | 33× | Comparação completa de biblioteca |
| vs FAISS | mais de 200× | Padrão da indústria |
| FlashAssign kernel isolado | até 21,2× | N=1M, K=8192 |
| Sort-Inverse Update isolado | até 6,3× | N=33M, K=4096 |
| Out-of-core, 1 bilhão de pontos | 41,4s/iter | Baseline: 261,8s |
Casos de uso transformadores
Com essa aceleração, o clustering online dentro de loops de treinamento e inferência se torna viável:
- Indexação de busca vetorial — reindexar frequentemente conforme os dados mudam
- Roteamento de atenção esparsa (Routing Transformers, Tactic) — agrupar tokens em tempo de inferência
- Compressão de KV-cache (ClusterKV) — clustering por camada e por passo
- Quantização de KV de baixo bit — construção rápida de codebooks durante compressão
- Diffusion Transformers (Sparse VideoGen2) — k-means em lote dentro do forward pass
API familiar
O Flash-KMeans oferece uma API que espelha FAISS/scikit-learn:
import torch
from flash_kmeans import batch_kmeans_Euclid
x = torch.randn(32, 75600, 128, device="cuda", dtype=torch.float16)
cluster_ids, centers, _ = batch_kmeans_Euclid(
x, n_clusters=1000, tol=1e-4, verbose=True
)
Com dispatch automático para múltiplas GPUs e modo out-of-core que esconde transferências PCIe atrás de computação.
O Flash-KMeans demonstra que ainda há ganhos massivos de performance a serem extraídos puramente da reestruturação de fluxo de dados em GPU — sem comprometer a exatidão dos resultados.