Introduktion til Tensorflow til Java

1. Oversigt

TensorFlow er en open source-bibliotek til dataflytsprogrammering. Dette blev oprindeligt udviklet af Google og er tilgængeligt for en bred vifte af platforme. Selvom TensorFlow kan arbejde på en enkelt kerne, kan det som nyd godt af flere tilgængelige CPU'er, GPU'er eller TPU'er.

I denne vejledning gennemgår vi det grundlæggende i TensorFlow og hvordan man bruger det i Java. Bemærk, at TensorFlow Java API er en eksperimentel API og derfor ikke er omfattet af nogen stabilitetsgaranti. Vi dækker senere i vejledningen mulige brugssager til brug af TensorFlow Java API.

2. Grundlæggende

TensorFlow-beregning drejer sig stort set omkring to grundlæggende begreber: Graf og session. Lad os gå igennem dem hurtigt for at få den nødvendige baggrund for at gennemgå resten af ​​vejledningen.

2.1. TensorFlow-graf

Lad os til at begynde med forstå de grundlæggende byggesten i TensorFlow-programmer. Beregninger er repræsenteret som grafer i TensorFlow. En graf er typisk en rettet acyklisk graf over operationer og data, for eksempel:

Ovenstående billede repræsenterer beregningsgrafen for følgende ligning:

f (x, y) = z = a * x + b * y

En beregningsgraf for TensorFlow består af to elementer:

  1. Tensor: Dette er kerneenheden for data i TensorFlow. De er repræsenteret som kanterne i en beregningsgraf, der viser strømmen af ​​data gennem grafen. En tensor kan have en form med et vilkårligt antal dimensioner. Antallet af dimensioner i en tensor kaldes normalt dens rang. Så en skalar er en rang 0 tensor, en vektor er en rang 1 tensor, en matrix er en rang 2 tensor, og så videre og så videre.
  2. Drift: Dette er noderne i en beregningsgraf. De henviser til en bred vifte af beregninger, der kan ske på de tensorer, der føder til operationen. De resulterer ofte også i tensorer, der stammer fra operationen i en beregningsgraf.

2.2. TensorFlow-session

Nu er en TensorFlow-graf kun en skematisk oversigt over beregningen, som faktisk ikke indeholder nogen værdier. Sådan en graf skal køres inde i det, der kaldes en TensorFlow-session, for at tensorerne i grafen skal evalueres. Sessionen kan tage en masse tensorer at evaluere ud fra en graf som inputparametre. Derefter løber den bagud i grafen og kører alle de noder, der er nødvendige for at evaluere disse tensorer.

Med denne viden er vi nu klar til at tage dette og anvende det på Java API!

3. Maven-opsætning

Vi opretter et hurtigt Maven-projekt for at oprette og køre en TensorFlow-graf i Java. Vi har bare brug for tensorflow afhængighed:

 org.tensorflow tensorflow 1.12.0 

4. Oprettelse af grafen

Lad os nu prøve at opbygge den graf, vi diskuterede i det foregående afsnit ved hjælp af TensorFlow Java API. Mere præcist til denne vejledning bruger vi TensorFlow Java API til at løse funktionen repræsenteret af følgende ligning:

z = 3 * x + 2 * y

Det første trin er at erklære og initialisere en graf:

Grafgraf = ny graf ()

Nu skal vi definere alle nødvendige operationer. Huske på, at operationer i TensorFlow forbruger og producerer nul eller flere tensorer. Desuden er hver knude i grafen en operation inklusive konstanter og pladsholdere. Dette kan virke kontraintuitivt, men hold det med et øjeblik!

Klassen Kurve har en generisk funktion kaldet opBuilder () at bygge enhver form for operation på TensorFlow.

4.1. Definition af konstanter

Til at begynde med skal vi definere konstante operationer i vores graf ovenfor. Bemærk, at en konstant drift har brug for en tensor for dens værdi:

Betjening a = graph.opBuilder ("Const", "a") .setAttr ("dtype", DataType.fromClass (Double.class)) .setAttr ("value", Tensor.create (3.0, Double.class)). build (); Funktion b = graph.opBuilder ("Const", "b") .setAttr ("dtype", DataType.fromClass (Double.class)) .setAttr ("value", Tensor.create (2.0, Double.class)). build ();

Her har vi defineret en Operation af konstant type, fodring i Tensor med Dobbelt værdier 2,0 og 3,0. Det kan virke lidt overvældende til at begynde med, men sådan er det lige nu i Java API. Disse konstruktioner er meget mere koncise på sprog som Python.

4.2. Definition af pladsholdere

Mens vi har brug for at give værdier til vores konstanter, pladsholdere har ikke brug for en værdi på definitionstidspunktet. Værdierne til pladsholdere skal leveres, når grafen køres inde i en session. Vi gennemgår den del senere i vejledningen.

Lad os nu se, hvordan vi kan definere vores pladsholdere:

Operation x = graph.opBuilder ("Placeholder", "x") .setAttr ("dtype", DataType.fromClass (Double.class)) .build (); Funktion y = graph.opBuilder ("Placeholder", "y") .setAttr ("dtype", DataType.fromClass (Double.class)) .build ();

Bemærk, at vi ikke behøvede at give nogen værdi til vores pladsholdere. Disse værdier tilføres som Tensorer når du kører.

4.3. Definition af funktioner

Endelig er vi nødt til at definere de matematiske operationer i vores ligning, nemlig multiplikation og tilføjelse for at få resultatet.

Disse er igen intet andet end Operations i TensorFlow og Graph.opBuilder () er endnu en gang praktisk:

Operation ax = graph.opBuilder ("Mul", "ax") .addInput (a.output (0)) .addInput (x.output (0)) .build (); Betjening af = graph.opBuilder ("Mul", "af") .addInput (b.output (0)) .addInput (y.output (0)) .build (); Funktion z = graph.opBuilder ("Tilføj", "z") .addInput (ax.output (0)) .addInput (by.output (0)) .build ();

Her har vi defineret der Operation, to til at multiplicere vores input og den sidste til at opsummere de mellemliggende resultater. Bemærk, at operationer her modtager tensorer, der kun er resultatet af vores tidligere operationer.

Bemærk, at vi får output Tensor fra Operation ved hjælp af indeks '0'. Som vi diskuterede tidligere, en Operation kan resultere i en eller flere Tensor og derfor, mens vi henter et håndtag til det, skal vi nævne indekset. Da vi ved, at vores operationer kun returnerer en Tensor, '0' fungerer fint!

5. Visualisering af grafen

Det er svært at holde en fane på grafen, da den vokser i størrelse. Dette gør det vigtigt at visualisere det på en eller anden måde. Vi kan altid oprette en håndtegning som den lille graf, vi oprettede tidligere, men det er ikke praktisk for større grafer. TensorFlow leverer et værktøj kaldet TensorBoard for at lette dette.

Desværre har Java API ikke mulighed for at generere en begivenhedsfil, der forbruges af TensorBoard. Men ved hjælp af API'er i Python kan vi generere en begivenhedsfil som:

writer = tf.summary.FileWriter ('.') ...... writer.add_graph (tf.get_default_graph ()) writer.flush ()

Du skal ikke gider, hvis dette ikke giver mening i forbindelse med Java, dette er blevet tilføjet her bare for fuldstændighedens skyld og ikke nødvendigt for at fortsætte resten af ​​vejledningen.

Vi kan nu indlæse og visualisere begivenhedsfilen i TensorBoard som:

tensorboard --logdir.

TensorBoard kommer som en del af TensorFlow-installationen.

Bemærk ligheden mellem denne og den manuelt tegnede graf tidligere!

6. Arbejde med session

Vi har nu oprettet en beregningsgraf til vores enkle ligning i TensorFlow Java API. Men hvordan kører vi det? Før vi tager fat på det, lad os se, hvad staten er Kurve vi har netop oprettet på dette tidspunkt. Hvis vi prøver at udskrive output fra vores endelige Operation “Z”:

System.out.println (z.output (0));

Dette vil resultere i noget som:

Dette var ikke, hvad vi forventede! Men hvis vi husker det, vi diskuterede tidligere, giver det faktisk mening. Det Kurve vi netop har defineret, er ikke kørt endnu, så tensorerne deri har faktisk ingen faktisk værdi. Outputtet ovenfor siger bare, at dette vil være en Tensor af typen Dobbelt.

Lad os nu definere en Session at køre vores Kurve:

Sessionssession = ny session (graf)

Endelig er vi nu klar til at køre vores graf og få det output, vi har forventet:

Tensor tensor = sess.runner (). Hent ("z"). Feed ("x", Tensor.create (3.0, Double.class)). Feed ("y", Tensor.create (6.0, Double.class) ) .run (). get (0) .expect (Double.class); System.out.println (tensor.doubleValue ());

Så hvad laver vi her? Det skal være ret intuitivt:

  • Få en Løber fra Session
  • Definer Operation at hente ved navn “z”
  • Giv tensorer til vores pladsholdere "x" og "y"
  • Kør Kurve i Session

Og nu ser vi den skalære output:

21.0

Dette var hvad vi forventede, er det ikke!

7. Brugssagen til Java API

På dette tidspunkt kan TensorFlow lyde som overdreven til at udføre grundlæggende operationer. Men selvfølgelig, TensorFlow er beregnet til at køre grafer meget meget større end dette.

Derudover de tensorer, det beskæftiger sig med i virkelige modeller, er meget større i størrelse og rang. Dette er de faktiske maskinlæringsmodeller, hvor TensorFlow finder sin reelle brug.

Det er ikke svært at se, at arbejde med kernen API i TensorFlow kan blive meget besværligt, når grafens størrelse øges. Til denne ende, TensorFlow leverer API'er på højt niveau som Keras til at arbejde med komplekse modeller. Desværre er der endnu kun lidt eller ingen officiel støtte til Keras på Java.

Det kan vi dog bruge Python til at definere og træne komplekse modeller enten direkte i TensorFlow eller ved hjælp af API'er på højt niveau som Keras. Efterfølgende kan vi eksporter en uddannet model og brug den i Java ved hjælp af TensorFlow Java API.

Hvorfor vil vi nu gøre noget lignende? Dette er især nyttigt i situationer, hvor vi vil bruge funktioner til maskinindlæring i eksisterende klienter, der kører på Java. For eksempel anbefale billedtekst til brugerbilleder på en Android-enhed. Ikke desto mindre er der flere tilfælde, hvor vi er interesserede i output fra en maskinlæringsmodel, men ikke nødvendigvis ønsker at oprette og træne den model i Java.

Det er her TensorFlow Java API finder størstedelen af ​​brugen. Vi gennemgår, hvordan dette kan opnås i det næste afsnit.

8. Brug af gemte modeller

Vi forstår nu, hvordan vi kan gemme en model i TensorFlow til filsystemet og indlæse den tilbage muligvis på et helt andet sprog og platform. TensorFlow leverer API'er til at generere modelfiler i en sprog- og platformneutral struktur kaldet Protocol Buffer.

8.1. Gemme modeller i filsystemet

Vi begynder med at definere den samme graf, som vi oprettede tidligere i Python, og gemme den i filsystemet.

Lad os se, at vi kan gøre dette i Python:

importer tensorflow som tf-graf = tf.Graph () builder = tf.saved_model.builder.SavedModelBuilder ('./ model') med graph.as_default (): a = tf.constant (2, name = "a") b = tf.constant (3, name = "b") x = tf.placeholder (tf.int32, name = "x") y = tf.placeholder (tf.int32, name = "y") z = tf.math. tilføj (a * x, b * y, navn = "z") sess = tf.Session () sess.run (z, feed_dict = {x: 2, y: 3}) builder.add_meta_graph_and_variables (sess, [tf. saved_model.tag_constants.SERVING]) builder.save ()

Som fokus for denne tutorial i Java, lad os ikke være meget opmærksomme på detaljerne i denne kode i Python, bortset fra det faktum, at den genererer en fil kaldet “saved_model.pb”. Vær opmærksom på at passere kortfattetheden ved at definere en lignende graf sammenlignet med Java!

8.2. Indlæser modeller fra filsystemet

Vi indlæser nu “saved_model.pb” i Java. Java TensorFlow API har SavedModelBundle at arbejde med gemte modeller:

SavedModelBundle model = SavedModelBundle.load ("./ model", "serve"); Tensor tensor = model.session (). Runner (). Hente ("z") .feed ("x", Tensor.create (3, Integer.class)) .feed ("y", Tensor.create (3, Integer.class)) .run (). Get (0) .expect (Integer.class); System.out.println (tensor.intValue ());

Det skal nu være ret intuitivt at forstå, hvad ovenstående kode gør. Det indlæser simpelthen modelgrafen fra protokolbufferen og gør sessionen tilgængelig deri. Derefter kan vi stort set gøre noget med denne graf, som vi ville have gjort for en lokalt defineret graf.

9. Konklusion

For at opsummere gik vi i denne vejledning gennem de grundlæggende begreber relateret til TensorFlow-beregningsgrafen. Vi så, hvordan man bruger TensorFlow Java API til at oprette og køre en sådan graf. Derefter talte vi om brugssagerne til Java API med hensyn til TensorFlow.

I processen forstod vi også, hvordan man visualiserer grafen ved hjælp af TensorBoard, og gemme og genindlæse en model ved hjælp af Protocol Buffer.

Som altid er koden til eksemplerne tilgængelig på GitHub.