O algoritmo de agrupamento K-Means em Java
1. Visão geral
Clustering é um termo abrangente para uma classe de algoritmos não supervisionados paradiscover groups of things, people, or ideas that are closely related to each other.
Nessa definição aparentemente simples de uma linha, vimos algumas chavões. O que exatamente é o agrupamento? O que é um algoritmo não supervisionado?
Neste tutorial, vamos, primeiro, lançar algumas luzes sobre esses conceitos. Então, veremos como eles podem se manifestar em Java.
2. Algoritmos não supervisionados
Antes de usar a maioria dos algoritmos de aprendizado, devemos fornecer alguns dados de amostra para eles e permitir que o algoritmo aprenda com esses dados. Na terminologia do aprendizado de máquina,we call that sample dataset training data. Além disso,the whole process is known as the training process.
De qualquer forma,we can classify learning algorithms based on the amount of supervision they need during the training process. Os dois principais tipos de algoritmos de aprendizagem nesta categoria são:
-
Supervised Learning: Em algoritmos supervisionados, os dados de treinamento devem incluir a solução real para cada ponto. Por exemplo, se estamos prestes a treinar nosso algoritmo de filtragem de spam, alimentamos os e-mails de amostra e seus marcadores, ou seja, spam ou não spam, para o algoritmo. Matematicamente falando, vamos inferirf(x) de um conjunto de treinamento incluindoxs areiays.
-
Unsupervised Learning: quando não há rótulos nos dados de treinamento, o algoritmo não é supervisionado. Por exemplo, temos muitos dados sobre músicos e vamos descobrir grupos de músicos semelhantes nos dados.
3. Agrupamento
O agrupamento é um algoritmo não supervisionado para descobrir grupos de coisas, idéias ou pessoas semelhantes. Ao contrário dos algoritmos supervisionados, não estamos treinando algoritmos de agrupamento com exemplos de rótulos conhecidos. Em vez disso, o armazenamento em cluster tenta encontrar estruturas dentro de um conjunto de treinamento em que nenhum ponto dos dados seja o rótulo.
3.1. K-Means Clustering
K-Means é um algoritmo de agrupamento com uma propriedade fundamental:the number of clusters is defined in advance. Além do K-Means, existem outros tipos de algoritmos de clustering, como Hierarchical Clustering, Affinity Propagation ouSpectral Clustering.
3.2. Como funciona o K-Means
Suponha que nosso objetivo seja encontrar alguns grupos semelhantes em um conjunto de dados como:
K-Means começa com k centróides colocados aleatoriamente. Centroids, as their name suggests, are the center points of the clusters. Por exemplo, aqui estamos adicionando quatro centróides aleatórios:
Em seguida, atribuímos cada ponto de dados existente ao centróide mais próximo:
Após a atribuição, movemos os centróides para o local médio dos pontos atribuídos a ele. Lembre-se, os centróides devem ser os pontos centrais dos clusters:
A iteração atual termina sempre que terminamos de realocar os centróides. We repeat these iterations until the assignment between multiple consecutive iterations stops changing:
Quando o algoritmo termina, esses quatro clusters são encontrados conforme o esperado. Agora que sabemos como funciona o K-Means, vamos implementá-lo em Java.
3.3. Representação de recursos
Ao modelar diferentes conjuntos de dados de treinamento, precisamos de uma estrutura de dados para representar os atributos do modelo e seus valores correspondentes. Por exemplo, um músico pode ter um atributo de gênero com um valor como Rock.We usually use the term feature to refer to the combination of an attribute and its value.
Para preparar um conjunto de dados para um algoritmo de aprendizado específico, geralmente usamos um conjunto comum de atributos numéricos que podem ser usados para comparar itens diferentes. Por exemplo, se permitirmos que nossos usuários rotulem cada artista com um gênero, no final do dia, poderemos contar quantas vezes cada artista é marcado com um gênero específico:
O vetor de recursos para um artista como o Linkin Park é[rock → 7890, nu-metal → 700, alternative → 520, pop → 3]. Então, se pudéssemos encontrar uma maneira de representar atributos como valores numéricos, podemos simplesmente comparar dois itens diferentes, por exemplo, artistas, comparando suas entradas vetoriais correspondentes.
Uma vez que os vetores numéricos são estruturas de dados versáteis, vamos representar recursos usando-os. Aqui está como implementamos vetores de recursos em Java:
public class Record {
private final String description;
private final Map features;
// constructor, getter, toString, equals and hashcode
}
3.4. Encontrar itens semelhantes
Em cada iteração do K-Means, precisamos encontrar uma maneira de encontrar o centróide mais próximo de cada item no conjunto de dados. Uma das maneiras mais simples de calcular a distância entre dois vetores de recursos é usarEuclidean Distance. A distância euclidiana entre dois vetores como[p1, q1]e[p2, q2] é igual a:
Vamos implementar essa função em Java. Primeiro, a abstração:
public interface Distance {
double calculate(Map f1, Map f2);
}
Além da distância euclidiana,there are other approaches to compute the distance or similarity between different items like the Pearson Correlation Coefficient. Essa abstração facilita a alternância entre diferentes métricas de distância.
Vamos ver a implementação para distância euclidiana:
public class EuclideanDistance implements Distance {
@Override
public double calculate(Map f1, Map f2) {
double sum = 0;
for (String key : f1.keySet()) {
Double v1 = f1.get(key);
Double v2 = f2.get(key);
if (v1 != null && v2 != null) {
sum += Math.pow(v1 - v2, 2);
}
}
return Math.sqrt(sum);
}
}
Primeiro, calculamos a soma das diferenças ao quadrado entre as entradas correspondentes. Então, aplicando a funçãosqrt , calculamos a distância euclidiana real.
3.5. Representação do Centroid
Os centróides estão no mesmo espaço que os recursos normais, para que possamos representá-los semelhantes aos recursos:
public class Centroid {
private final Map coordinates;
// constructors, getter, toString, equals and hashcode
}
Agora que temos algumas abstrações necessárias no lugar, é hora de escrever nossa implementação de K-Means. Aqui está uma rápida olhada em nossa assinatura de método:
public class KMeans {
private static final Random random = new Random();
public static Map> fit(List records,
int k,
Distance distance,
int maxIterations) {
// omitted
}
}
Vamos decompor essa assinatura de método:
-
Odataset é um conjunto de vetores de recursos. Como cada vetor de característica é umRecord, , então o tipo de conjunto de dados éList<Record>
-
O parâmetrok determina o número de clusters, que devemos fornecer com antecedência
-
distance encapsula a maneira como vamos calcular a diferença entre dois recursos
-
K-Means termina quando a atribuição para de ser alterada por algumas iterações consecutivas. Além dessa condição de finalização, também podemos colocar um limite superior para o número de iterações. O sargumentomaxIterations determina esse limite superior
-
Quando o K-Means termina, cada centróide deve ter alguns recursos atribuídos, portanto, estamos usando umMap<Centroid, List<Record>> como o tipo de retorno. Basicamente, cada entrada do mapa corresponde a um cluster
3.6. Geração Centroid
A primeira etapa é gerark centróides colocados aleatoriamente.
Embora cada centróide possa conter coordenadas totalmente aleatórias, é uma boa práticagenerate random coordinates between the minimum and maximum possible values for each attribute. Gerar centróides aleatórios sem considerar a faixa de valores possíveis faria com que o algoritmo convergisse mais lentamente.
Primeiro, devemos calcular o valor mínimo e máximo para cada atributo e, em seguida, gerar os valores aleatórios entre cada par deles:
private static List randomCentroids(List records, int k) {
List centroids = new ArrayList<>();
Map maxs = new HashMap<>();
Map mins = new HashMap<>();
for (Record record : records) {
record.getFeatures().forEach((key, value) -> {
// compares the value with the current max and choose the bigger value between them
maxs.compute(key, (k1, max) -> max == null || value > max ? value : max);
// compare the value with the current min and choose the smaller value between them
mins.compute(key, (k1, min) -> min == null || value < min ? value : min);
});
}
Set attributes = records.stream()
.flatMap(e -> e.getFeatures().keySet().stream())
.collect(toSet());
for (int i = 0; i < k; i++) {
Map coordinates = new HashMap<>();
for (String attribute : attributes) {
double max = maxs.get(attribute);
double min = mins.get(attribute);
coordinates.put(attribute, random.nextDouble() * (max - min) + min);
}
centroids.add(new Centroid(coordinates));
}
return centroids;
}
Agora, podemos atribuir cada registro a um desses centróides aleatórios.
3.7. Tarefa
Primeiro, dado umRecord, devemos encontrar o centróide mais próximo a ele:
private static Centroid nearestCentroid(Record record, List centroids, Distance distance) {
double minimumDistance = Double.MAX_VALUE;
Centroid nearest = null;
for (Centroid centroid : centroids) {
double currentDistance = distance.calculate(record.getFeatures(), centroid.getCoordinates());
if (currentDistance < minimumDistance) {
minimumDistance = currentDistance;
nearest = centroid;
}
}
return nearest;
}
Cada registro pertence ao seu cluster centróide mais próximo:
private static void assignToCluster(Map> clusters,
Record record,
Centroid centroid) {
clusters.compute(centroid, (key, list) -> {
if (list == null) {
list = new ArrayList<>();
}
list.add(record);
return list;
});
}
3.8. Realocação de Centroid
Se, após uma iteração, um centróide não contiver nenhuma atribuição, não o realocaremos. Caso contrário, devemos realocar a coordenada do centróide para cada atributo para o local médio de todos os registros atribuídos:
private static Centroid average(Centroid centroid, List records) {
if (records == null || records.isEmpty()) {
return centroid;
}
Map average = centroid.getCoordinates();
records.stream().flatMap(e -> e.getFeatures().keySet().stream())
.forEach(k -> average.put(k, 0.0));
for (Record record : records) {
record.getFeatures().forEach(
(k, v) -> average.compute(k, (k1, currentValue) -> v + currentValue)
);
}
average.forEach((k, v) -> average.put(k, v / records.size()));
return new Centroid(average);
}
Como podemos realocar um único centróide, agora é possível implementar o métodorelocateCentroids :
private static List relocateCentroids(Map> clusters) {
return clusters.entrySet().stream().map(e -> average(e.getKey(), e.getValue())).collect(toList());
}
Essa linha única simples interage com todos os centróides, realoca-os e retorna os novos centróides.
3.9. Juntando tudo
Em cada iteração, depois de atribuir todos os registros ao centróide mais próximo, primeiro devemos comparar as atribuições atuais com a última iteração.
Se as atribuições forem idênticas, o algoritmo será encerrado. Caso contrário, antes de pular para a próxima iteração, devemos realocar os centróides:
public static Map> fit(List records,
int k,
Distance distance,
int maxIterations) {
List centroids = randomCentroids(records, k);
Map> clusters = new HashMap<>();
Map> lastState = new HashMap<>();
// iterate for a pre-defined number of times
for (int i = 0; i < maxIterations; i++) {
boolean isLastIteration = i == maxIterations - 1;
// in each iteration we should find the nearest centroid for each record
for (Record record : records) {
Centroid centroid = nearestCentroid(record, centroids, distance);
assignToCluster(clusters, record, centroid);
}
// if the assignments do not change, then the algorithm terminates
boolean shouldTerminate = isLastIteration || clusters.equals(lastState);
lastState = clusters;
if (shouldTerminate) {
break;
}
// at the end of each iteration we should relocate the centroids
centroids = relocateCentroids(clusters);
clusters = new HashMap<>();
}
return lastState;
}
4. Exemplo: Descobrindo artistas semelhantes na Last.fm
Last.fm builds a detailed profile of each user’s musical taste by recording details of what the user listens to. Nesta seção, vamos encontrar grupos de artistas semelhantes. Para construir um conjunto de dados apropriado para esta tarefa, usaremos três APIs do Last.fm:
-
API para obter umcollection of top artists na Last.fm.
-
Outra API para encontrarpopular tags. Cada usuário pode marcar um artista com algo, por exemplo rock. Assim, o Last.fm mantém um banco de dados dessas marcas e suas frequências.
-
E uma API paraget the top tags for an artist, ordenada por popularidade. Como existem muitas dessas tags, manteremos apenas as tags que estão entre as principais tags globais.
4.1. Last.fm’s API
Para usar essas APIs, devemos obter umAPI Key from Last.fme enviá-lo em cada solicitação HTTP. Vamos usar o seguinte serviçoRetrofit para chamar essas APIs:
public interface LastFmService {
@GET("/2.0/?method=chart.gettopartists&format=json&limit=50")
Call topArtists(@Query("page") int page);
@GET("/2.0/?method=artist.gettoptags&format=json&limit=20&autocorrect=1")
Call topTagsFor(@Query("artist") String artist);
@GET("/2.0/?method=chart.gettoptags&format=json&limit=100")
Call topTags();
// A few DTOs and one interceptor
}
Então, vamos encontrar os artistas mais populares na Last.fm:
// setting up the Retrofit service
private static List getTop100Artists() throws IOException {
List artists = new ArrayList<>();
// Fetching the first two pages, each containing 50 records.
for (int i = 1; i <= 2; i++) {
artists.addAll(lastFm.topArtists(i).execute().body().all());
}
return artists;
}
Da mesma forma, podemos buscar as principais tags:
private static Set getTop100Tags() throws IOException {
return lastFm.topTags().execute().body().all();
}
Por fim, podemos criar um conjunto de dados de artistas junto com suas frequências de tags:
private static List datasetWithTaggedArtists(List artists,
Set topTags) throws IOException {
List records = new ArrayList<>();
for (String artist : artists) {
Map tags = lastFm.topTagsFor(artist).execute().body().all();
// Only keep popular tags.
tags.entrySet().removeIf(e -> !topTags.contains(e.getKey()));
records.add(new Record(artist, tags));
}
return records;
}
4.2. Formação de grupos de artistas
Agora, podemos alimentar o conjunto de dados preparado para nossa implementação do K-Means:
List artists = getTop100Artists();
Set topTags = getTop100Tags();
List records = datasetWithTaggedArtists(artists, topTags);
Map> clusters = KMeans.fit(records, 7, new EuclideanDistance(), 1000);
// Printing the cluster configuration
clusters.forEach((key, value) -> {
System.out.println("-------------------------- CLUSTER ----------------------------");
// Sorting the coordinates to see the most significant tags first.
System.out.println(sortedCentroid(key));
String members = String.join(", ", value.stream().map(Record::getDescription).collect(toSet()));
System.out.print(members);
System.out.println();
System.out.println();
});
Se rodarmos esse código, ele visualizaria os clusters como saída de texto:
------------------------------ CLUSTER -----------------------------------
Centroid {classic rock=65.58333333333333, rock=64.41666666666667, british=20.333333333333332, ... }
David Bowie, Led Zeppelin, Pink Floyd, System of a Down, Queen, blink-182, The Rolling Stones, Metallica,
Fleetwood Mac, The Beatles, Elton John, The Clash
------------------------------ CLUSTER -----------------------------------
Centroid {Hip-Hop=97.21428571428571, rap=64.85714285714286, hip hop=29.285714285714285, ... }
Kanye West, Post Malone, Childish Gambino, Lil Nas X, A$AP Rocky, Lizzo, xxxtentacion,
Travi$ Scott, Tyler, the Creator, Eminem, Frank Ocean, Kendrick Lamar, Nicki Minaj, Drake
------------------------------ CLUSTER -----------------------------------
Centroid {indie rock=54.0, rock=52.0, Psychedelic Rock=51.0, psychedelic=47.0, ... }
Tame Impala, The Black Keys
------------------------------ CLUSTER -----------------------------------
Centroid {pop=81.96428571428571, female vocalists=41.285714285714285, indie=22.785714285714285, ... }
Ed Sheeran, Taylor Swift, Rihanna, Miley Cyrus, Billie Eilish, Lorde, Ellie Goulding, Bruno Mars,
Katy Perry, Khalid, Ariana Grande, Bon Iver, Dua Lipa, Beyoncé, Sia, P!nk, Sam Smith, Shawn Mendes,
Mark Ronson, Michael Jackson, Halsey, Lana Del Rey, Carly Rae Jepsen, Britney Spears, Madonna,
Adele, Lady Gaga, Jonas Brothers
------------------------------ CLUSTER -----------------------------------
Centroid {indie=95.23076923076923, alternative=70.61538461538461, indie rock=64.46153846153847, ... }
Twenty One Pilots, The Smiths, Florence + the Machine, Two Door Cinema Club, The 1975, Imagine Dragons,
The Killers, Vampire Weekend, Foster the People, The Strokes, Cage the Elephant, Arcade Fire,
Arctic Monkeys
------------------------------ CLUSTER -----------------------------------
Centroid {electronic=91.6923076923077, House=39.46153846153846, dance=38.0, ... }
Charli XCX, The Weeknd, Daft Punk, Calvin Harris, MGMT, Martin Garrix, Depeche Mode, The Chainsmokers,
Avicii, Kygo, Marshmello, David Guetta, Major Lazer
------------------------------ CLUSTER -----------------------------------
Centroid {rock=87.38888888888889, alternative=72.11111111111111, alternative rock=49.16666666, ... }
Weezer, The White Stripes, Nirvana, Foo Fighters, Maroon 5, Oasis, Panic! at the Disco, Gorillaz,
Green Day, The Cure, Fall Out Boy, OneRepublic, Paramore, Coldplay, Radiohead, Linkin Park,
Red Hot Chili Peppers, Muse
Como as coordenadas do centróide são classificadas pela frequência média de tags, podemos identificar facilmente o gênero dominante em cada cluster. Por exemplo, o último cluster é um cluster de boas e antigas bandas de rock, ou o segundo está cheio de estrelas do rap.
Embora esse clustering faça sentido, na maior parte, não é perfeito, pois os dados são apenas coletados a partir do comportamento do usuário.
5. Visualização
Alguns momentos atrás, nosso algoritmo visualizava o agrupamento de artistas de maneira amigável ao terminal. Se convertermos nossa configuração de cluster para JSON e alimentá-lo para D3.js, então, com algumas linhas de JavaScript, teremos um bomRadial Tidy-Tree amigável para humanos:
Temos que converter nossoMap<Centroid, List<Record>> em um JSON com um esquema semelhante comothis d3.js example.
6. Número de Clusters
Uma das propriedades fundamentais do K-Means é o fato de que devemos definir o número de clusters com antecedência. Até agora, usamos um valor estático parak, mas determinar esse valor pode ser um problema desafiador. There are two common ways to calculate the number of clusters:
-
Conhecimento do domínio
-
Heurística Matemática
Se tivermos a sorte de saber tanto sobre o domínio, poderemos simplesmente adivinhar o número certo. Caso contrário, podemos aplicar algumas heurísticas, como Elbow Method ou Silhouette Method, para entender o número de clusters.
Antes de prosseguir, devemos saber que essas heurísticas, embora úteis, são apenas heurísticas e podem não fornecer respostas claras.
6.1. Método do cotovelo
Para usar o método cotovelo, devemos primeiro calcular a diferença entre cada centróide do cluster e todos os seus membros. As we group more unrelated members in a cluster, the distance between the centroid and its members goes up, hence the cluster quality decreases.
Uma maneira de realizar este cálculo de distância é usar a Soma dos Erros Quadrados. * A soma dos erros quadráticos ou SSE é igual à soma das diferenças quadradas entre um centróide e todos os seus membros *:
public static double sse(Map> clustered, Distance distance) {
double sum = 0;
for (Map.Entry> entry : clustered.entrySet()) {
Centroid centroid = entry.getKey();
for (Record record : entry.getValue()) {
double d = distance.calculate(centroid.getCoordinates(), record.getFeatures());
sum += Math.pow(d, 2);
}
}
return sum;
}
Então,we can run the K-Means algorithm for different values of ke calcule o SSE para cada um deles:
List records = // the dataset;
Distance distance = new EuclideanDistance();
List sumOfSquaredErrors = new ArrayList<>();
for (int k = 2; k <= 16; k++) {
Map> clusters = KMeans.fit(records, k, distance, 1000);
double sse = Errors.sse(clusters, distance);
sumOfSquaredErrors.add(sse);
}
No final do dia, é possível encontrar umak apropriada plotando o número de clusters em relação ao SSE:
Geralmente, à medida que o número de clusters aumenta, a distância entre os membros do cluster diminui. No entanto, não podemos escolher quaisquer valores arbitrários grandes parak, s, visto que ter vários clusters com apenas um membro anula todo o propósito do cluster.
The idea behind the elbow method is to find an appropriate value for k in a way that the SSE decreases dramatically around that value. Por exemplo,k=9 pode ser um bom candidato aqui.
7. Conclusão
Neste tutorial, primeiro, abordamos alguns conceitos importantes no Machine Learning. Em seguida, ficamos com a mecânica do algoritmo de agrupamento K-Means. Por fim, escrevemos uma implementação simples para o K-Means, testamos nosso algoritmo com um conjunto de dados do mundo real da Last.fm e visualizamos o resultado do clustering de maneira gráfica agradável.
Como de costume, o código de amostra está disponível em nosso projetoGitHub, portanto, certifique-se de dar uma olhada!