/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.contrib.rnn.example;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.jblas.DoubleMatrix;
import org.neuroph.contrib.rnn.LSTM;
import org.neuroph.contrib.rnn.RNN;
import org.neuroph.contrib.rnn.bptt.BackPropagationThroughTime;
import org.neuroph.contrib.rnn.bptt.LSTMBackPropagationThroughTime;
import org.neuroph.contrib.rnn.util.LossFunction;
import org.neuroph.contrib.rnn.util.MatrixInitializer;
import org.neuroph.contrib.rnn.util.SequenceModeller;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.learning.LearningRule;

public class LSTMStockPricePredictionExample {
    public static void main(String[] args) {
        LSTMStockPricePredictionExample.trainNetwork();
    }

    private static void trainNetwork() {
        DataSet trainingSet = DataSet.createFromFile((String)"google-stock-price-train.csv", (int)3, (int)1, (String)",");
        DataSet testSet = DataSet.createFromFile((String)"google-stock-price-test.csv", (int)3, (int)1, (String)",");
        SequenceModeller sequenceModeller = new SequenceModeller(trainingSet);
        int inputsCount = sequenceModeller.getCharIndex().size();
        int hiddenCount = 100;
        int maxIterations = 100;
        double learningRate = 0.8;
        System.out.println("Creating neural network...");
        LSTM lstm = new LSTM(inputsCount, hiddenCount, new MatrixInitializer(MatrixInitializer.Type.Uniform, 0.1, 0.0, 0.0));
        LSTMBackPropagationThroughTime bptt = new LSTMBackPropagationThroughTime();
        bptt.setLearningRate(learningRate);
        lstm.setLearningRule((LearningRule)bptt);
        System.out.println("Training network...");
        bptt.learn(trainingSet, maxIterations);
        System.out.println("Training completed.");
        LSTMStockPricePredictionExample.testNetwork(lstm, testSet);
    }

    private static void testNetwork(RNN lstm, DataSet testSet) {
        SequenceModeller sequenceModeller = new SequenceModeller(testSet);
        Map<Integer, String> indexChar = sequenceModeller.getIndexChar();
        Map<String, DoubleMatrix> charVector = sequenceModeller.getCharVector();
        List<String> sequence = sequenceModeller.getSequence();
        System.out.println("Test set:");
        testSet.forEach(System.out::println);
        System.out.println("Prediction:");
        double error = 0.0;
        double num = 0.0;
        double start = System.currentTimeMillis();
        for (int j = 0; j < sequence.size(); ++j) {
            String seq = sequence.get(j);
            HashMap<String, DoubleMatrix> valuesInTimesteps = new HashMap<String, DoubleMatrix>();
            System.out.print(String.valueOf(seq.charAt(0)));
            for (int timestep = 0; timestep < seq.length() - 1; ++timestep) {
                DoubleMatrix input = charVector.get(String.valueOf(seq.charAt(timestep)));
                valuesInTimesteps.put("input" + timestep, input);
                lstm.activate(timestep, valuesInTimesteps);
                DoubleMatrix predictedResult = lstm.decode((DoubleMatrix)valuesInTimesteps.get("output" + timestep));
                valuesInTimesteps.put("predictedResult" + timestep, predictedResult);
                DoubleMatrix result = charVector.get(String.valueOf(seq.charAt(timestep + 1)));
                valuesInTimesteps.put("result" + timestep, result);
                System.out.print(indexChar.get(predictedResult.argmax()));
                error += LossFunction.getMeanCategoricalCrossEntropy(predictedResult, result);
            }
            System.out.println();
            BackPropagationThroughTime bptt = (BackPropagationThroughTime)lstm.getLearningRule();
            bptt.propagate(valuesInTimesteps, seq.length() - 2, bptt.getLearningRate());
            num += (double)seq.length();
        }
        System.out.println("Error = " + error / num + ", time = " + ((double)System.currentTimeMillis() - start) / 1000.0 + "s");
    }
}

