Regressão logística em Java
1. Introdução
A regressão logística é um importante instrumento na caixa de ferramentas do praticante de aprendizado de máquina (ML).
Neste tutorial, exploraremos a idéia principal por trás da regressão logística .
Primeiro, vamos começar com uma breve visão geral dos paradigmas e algoritmos de ML.
2. Visão geral
O ML nos permite resolver problemas que podemos formular em termos humanos. No entanto, esse fato pode representar um desafio para nós, desenvolvedores de software. Acostumamo-nos a resolver os problemas que podemos formular em termos amigáveis ao computador. Por exemplo, como seres humanos, podemos detectar facilmente os objetos em uma foto ou estabelecer o humor de uma frase. Como poderíamos formular esse problema para um computador?
Para encontrar uma solução *, no ML, há um estágio especial chamado training *. Durante esse estágio, alimentamos os dados de entrada em nosso algoritmo para que ele tente criar um conjunto ideal de parâmetros (os chamados pesos). Quanto mais dados de entrada pudermos fornecer ao algoritmo, mais previsões precisas poderemos esperar dele.
O treinamento faz parte de um fluxo de trabalho iterativo de ML:
https://www..com/wp-content/uploads/2019/09/ml1.png [imagem: https://www..com/wp-content/uploads/2019/09/ml1.png [imagem] ] + Começamos com a aquisição de dados. Freqüentemente, os dados vêm de diferentes fontes. Portanto, temos que torná-lo no mesmo formato. Também devemos controlar que o conjunto de dados represente de forma justa o domínio do estudo. Se o modelo nunca foi treinado em maçãs vermelhas, dificilmente poderá prever.
Em seguida, devemos construir um modelo que consuma os dados e seja capaz de fazer previsões. No ML, não há modelos predefinidos que funcionem bem em todas as situações.
Ao procurar o modelo correto, pode facilmente acontecer que construamos, treinemos, visualizemos suas previsões e descartemos o modelo, porque não estamos felizes com as previsões que ele faz. Nesse caso, devemos recuar e construir outro modelo e repetir o processo novamente.
3. Paradigmas de ML
No ML, com base no tipo de dados de entrada que temos à nossa disposição, podemos destacar três paradigmas principais:
-
aprendizado supervisionado (classificação de imagens, reconhecimento de objetos, análise de sentimentos)
-
aprendizado não supervisionado (detecção de anomalias) *aprendizagem por reforço (estratégias de jogo)
O caso que vamos descrever* neste tutorial pertence ao aprendizado supervisionado. *
4. ML Toolbox
No ML, há um conjunto de ferramentas que podemos aplicar ao criar um modelo. Vamos mencionar alguns deles:
-
Regressão linear
-
Regressão logística
-
Redes neurais
-
Máquina de vetores de suporte
-
k-vizinhos mais próximos
*Podemos combinar várias ferramentas ao criar um modelo com alta capacidade de previsão.* De fato, para este tutorial, nosso modelo usará regressão logística e redes neurais.
5. Bibliotecas ML
Embora o Java não seja a linguagem mais popular para a prototipagem de modelos de ML, o tem reputação de ser uma ferramenta confiável para criar software robusto em muitas áreas, incluindo o ML. Portanto, podemos encontrar bibliotecas ML escritas em Java.
Nesse contexto, podemos mencionar a biblioteca padrão de fato Tensorflow, que também possui uma versão Java. Outro destaque a ser mencionado é uma biblioteca de aprendizado profundo chamada https://www..com/deeplearning4j [Deeplearning4j]. Essa é uma ferramenta muito poderosa e vamos usá-la neste tutorial também.
6. Regressão logística no reconhecimento de dígitos
A idéia principal da regressão logística é construir um modelo que preveja os rótulos dos dados de entrada da maneira mais precisa possível.
Treinamos o modelo até que a chamada função de perda ou função objetivo atinja algum valor mínimo. A função de perda depende das previsões reais do modelo e das esperadas (os rótulos dos dados de entrada). Nosso objetivo é minimizar a divergência entre as previsões reais do modelo e as esperadas.
Se não estivermos satisfeitos com esse valor mínimo, devemos construir outro modelo e executar o treinamento novamente.
Para ver a regressão logística em ação, nós a ilustramos no reconhecimento de dígitos manuscritos. Este problema já se tornou clássico. A biblioteca Deeplearning4j possui uma série de examples que mostram como usar sua API . A parte relacionada a código deste tutorial é fortemente baseada em https://github.com/deeplearning4j/dl4j-examples/blob/master/dl4j-examples/src/main/java/org/deeplearning4j/examples/convolution/mnist/MnistClassifier.java [MNIST Classifier].
6.1. Dados de entrada
Como dados de entrada, usamos o conhecido MNIST database de dígitos manuscritos. Como dados de entrada, temos imagens em escala de cinza de 28 × 28 pixels. Cada imagem tem um rótulo natural, que é o dígito que a imagem representa:
https://www..com/wp-content/uploads/2019/09/ml2.png [imagem: https://www..com/wp-content/uploads/2019/09/ml2.png [imagem] ]
Para estimar a eficiência do modelo que vamos construir, dividimos os dados de entrada em conjuntos de treinamento e teste:
DataSetIterator train = new RecordReaderDataSetIterator(...);
DataSetIterator test = new RecordReaderDataSetIterator(...);
Uma vez rotuladas e divididas as imagens de entrada nos dois conjuntos, o estágio "elaboração de dados" termina e podemos passar para a "construção do modelo".
6.2. Model Building
Como mencionamos, não há modelos que funcionem bem em todas as situações. No entanto, após muitos anos de pesquisa em ML, os cientistas descobriram modelos que apresentam um desempenho muito bom no reconhecimento de dígitos manuscritos. Aqui, usamos o chamado modelo LeNet-5.
O LeNet-5 é uma rede neural que consiste em uma série de camadas que transformam a imagem de 28 × 28 pixels em um vetor de dez dimensões:
https://www..com/wp-content/uploads/2019/09/ml3.png [imagem: https://www..com/wp-content/uploads/2019/09/ml3.png [imagem] ] + O vetor de saída tridimensional contém probabilidades de que o rótulo da imagem de entrada seja 0 ou 1 ou 2 e assim por diante.
Por exemplo, se o vetor de saída tiver o seguinte formato:
{0.1, 0.0, 0.3, 0.2, 0.1, 0.1, 0.0, 0.1, 0.1, 0.0}
significa que a probabilidade de a imagem de entrada ser zero é 0,1, uma é 0, e duas é 0,3, etc. Vemos que a probabilidade máxima (0,3) corresponde ao rótulo 3.
Vamos mergulhar nos detalhes da construção de modelos. Omitimos detalhes específicos de Java e nos concentramos nos conceitos de ML.
Configuramos o modelo criando um objeto MultiLayerNetwork:
MultiLayerNetwork model = new MultiLayerNetwork(config);
Em seu construtor, devemos passar um objeto MultiLayerConfiguration. Este é o próprio objeto que descreve a geometria da rede neural. Para definir a geometria da rede, devemos definir todas as camadas.
Vamos mostrar como fazemos isso com o primeiro e o segundo:
ConvolutionLayer layer1 = new ConvolutionLayer
.Builder(5, 5).nIn(channels)
.stride(1, 1)
.nOut(20)
.activation(Activation.IDENTITY)
.build();
SubsamplingLayer layer2 = new SubsamplingLayer
.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2, 2)
.stride(2, 2)
.build();
Vemos que as definições de camadas contêm uma quantidade considerável de parâmetros ad-hoc que impactam significativamente em todo o desempenho da rede. É exatamente aqui que nossa capacidade de encontrar um bom modelo na paisagem de todos se torna crucial.
Agora, estamos prontos para construir o objeto MultiLayerConfiguration:
MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
//preparation steps
.list()
.layer(layer1)
.layer(layer2)
//other layers and final steps
.build();
que passamos para o construtor MultiLayerNetwork.
6.3. Treinamento
O modelo que construímos contém 431080 parâmetros ou pesos. Não vamos fornecer aqui o cálculo exato desse número, mas devemos estar cientes de que apenas a * primeira camada tem mais de 24x24x20 = 11520 pesos. *
A fase de treinamento é tão simples quanto:
model.fit(train);
Inicialmente, os parâmetros 431080 possuem alguns valores aleatórios, mas após o treinamento, eles adquirem alguns valores que determinam o desempenho do modelo. Podemos avaliar a previsão do modelo:
Evaluation eval = model.evaluate(test);
logger.info(eval.stats());
O modelo LeNet-5 atinge uma precisão bastante alta de quase 99%, mesmo em apenas uma iteração de treinamento (época). Se queremos obter maior precisão, devemos fazer mais iterações usando um for-loop simples:
for (int i = 0; i < epochs; i++) {
model.fit(train);
train.reset();
test.reset();
}
6.4. Predição
Agora, enquanto treinamos o modelo e estamos felizes com suas previsões nos dados de teste, podemos experimentar o modelo com algumas informações absolutamente novas. Para esse fim, vamos criar uma nova classe MnistPrediction na qual carregaremos uma imagem de um arquivo que selecionamos no sistema de arquivos:
INDArray image = new NativeImageLoader(height, width, channels).asMatrix(file);
new ImagePreProcessingScaler(0, 1).transform(image);
A variável image contém nossa imagem sendo reduzida para 28 × 28 em escala de cinza. Podemos alimentá-lo com o nosso modelo:
INDArray output = model.output(image);
A variável output conterá as probabilidades da imagem ser zero, um, dois etc.
Vamos agora tocar um pouco e escrever um dígito 2, digitalizar esta imagem e alimentar o modelo. Podemos obter algo assim:
https://www..com/wp-content/uploads/2019/09/ml4.png [imagem: https://www..com/wp-content/uploads/2019/09/ml4.png [imagem] ] + Como vemos, o componente com valor máximo de 0,99 possui o índice dois. Isso significa que o modelo reconheceu corretamente nosso dígito manuscrito.
7. Conclusão
Neste tutorial, descrevemos os conceitos gerais de aprendizado de máquina. Ilustramos esses conceitos no exemplo de regressão logística que aplicamos ao reconhecimento de dígitos manuscritos.
Como sempre, podemos encontrar os trechos de código correspondentes em nosso GitHub repository.