/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.contrib.rnn;

import java.util.Map;
import org.jblas.DoubleMatrix;
import org.neuroph.contrib.rnn.RNN;
import org.neuroph.contrib.rnn.util.Activation;
import org.neuroph.contrib.rnn.util.MatrixInitializer;

public final class LSTM
extends RNN {
    private DoubleMatrix inputGateInputWeight;
    private DoubleMatrix inputGateOutputWeight;
    private DoubleMatrix inputGateMemoryCellWeight;
    private DoubleMatrix inputGateBias;
    private DoubleMatrix forgetGateInputWeight;
    private DoubleMatrix forgetGateOutputWeight;
    private DoubleMatrix forgetGateMemoryCellWeight;
    private DoubleMatrix forgetGateBias;
    private DoubleMatrix memoryCellInputWeight;
    private DoubleMatrix memoryCellOutputWeight;
    private DoubleMatrix memoryCellBias;
    private DoubleMatrix outputGateInputWeight;
    private DoubleMatrix outputGateOutputWeight;
    private DoubleMatrix outputGateMemoryCellWeight;
    private DoubleMatrix outputGateBias;
    private DoubleMatrix outputWeight;
    private DoubleMatrix outputBias;

    public LSTM(int inputSize, int outputSize, MatrixInitializer matrixInitializer) {
        this.inputSize = inputSize;
        this.outputSize = outputSize;
        if (matrixInitializer.getType() == MatrixInitializer.Type.Uniform) {
            this.setUniformWeights(matrixInitializer);
        } else if (matrixInitializer.getType() == MatrixInitializer.Type.Gaussian) {
            this.setGaussianWeights(matrixInitializer);
        }
    }

    public DoubleMatrix getInputGateInputWeight() {
        return this.inputGateInputWeight;
    }

    public void setInputGateInputWeight(DoubleMatrix inputGateInputWeight) {
        this.inputGateInputWeight = inputGateInputWeight;
    }

    public DoubleMatrix getInputGateOutputWeight() {
        return this.inputGateOutputWeight;
    }

    public void setInputGateOutputWeight(DoubleMatrix inputGateOutputWeight) {
        this.inputGateOutputWeight = inputGateOutputWeight;
    }

    public DoubleMatrix getInputGateMemoryCellWeight() {
        return this.inputGateMemoryCellWeight;
    }

    public void setInputGateMemoryCellWeight(DoubleMatrix inputGateMemoryCellWeight) {
        this.inputGateMemoryCellWeight = inputGateMemoryCellWeight;
    }

    public DoubleMatrix getInputGateBias() {
        return this.inputGateBias;
    }

    public void setInputGateBias(DoubleMatrix inputGateBias) {
        this.inputGateBias = inputGateBias;
    }

    public DoubleMatrix getForgetGateInputWeight() {
        return this.forgetGateInputWeight;
    }

    public void setForgetGateInputWeight(DoubleMatrix forgetGateInputWeight) {
        this.forgetGateInputWeight = forgetGateInputWeight;
    }

    public DoubleMatrix getForgetGateOutputWeight() {
        return this.forgetGateOutputWeight;
    }

    public void setForgetGateOutputWeight(DoubleMatrix forgetGateOutputWeight) {
        this.forgetGateOutputWeight = forgetGateOutputWeight;
    }

    public DoubleMatrix getForgetGateMemoryCellWeight() {
        return this.forgetGateMemoryCellWeight;
    }

    public void setForgetGateMemoryCellWeight(DoubleMatrix forgetGateMemoryCellWeight) {
        this.forgetGateMemoryCellWeight = forgetGateMemoryCellWeight;
    }

    public DoubleMatrix getForgetGateBias() {
        return this.forgetGateBias;
    }

    public void setForgetGateBias(DoubleMatrix forgetGateBias) {
        this.forgetGateBias = forgetGateBias;
    }

    public DoubleMatrix getMemoryCellInputWeight() {
        return this.memoryCellInputWeight;
    }

    public void setMemoryCellInputWeight(DoubleMatrix memoryCellInputWeight) {
        this.memoryCellInputWeight = memoryCellInputWeight;
    }

    public DoubleMatrix getMemoryCellOutputWeight() {
        return this.memoryCellOutputWeight;
    }

    public void setMemoryCellOutputWeight(DoubleMatrix memoryCellOutputWeight) {
        this.memoryCellOutputWeight = memoryCellOutputWeight;
    }

    public DoubleMatrix getMemoryCellBias() {
        return this.memoryCellBias;
    }

    public void setMemoryCellBias(DoubleMatrix memoryCellBias) {
        this.memoryCellBias = memoryCellBias;
    }

    public DoubleMatrix getOutputGateInputWeight() {
        return this.outputGateInputWeight;
    }

    public void setOutputGateInputWeight(DoubleMatrix outputGateInputWeight) {
        this.outputGateInputWeight = outputGateInputWeight;
    }

    public DoubleMatrix getOutputGateOutputWeight() {
        return this.outputGateOutputWeight;
    }

    public void setOutputGateOutputWeight(DoubleMatrix outputGateOutputWeight) {
        this.outputGateOutputWeight = outputGateOutputWeight;
    }

    public DoubleMatrix getOutputGateMemoryCellWeight() {
        return this.outputGateMemoryCellWeight;
    }

    public void setOutputGateMemoryCellWeight(DoubleMatrix outputGateMemoryCellWeight) {
        this.outputGateMemoryCellWeight = outputGateMemoryCellWeight;
    }

    public DoubleMatrix getOutputGateBias() {
        return this.outputGateBias;
    }

    public void setOutputGateBias(DoubleMatrix outputGateBias) {
        this.outputGateBias = outputGateBias;
    }

    public DoubleMatrix getOutputWeight() {
        return this.outputWeight;
    }

    public void setOutputWeight(DoubleMatrix outputWeight) {
        this.outputWeight = outputWeight;
    }

    public DoubleMatrix getOutputBias() {
        return this.outputBias;
    }

    public void setOutputBias(DoubleMatrix outputBias) {
        this.outputBias = outputBias;
    }

    @Override
    public void activate(int timestep, Map<String, DoubleMatrix> valuesInTimesteps) {
        DoubleMatrix input = valuesInTimesteps.get("input" + timestep);
        DoubleMatrix previousOutputActivation = null;
        DoubleMatrix previousMemoryCellActivation = null;
        if (timestep == 0) {
            previousOutputActivation = new DoubleMatrix(1, this.outputSize);
            previousMemoryCellActivation = previousOutputActivation.dup();
        } else {
            previousOutputActivation = valuesInTimesteps.get("output" + (timestep - 1));
            previousMemoryCellActivation = valuesInTimesteps.get("memoryCellActivation" + (timestep - 1));
        }
        DoubleMatrix inputActivation = Activation.logistic(input.mmul(this.inputGateInputWeight).add(previousOutputActivation.mmul(this.inputGateOutputWeight)).add(previousMemoryCellActivation.mmul(this.inputGateMemoryCellWeight)).add(this.inputGateBias));
        DoubleMatrix forgetActivation = Activation.logistic(input.mmul(this.forgetGateInputWeight).add(previousOutputActivation.mmul(this.forgetGateOutputWeight)).add(previousMemoryCellActivation.mmul(this.forgetGateMemoryCellWeight)).add(this.forgetGateBias));
        DoubleMatrix memoryCellGate = Activation.tanh(input.mmul(this.memoryCellInputWeight).add(previousOutputActivation.mmul(this.memoryCellOutputWeight)).add(this.memoryCellBias));
        DoubleMatrix memoryCellActivation = forgetActivation.mul(previousMemoryCellActivation).add(inputActivation.mul(memoryCellGate));
        DoubleMatrix outputActivation = Activation.logistic(input.mmul(this.outputGateInputWeight).add(previousOutputActivation.mmul(this.outputGateOutputWeight)).add(memoryCellActivation.mmul(this.outputGateMemoryCellWeight)).add(this.outputGateBias));
        DoubleMatrix outputActivationGate = Activation.tanh(memoryCellActivation);
        DoubleMatrix output = outputActivation.mul(outputActivationGate);
        valuesInTimesteps.put("inputActivation" + timestep, inputActivation);
        valuesInTimesteps.put("forgetActivation" + timestep, forgetActivation);
        valuesInTimesteps.put("memoryCellGate" + timestep, memoryCellGate);
        valuesInTimesteps.put("memoryCellActivation" + timestep, memoryCellActivation);
        valuesInTimesteps.put("outputActivation" + timestep, outputActivation);
        valuesInTimesteps.put("outputActivationGate" + timestep, outputActivationGate);
        valuesInTimesteps.put("output" + timestep, output);
    }

    @Override
    public DoubleMatrix decode(DoubleMatrix matrix) {
        return Activation.softmax(matrix.mmul(this.outputWeight).add(this.outputBias));
    }

    @Override
    protected void setUniformWeights(MatrixInitializer matrixInitializer) {
        this.inputGateInputWeight = matrixInitializer.uniform(this.inputSize, this.outputSize);
        this.inputGateOutputWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.inputGateMemoryCellWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.inputGateBias = new DoubleMatrix(1, this.outputSize);
        this.forgetGateInputWeight = matrixInitializer.uniform(this.inputSize, this.outputSize);
        this.forgetGateOutputWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.forgetGateMemoryCellWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.forgetGateBias = new DoubleMatrix(1, this.outputSize);
        this.memoryCellInputWeight = matrixInitializer.uniform(this.inputSize, this.outputSize);
        this.memoryCellOutputWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.memoryCellBias = new DoubleMatrix(1, this.outputSize);
        this.outputGateInputWeight = matrixInitializer.uniform(this.inputSize, this.outputSize);
        this.outputGateOutputWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.outputGateMemoryCellWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.outputGateBias = new DoubleMatrix(1, this.outputSize);
        this.outputWeight = matrixInitializer.uniform(this.outputSize, this.inputSize);
        this.outputBias = new DoubleMatrix(1, this.inputSize);
    }

    @Override
    protected void setGaussianWeights(MatrixInitializer matrixInitializer) {
        this.inputGateInputWeight = matrixInitializer.gaussian(this.inputSize, this.outputSize);
        this.inputGateOutputWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.inputGateMemoryCellWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.inputGateBias = new DoubleMatrix(1, this.outputSize);
        this.forgetGateInputWeight = matrixInitializer.gaussian(this.inputSize, this.outputSize);
        this.forgetGateOutputWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.forgetGateMemoryCellWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.forgetGateBias = new DoubleMatrix(1, this.outputSize);
        this.memoryCellInputWeight = matrixInitializer.gaussian(this.inputSize, this.outputSize);
        this.memoryCellOutputWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.memoryCellBias = new DoubleMatrix(1, this.outputSize);
        this.outputGateInputWeight = matrixInitializer.gaussian(this.inputSize, this.outputSize);
        this.outputGateOutputWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.outputGateMemoryCellWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.outputGateBias = new DoubleMatrix(1, this.outputSize);
        this.outputWeight = matrixInitializer.gaussian(this.outputSize, this.inputSize);
        this.outputBias = new DoubleMatrix(1, this.inputSize);
    }
}

