Introdução ao Tensorflow para Java
1. Visão geral
TensorFlow é umopen source library for dataflow programming. Este foi originalmente desenvolvido pelo Google e está disponível para uma ampla variedade de plataformas. Embora o TensorFlow possa funcionar em um único núcleo, ele pode atéeasily benefit from multiple CPU, GPU or TPU available.
Neste tutorial, veremos os fundamentos do TensorFlow e como usá-lo em Java. Observe que a API Java do TensorFlow é uma API experimental e, portanto, não é coberta por nenhuma garantia de estabilidade. Cobriremos posteriormente no tutorial os possíveis casos de uso para usar a API TensorFlow Java.
2. Fundamentos
O cálculo do TensorFlow gira basicamente em torno detwo fundamental concepts: Graph and Session. Vamos examiná-los rapidamente para obter a base necessária para percorrer o restante do tutorial.
2.1. Gráfico de TensorFlow
Para começar, vamos entender os blocos de construção fundamentais dos programas TensorFlow. Computations are represented as graphs in TensorFlow. Um gráfico geralmente é um gráfico acíclico direcionado de operações e dados, por exemplo:
A figura acima representa o gráfico computacional da seguinte equação:
f(x, y) = z = a*x + b*y
Um gráfico computacional TensorFlow consiste em dois elementos:
-
Tensor: These are the core unit of data in TensorFlow. Eles são representados como as arestas em um gráfico computacional, representando o fluxo de dados através do gráfico. Um tensor pode ter uma forma com qualquer número de dimensões. O número de dimensões em um tensor é geralmente referido como sua classificação. Portanto, um escalar é um tensor de classificação 0, um vetor é um tensor de classificação 1, uma matriz é um tensor de classificação 2 e assim por diante.
-
Operation: These are the nodes in a computational graph. Eles se referem a uma ampla variedade de cálculos que podem acontecer nos tensores que alimentam a operação. Eles geralmente resultam em tensores que emanam da operação em um gráfico computacional.
2.2. Sessão TensorFlow
Agora, um gráfico TensorFlow é um mero esquema da computação que na verdade não contém valores. Tala graph must be run inside what is called a TensorFlow session for the tensors in the graph to be evaluated. A sessão pode levar vários tensores para avaliar a partir de um gráfico como parâmetros de entrada. Em seguida, ele retrocede no gráfico e executa todos os nós necessários para avaliar esses tensores.
Com esse conhecimento, agora estamos prontos para pegar isso e aplicá-lo à API Java!
3. Configuração do Maven
Vamos configurar um projeto Maven rápido para criar e executar um gráfico TensorFlow em Java. Precisamos apenas detensorflow dependency:
org.tensorflow
tensorflow
1.12.0
4. Criando o gráfico
Agora vamos tentar construir o gráfico que discutimos na seção anterior usando a API TensorFlow Java. Mais precisamente, neste tutorial, usaremos a API TensorFlow Java para resolver a função representada pela seguinte equação:
z = 3*x + 2*y
O primeiro passo é declarar e inicializar um gráfico:
Graph graph = new Graph()
Agora, temos que definir todas as operações necessárias. Lembre-se de queoperations in TensorFlow consume and produce zero or more tensors. Além disso, todo nó no gráfico é uma operação incluindo constantes e espaços reservados. Isso pode parecer contra-intuitivo, mas aceite isso por um momento!
A classeGraph tem uma função genérica chamadaopBuilder() para construir qualquer tipo de operação no TensorFlow.
4.1. Definindo constantes
Para começar, vamos definir as operações constantes em nosso gráfico acima. Observe que aconstant operation will need a tensor for its value:
Operation a = graph.opBuilder("Const", "a")
.setAttr("dtype", DataType.fromClass(Double.class))
.setAttr("value", Tensor.create(3.0, Double.class))
.build();
Operation b = graph.opBuilder("Const", "b")
.setAttr("dtype", DataType.fromClass(Double.class))
.setAttr("value", Tensor.create(2.0, Double.class))
.build();
Aqui, definimos umOperation do tipo constante, alimentandoTensor com os valoresDouble 2,0 e 3,0. Pode parecer um pouco opressor no início, mas é assim que funciona na API Java por enquanto. Essas construções são muito mais concisas em linguagens como Python.
4.2. Definindo espaços reservados
Embora precisemos fornecer valores para nossas constantes,placeholders don’t need a value at definition-time. Os valores para espaços reservados precisam ser fornecidos quando o gráfico é executado dentro de uma sessão. Veremos essa parte mais tarde no tutorial.
Por enquanto, vamos ver como podemos definir nossos marcadores de posição:
Operation x = graph.opBuilder("Placeholder", "x")
.setAttr("dtype", DataType.fromClass(Double.class))
.build();
Operation y = graph.opBuilder("Placeholder", "y")
.setAttr("dtype", DataType.fromClass(Double.class))
.build();
Observe que não precisamos fornecer nenhum valor para nossos espaços reservados. Esses valores serão alimentados comoTensors quando executados.
4.3. Definindo Funções
Finalmente, precisamos definir as operações matemáticas de nossa equação, a saber, multiplicação e adição para obter o resultado.
Novamente, nada mais são do queOperations no TensorFlow eGraph.opBuilder() é útil mais uma vez:
Operation ax = graph.opBuilder("Mul", "ax")
.addInput(a.output(0))
.addInput(x.output(0))
.build();
Operation by = graph.opBuilder("Mul", "by")
.addInput(b.output(0))
.addInput(y.output(0))
.build();
Operation z = graph.opBuilder("Add", "z")
.addInput(ax.output(0))
.addInput(by.output(0))
.build();
Aqui, definimosOperation, dois para multiplicar nossas entradas e o último para somar os resultados intermediários. Observe que as operações aqui recebem tensores que nada mais são do que resultados de nossas operações anteriores.
Observe que estamos obtendo a saídaTensor deOperation usando o índice '0'. Como discutimos anteriormente,an Operation can result in one or more Tensor e, portanto, ao recuperar um identificador para ele, precisamos mencionar o índice. Como sabemos que nossas operações estão retornando apenas umTensor, '0' funciona perfeitamente!
5. Visualizando o gráfico
É difícil manter uma guia no gráfico à medida que aumenta de tamanho. Isso o tornaimportant to visualize it in some way. Sempre podemos criar um desenho à mão como o pequeno gráfico que criamos anteriormente, mas não é prático para gráficos maiores. TensorFlow provides a utility called TensorBoard to facilitate this.
Infelizmente, a API Java não tem a capacidade de gerar um arquivo de evento que é consumido pelo TensorBoard. Mas, usando APIs em Python, podemos gerar um arquivo de evento como:
writer = tf.summary.FileWriter('.')
......
writer.add_graph(tf.get_default_graph())
writer.flush()
Por favor, não se preocupe se isso não fizer sentido no contexto de Java, isso foi adicionado aqui apenas para fins de integridade e não é necessário continuar o restante do tutorial.
Agora podemos carregar e visualizar o arquivo de evento no TensorBoard como:
tensorboard --logdir .
O TensorBoard vem como parte da instalação do TensorFlow.
Observe a semelhança entre este e o gráfico desenhado manualmente anteriormente!
6. Trabalhando com Sessão
Agora criamos um gráfico computacional para nossa equação simples na API Java do TensorFlow. Mas como o executamos? Antes de abordar isso, vamos ver qual é o estado deGraph que acabamos de criar neste ponto. Se tentarmos imprimir a saída de nossoOperation “z” final:
System.out.println(z.output(0));
Isso resultará em algo como:
dtype=DOUBLE>
Não era isso que esperávamos! Mas se lembrarmos do que discutimos anteriormente, isso realmente faz sentido. The Graph we have just defined has not been run yet, so the tensors therein do not actually hold any actual value. A saída acima apenas diz que será umTensor do tipoDouble.
Vamos agora definir umSession para executar nossoGraph:
Session sess = new Session(graph)
Finalmente, agora estamos prontos para executar nosso gráfico e obter a saída que esperávamos:
Tensor tensor = sess.runner().fetch("z")
.feed("x", Tensor.create(3.0, Double.class))
.feed("y", Tensor.create(6.0, Double.class))
.run().get(0).expect(Double.class);
System.out.println(tensor.doubleValue());
Então o que estamos fazendo aqui? Deve ser bastante intuitivo:
-
Obtenha umRunner deSession
-
Defina oOperation para buscar por seu nome “z”
-
Alimente tensores para nossos marcadores "x" e "y"
-
Execute oGraph noSession
E agora vemos a saída escalar:
21.0
Isso é o que esperávamos, não é?
7. A API de Caso de Uso para Java
Nesse ponto, o TensorFlow pode parecer um exagero na execução de operações básicas. Mas, é claro,TensorFlow is meant to run graphs much much larger do que isso.
Além disso,the tensors it deals with in real-world models are much larger in size and rank. Estes são os modelos reais de aprendizado de máquina em que o TensorFlow encontra seu uso real.
Não é difícil ver que trabalhar com a API principal no TensorFlow pode se tornar muito complicado à medida que o tamanho do gráfico aumenta. Para isso,TensorFlow provides high-level APIs like Keras to work with complex models. Infelizmente, ainda há pouco ou nenhum suporte oficial para o Keras em Java.
No entanto, podemosuse Python to define and train complex models diretamente no TensorFlow ou usando APIs de alto nível como Keras. Posteriormente, podemosexport a trained model and use that in Java usando a API TensorFlow Java.
Agora, por que queremos fazer algo assim? Isso é particularmente útil para situações em que queremos usar os recursos habilitados para aprendizado de máquina em clientes existentes em execução em Java. Por exemplo, recomendando legenda para imagens de usuário em um dispositivo Android. No entanto, existem várias instâncias em que estamos interessados na saída de um modelo de aprendizado de máquina, mas não necessariamente queremos criar e treinar esse modelo em Java.
É aqui que a API Java do TensorFlow encontra a maior parte de seu uso. Veremos como isso pode ser alcançado na próxima seção.
8. Usando modelos salvos
Agora entenderemos como podemos salvar um modelo no TensorFlow no sistema de arquivos e carregá-lo de volta, possivelmente em uma plataforma e linguagem completamente diferente. TensorFlow provides APIs to generate model files in a language and platform neutral structure called Protocol Buffer.
8.1. Salvando modelos no sistema de arquivos
Começaremos definindo o mesmo gráfico que criamos anteriormente no Python e salvando-o no sistema de arquivos.
Vamos ver se podemos fazer isso em Python:
import tensorflow as tf
graph = tf.Graph()
builder = tf.saved_model.builder.SavedModelBuilder('./model')
with graph.as_default():
a = tf.constant(2, name='a')
b = tf.constant(3, name='b')
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')
z = tf.math.add(a*x, b*y, name='z')
sess = tf.Session()
sess.run(z, feed_dict = {x: 2, y: 3})
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING])
builder.save()
Como o foco deste tutorial em Java, não vamos prestar muita atenção aos detalhes deste código em Python, exceto pelo fato de que ele gera um arquivo chamado “saved_model.pb”. Observe de maneira concisa a definição de um gráfico semelhante comparado ao Java!
8.2. Carregando modelos do sistema de arquivos
Agora vamos carregar “saved_model.pb” em Java. A API Java TensorFlow temSavedModelBundle para funcionar com modelos salvos:
SavedModelBundle model = SavedModelBundle.load("./model", "serve");
Tensor tensor = model.session().runner().fetch("z")
.feed("x", Tensor.create(3, Integer.class))
.feed("y", Tensor.create(3, Integer.class))
.run().get(0).expect(Integer.class);
System.out.println(tensor.intValue());
Agora deve ser bastante intuitivo entender o que o código acima está fazendo. Simplesmente carrega o gráfico do modelo a partir do buffer do protocolo e disponibiliza a sessão no mesmo. A partir daí, podemos praticamente fazer qualquer coisa com esse gráfico, como teríamos feito em um gráfico definido localmente.
9. Conclusão
Em resumo, neste tutorial, examinamos os conceitos básicos relacionados ao gráfico computacional TensorFlow. Vimos como usar a API Java do TensorFlow para criar e executar esse gráfico. Em seguida, conversamos sobre os casos de uso da API Java em relação ao TensorFlow.
No processo, também entendemos como visualizar o gráfico usando o TensorBoard e salvar e recarregar um modelo usando o Buffer de Protocolo.
Como sempre, o código dos exemplos está disponívelover on GitHub.