/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.contrib.rnn.bptt;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;
import org.neuroph.contrib.rnn.RNN;
import org.neuroph.contrib.rnn.util.LossFunction;
import org.neuroph.contrib.rnn.util.SequenceModeller;
import org.neuroph.core.data.DataSet;
import org.neuroph.nnet.learning.BackPropagation;

public abstract class BackPropagationThroughTime
extends BackPropagation {
    public void learn(DataSet trainingSet, int maxIterations) {
        SequenceModeller sequenceModeller = new SequenceModeller(trainingSet);
        Map<String, DoubleMatrix> charVector = sequenceModeller.getCharVector();
        List<String> sequence = sequenceModeller.getSequence();
        for (int i = 0; i < maxIterations; ++i) {
            double error = 0.0;
            double num = 0.0;
            double start = System.currentTimeMillis();
            for (int j = 0; j < sequence.size(); ++j) {
                String seq = sequence.get(j);
                if (seq.length() < 3) continue;
                RNN rnn = (RNN)this.getNeuralNetwork();
                HashMap<String, DoubleMatrix> valuesInTimesteps = new HashMap<String, DoubleMatrix>();
                for (int timestep = 0; timestep < seq.length() - 1; ++timestep) {
                    DoubleMatrix input = charVector.get(String.valueOf(seq.charAt(timestep)));
                    valuesInTimesteps.put("input" + timestep, input);
                    rnn.activate(timestep, valuesInTimesteps);
                    DoubleMatrix predictedResult = rnn.decode((DoubleMatrix)valuesInTimesteps.get("output" + timestep));
                    valuesInTimesteps.put("predictedResult" + timestep, predictedResult);
                    DoubleMatrix result = charVector.get(String.valueOf(seq.charAt(timestep + 1)));
                    valuesInTimesteps.put("result" + timestep, result);
                    error += LossFunction.getMeanCategoricalCrossEntropy(predictedResult, result);
                }
                BackPropagationThroughTime bptt = (BackPropagationThroughTime)rnn.getLearningRule();
                bptt.propagate(valuesInTimesteps, seq.length() - 2, bptt.getLearningRate());
                num += (double)seq.length();
            }
            System.out.println("Iteration = " + (i + 1) + ", error = " + error / num + ", time = " + ((double)System.currentTimeMillis() - start) / 1000.0 + "s");
        }
    }

    public abstract void propagate(Map<String, DoubleMatrix> var1, int var2, double var3);

    protected abstract void updateParameters(Map<String, DoubleMatrix> var1, int var2, double var3, RNN var5);

    protected DoubleMatrix deriveExp(DoubleMatrix matrix) {
        return matrix.mul(DoubleMatrix.ones((int)1, (int)matrix.length).sub(matrix));
    }

    protected DoubleMatrix deriveTanh(DoubleMatrix matrix) {
        return DoubleMatrix.ones((int)1, (int)matrix.length).sub(MatrixFunctions.pow((DoubleMatrix)matrix, (double)2.0));
    }
}

