Logistisk regression i Java

1. Introduktion

Logistisk regression er et vigtigt instrument i værktøjskassen til maskinindlæring (ML).

I denne vejledning vi udforsker hovedideen bag logistisk regression.

Lad os først starte med en kort oversigt over ML-paradigmer og algoritmer.

2. Oversigt

ML giver os mulighed for at løse problemer, som vi kan formulere menneskeligt. Denne kendsgerning kan dog være en udfordring for os softwareudviklere. Vi har vænnet os til at løse de problemer, vi kan formulere i computervenlige termer. For eksempel kan vi som mennesker let registrere objekterne på et foto eller etablere stemningen i en sætning. Hvordan kunne vi formulere et sådant problem for en computer?

For at komme med en løsning, i ML er der et specielt stadium kaldet uddannelse. I løbet af dette trin føder vi inputdataene til vores algoritme, så den forsøger at komme med et optimalt sæt parametre (de såkaldte vægte). Jo flere inputdata vi kan føje til algoritmen, jo mere præcise forudsigelser kan vi forvente af den.

Træning er en del af en iterativ ML-arbejdsgang:

Vi starter med at indhente data. Ofte kommer dataene fra forskellige kilder. Derfor er vi nødt til at gøre det til det samme format. Vi bør også kontrollere, at datasættet retfærdigt repræsenterer undersøgelsesdomænet. Hvis modellen aldrig er blevet trænet på røde æbler, kan den næppe forudsige det.

Dernæst skal vi bygge en model, der forbruger dataene og vil være i stand til at forudsige. I ML er der ingen foruddefinerede modeller, der fungerer godt i alle situationer.

Når vi søger efter den rigtige model, kan det let ske, at vi bygger en model, træner den, ser dens forudsigelser og kasserer modellen, fordi vi ikke er tilfredse med de forudsigelser, den giver. I dette tilfælde skal vi træde tilbage og opbygge en anden model og gentage processen igen.

3. ML-paradigmer

I ML, baseret på hvilken type inputdata vi har til rådighed, kan vi udpege tre hovedparadigmer:

  • overvåget læring (billedklassificering, genkendelse af genstande, sentimentanalyse)
  • Uovervåget læring (detektion af anomali)
  • forstærkningslæring (spilstrategier)

Sagen, som vi skal beskrive i denne tutorial tilhører overvåget læring.

4. ML Værktøjskasse

I ML er der et sæt værktøjer, som vi kan anvende, når vi bygger en model. Lad os nævne nogle af dem:

  • Lineær regression
  • Logistisk regression
  • Neurale netværk
  • Support Vector Machine
  • k-nærmeste naboer

Vi kan kombinere flere værktøjer, når vi bygger en model, der har høj forudsigelighed. Faktisk til denne vejledning bruger vores model logistisk regression og neurale netværk.

5. ML-biblioteker

Selvom Java ikke er det mest populære sprog til prototyping af ML-modeller,det har et ry som et pålideligt værktøj til at skabe robust software inden for mange områder inklusive ML. Derfor finder vi muligvis ML-biblioteker skrevet i Java.

I denne sammenhæng kan vi nævne det de facto standardbibliotek Tensorflow, som også har en Java-version. En anden værd at nævne er et dyb læringsbibliotek kaldet Deeplearning4j. Dette er et meget kraftfuldt værktøj, og vi vil også bruge det i denne vejledning.

6. Logistisk regression ved ciffergenkendelse

Hovedideen med logistisk regression er at opbygge en model, der forudsiger etiketterne på inputdataene så præcist som muligt.

Vi træner modellen, indtil den såkaldte tabsfunktion eller objektive funktion når en minimal værdi. Tabsfunktionen afhænger af de aktuelle modelforudsigelser og forventede (etiketterne på inputdataene). Vores mål er at minimere forskellen mellem aktuelle modelforudsigelser og de forventede.

Hvis vi ikke er tilfredse med denne minimumsværdi, skal vi bygge en anden model og udføre træningen igen.

For at se logistisk regression i aktion illustrerer vi det med anerkendelsen af ​​håndskrevne cifre. Dette problem er allerede blevet et klassisk problem. Deeplearning4j-biblioteket har en række realistiske eksempler, der viser, hvordan man bruger dets API. Den kode-relaterede del af denne tutorial er stærkt baseret på MNIST klassifikator.

6.1. Indtastningsdata

Som inputdata bruger vi den velkendte MNIST-database med håndskrevne cifre. Som inputdata har vi 28 × 28 pixel gråtonebilleder. Hvert billede har en naturlig etiket, som er det ciffer, som billedet repræsenterer:

For at estimere effektiviteten af ​​den model, vi skal bygge, deler vi inputdataene i træning og testsæt:

DataSetIterator-tog = nyt RecordReaderDataSetIterator (...); DataSetIterator test = ny RecordReaderDataSetIterator (...);

Når vi har fået inputbillederne mærket og opdelt i de to sæt, er "dataudarbejdelse" -fasen overstået, og vi går muligvis videre til "modelbygningen".

6.2. Modelbygning

Som vi har nævnt, er der ingen modeller, der fungerer godt i enhver situation. Ikke desto mindre har forskere efter mange års forskning i ML fundet modeller, der fungerer meget godt i at genkende håndskrevne cifre. Her bruger vi den såkaldte LeNet-5-model.

LeNet-5 er et neuralt netværk, der består af en række lag, der omdanner 28 × 28 pixelbillede til en ti-dimensionel vektor:

Den ti-dimensionelle outputvektor indeholder sandsynligheder for, at inputbilledets etiket enten er 0 eller 1 eller 2 osv.

For eksempel, hvis outputvektoren har følgende form:

{0.1, 0.0, 0.3, 0.2, 0.1, 0.1, 0.0, 0.1, 0.1, 0.0}

det betyder, at sandsynligheden for, at inputbilledet er nul er 0,1, at en er 0, at to er 0,3 osv. Vi ser, at den maksimale sandsynlighed (0,3) svarer til etiket 3.

Lad os dykke ned i detaljer om modelbygning. Vi udelader Java-specifikke detaljer og koncentrerer os om ML-koncepter.

Vi satte modellen op ved at oprette en MultiLayerNetwork objekt:

MultiLayerNetwork model = ny MultiLayerNetwork (config);

I sin konstruktør skal vi passere en MultiLayerConfiguration objekt. Dette er selve objektet, der beskriver geometrien i det neurale netværk. For at definere netværksgeometrien skal vi definere hvert lag.

Lad os vise, hvordan vi gør dette med den første og den anden:

ConvolutionLayer layer1 = nyt ConvolutionLayer .Builder (5, 5) .nIn (kanaler) .stride (1, 1) .nOut (20) .aktivering (Activation.IDENTITY) .build (); SubsamplingLayer layer2 = nyt SubsamplingLayer .Builder (SubsamplingLayer.PoolingType.MAX) .kernelSize (2, 2) .stride (2, 2) .build ();

Vi ser, at lagdefinitioner indeholder en betydelig mængde ad-hoc-parametre, som påvirker hele netværksydelsen markant. Det er her, vores evne til at finde en god model i landskabet for alle bliver afgørende.

Nu er vi klar til at konstruere MultiLayerConfiguration objekt:

MultiLayerConfiguration config = ny NeuralNetConfiguration.Builder () // forberedelsestrin. Liste () .lag (lag1) .lag (lag2) // andre lag og sidste trin .build ();

at vi videregiver til MultiLayerNetwork konstruktør.

6.3. Uddannelse

Den model, vi konstruerede, indeholder 431080 parametre eller vægte. Vi vil ikke her give den nøjagtige beregning af dette tal, men vi skal være opmærksomme på, at bare tdet første lag har mere end 24x24x20 = 11520 vægte.

Træningsstadiet er så simpelt som:

model.fit (tog); 

Oprindeligt har 431080-parametrene nogle tilfældige værdier, men efter træningen tilegner de sig nogle værdier, der bestemmer modelens ydeevne. Vi evaluerer modelens forudsigelighed:

Evaluering eval = model.evaluate (test); logger.info (eval.stats ());

LeNet-5-modellen opnår en ganske høj nøjagtighed på næsten 99%, selv i kun en enkelt trænings-iteration (epoke). Hvis vi ønsker at opnå højere nøjagtighed, bør vi foretage flere iterationer ved hjælp af en almindelig for-loop:

til (int i = 0; i <epoker; i ++) {model.fit (tog); train.reset (); test.reset (); } 

6.4. Forudsigelse

Når vi nu har trænet modellen, og vi er tilfredse med dens forudsigelser på testdataene, kan vi prøve modellen på noget helt nyt input. Lad os til dette formål oprette en ny klasse Mnistforudsigelse hvor vi indlæser et billede fra en fil, som vi vælger fra filsystemet:

INDArray billede = ny NativeImageLoader (højde, bredde, kanaler) .asMatrix (fil); ny ImagePreProcessingScaler (0, 1) .transform (billede);

Variablen billede indeholder vores billede reduceret til 28 × 28 gråtoner. Vi kan give det til vores model:

INDArray output = model.output (billede);

Variablen produktion vil indeholde sandsynligheden for, at billedet er nul, en, to osv.

Lad os nu spille lidt og skrive et ciffer 2, digitalisere dette billede og give det modellen. Vi får muligvis noget som dette:

Som vi ser, har komponenten med maksimal værdi 0,99 indeks to. Det betyder, at modellen korrekt har genkendt vores håndskrevne ciffer.

7. Konklusion

I denne vejledning beskrev vi de generelle begreber maskinindlæring. Vi illustrerede disse begreber på et logistisk regressionseksempel, som vi anvendte på en håndskrevet ciffergenkendelse.

Som altid finder vi muligvis de tilsvarende kodestykker på vores GitHub-lager.


$config[zx-auto] not found$config[zx-overlay] not found