Voltar para o blog
Pesquisa em IA

Flash-KMeans: Implementação Exata de K-Means Roda 200× Mais Rápido que FAISS em GPUs

15 de junho de 2026
07:48
GPUmachine learningk-meansFAISSclusteringTritonFlashAttention
Flash-KMeans: Implementação Exata de K-Means Roda 200× Mais Rápido que FAISS em GPUs

Pesquisadores acabam de lançar o Flash-KMeans, uma implementação exata do algoritmo k-means de Lloyd que roda mais de 200× mais rápido que o FAISS em GPUs — sem aproximações, sem atalhos algorítmicos.

O projeto, open-source sob licença Apache 2.0, é escrito em kernels Triton para GPU e instalável via pip install flash-kmeans. A grande inovação está em reestruturar como os dados se movem na GPU, não em mudar a matemática.

Os dois gargalos que o Flash-KMeans resolve

1. FlashAssign: eliminação da matriz de distância N×K

Em implementações convencionais, a etapa de atribuição materializa a matriz completa de distâncias N×K na memória de alta largura de banda (HBM) e depois a lê de volta para encontrar o argmin. Essa operação O(NK) domina o tempo de execução.

O FlashAssign transmite tiles de pontos e centroides da HBM para a SRAM on-chip, fundindo o cálculo de distância com um argmin online. A matriz N×K jamais é instanciada, reduzindo a E/S dominante de O(NK) para O(Nd + Kd).

  • Aceleração do kernel: até 21,2×
  • Inspirado no padrão de tiling e recomputação do FlashAttention

2. Sort-Inverse Update: atualização de centroides sem contenção

Atualizações convencionais com scatter e adições atômicas sofrem contenção quando muitas threads miram o mesmo cluster "quente". A equipe mediu apenas 50 GB/s de largura de banda efetiva nesse cenário em uma H200.

O Sort-Inverse Update ordena o vetor de atribuição por ID de cluster (argsort), criando segmentos contíguos de IDs idênticos. Cada bloco de threads reduz um segmento on-chip e emite uma adição atômica por segmento — não por ponto.

  • Aceleração do kernel: até 6,3×

Benchmarks (NVIDIA H200, FP16, d=128)

Comparação Aceleração Contexto
End-to-end vs melhor baseline até 17,9× N=8M, K=1024
vs NVIDIA cuML 33× Comparação completa de biblioteca
vs FAISS mais de 200× Padrão da indústria
FlashAssign kernel isolado até 21,2× N=1M, K=8192
Sort-Inverse Update isolado até 6,3× N=33M, K=4096
Out-of-core, 1 bilhão de pontos 41,4s/iter Baseline: 261,8s

Casos de uso transformadores

Com essa aceleração, o clustering online dentro de loops de treinamento e inferência se torna viável:

  • Indexação de busca vetorial — reindexar frequentemente conforme os dados mudam
  • Roteamento de atenção esparsa (Routing Transformers, Tactic) — agrupar tokens em tempo de inferência
  • Compressão de KV-cache (ClusterKV) — clustering por camada e por passo
  • Quantização de KV de baixo bit — construção rápida de codebooks durante compressão
  • Diffusion Transformers (Sparse VideoGen2) — k-means em lote dentro do forward pass

API familiar

O Flash-KMeans oferece uma API que espelha FAISS/scikit-learn:

import torch
from flash_kmeans import batch_kmeans_Euclid

x = torch.randn(32, 75600, 128, device="cuda", dtype=torch.float16)
cluster_ids, centers, _ = batch_kmeans_Euclid(
    x, n_clusters=1000, tol=1e-4, verbose=True
)

Com dispatch automático para múltiplas GPUs e modo out-of-core que esconde transferências PCIe atrás de computação.


O Flash-KMeans demonstra que ainda há ganhos massivos de performance a serem extraídos puramente da reestruturação de fluxo de dados em GPU — sem comprometer a exatidão dos resultados.