/*
 * Decompiled with CFR 0.152.
 */
package org.neuroph.contrib.rnn;

import java.util.Map;
import org.jblas.DoubleMatrix;
import org.neuroph.contrib.rnn.RNN;
import org.neuroph.contrib.rnn.util.Activation;
import org.neuroph.contrib.rnn.util.MatrixInitializer;

public final class GRU
extends RNN {
    private DoubleMatrix resetGateInputWeight;
    private DoubleMatrix resetGateOutputWeight;
    private DoubleMatrix resetGateBias;
    private DoubleMatrix updateGateInputWeight;
    private DoubleMatrix updateGateOutputWeight;
    private DoubleMatrix updateGateBias;
    private DoubleMatrix memoryCellInputWeight;
    private DoubleMatrix memoryCellOutputWeight;
    private DoubleMatrix memoryCellBias;
    private DoubleMatrix outputWeight;
    private DoubleMatrix outputBias;

    public GRU(int inputSize, int outputSize, MatrixInitializer matrixInitializer) {
        this.inputSize = inputSize;
        this.outputSize = outputSize;
        if (matrixInitializer.getType() == MatrixInitializer.Type.Uniform) {
            this.setUniformWeights(matrixInitializer);
        } else if (matrixInitializer.getType() == MatrixInitializer.Type.Gaussian) {
            this.setGaussianWeights(matrixInitializer);
        }
    }

    public DoubleMatrix getResetGateInputWeight() {
        return this.resetGateInputWeight;
    }

    public void setResetGateInputWeight(DoubleMatrix resetGateInputWeight) {
        this.resetGateInputWeight = resetGateInputWeight;
    }

    public DoubleMatrix getResetGateOutputWeight() {
        return this.resetGateOutputWeight;
    }

    public void setResetGateOutputWeight(DoubleMatrix resetGateOutputWeight) {
        this.resetGateOutputWeight = resetGateOutputWeight;
    }

    public DoubleMatrix getResetGateBias() {
        return this.resetGateBias;
    }

    public void setResetGateBias(DoubleMatrix resetGateBias) {
        this.resetGateBias = resetGateBias;
    }

    public DoubleMatrix getUpdateGateInputWeight() {
        return this.updateGateInputWeight;
    }

    public void setUpdateGateInputWeight(DoubleMatrix updateGateInputWeight) {
        this.updateGateInputWeight = updateGateInputWeight;
    }

    public DoubleMatrix getUpdateGateOutputWeight() {
        return this.updateGateOutputWeight;
    }

    public void setUpdateGateOutputWeight(DoubleMatrix updateGateOutputWeight) {
        this.updateGateOutputWeight = updateGateOutputWeight;
    }

    public DoubleMatrix getUpdateGateBias() {
        return this.updateGateBias;
    }

    public void setUpdateGateBias(DoubleMatrix updateGateBias) {
        this.updateGateBias = updateGateBias;
    }

    public DoubleMatrix getMemoryCellInputWeight() {
        return this.memoryCellInputWeight;
    }

    public void setMemoryCellInputWeight(DoubleMatrix memoryCellInputWeight) {
        this.memoryCellInputWeight = memoryCellInputWeight;
    }

    public DoubleMatrix getMemoryCellOutputWeight() {
        return this.memoryCellOutputWeight;
    }

    public void setMemoryCellOutputWeight(DoubleMatrix memoryCellOutputWeight) {
        this.memoryCellOutputWeight = memoryCellOutputWeight;
    }

    public DoubleMatrix getMemoryCellBias() {
        return this.memoryCellBias;
    }

    public void setMemoryCellBias(DoubleMatrix memoryCellBias) {
        this.memoryCellBias = memoryCellBias;
    }

    public DoubleMatrix getOutputWeight() {
        return this.outputWeight;
    }

    public void setOutputWeight(DoubleMatrix outputWeight) {
        this.outputWeight = outputWeight;
    }

    public DoubleMatrix getOutputBias() {
        return this.outputBias;
    }

    public void setOutputBias(DoubleMatrix outputBias) {
        this.outputBias = outputBias;
    }

    @Override
    public void activate(int timestep, Map<String, DoubleMatrix> valuesInTimesteps) {
        DoubleMatrix input = valuesInTimesteps.get("input" + timestep);
        DoubleMatrix previousOutput = null;
        previousOutput = timestep == 0 ? new DoubleMatrix(1, this.outputSize) : valuesInTimesteps.get("output" + (timestep - 1));
        DoubleMatrix resetActivation = Activation.logistic(input.mmul(this.resetGateInputWeight).add(previousOutput.mmul(this.resetGateOutputWeight)).add(this.resetGateBias));
        DoubleMatrix updateActivation = Activation.logistic(input.mmul(this.updateGateInputWeight).add(previousOutput.mmul(this.updateGateOutputWeight)).add(this.updateGateBias));
        DoubleMatrix memoryCellGate = Activation.tanh(input.mmul(this.memoryCellInputWeight).add(resetActivation.mul(previousOutput).mmul(this.memoryCellOutputWeight)).add(this.memoryCellBias));
        DoubleMatrix output = DoubleMatrix.ones((int)1, (int)updateActivation.columns).sub(updateActivation).mul(previousOutput).add(updateActivation.mul(memoryCellGate));
        valuesInTimesteps.put("resetActivation" + timestep, resetActivation);
        valuesInTimesteps.put("updateActivation" + timestep, updateActivation);
        valuesInTimesteps.put("memoryCellGate" + timestep, memoryCellGate);
        valuesInTimesteps.put("output" + timestep, output);
    }

    @Override
    public DoubleMatrix decode(DoubleMatrix matrix) {
        return Activation.softmax(matrix.mmul(this.outputWeight).add(this.outputBias));
    }

    @Override
    protected void setUniformWeights(MatrixInitializer matrixInitializer) {
        this.resetGateInputWeight = matrixInitializer.uniform(this.inputSize, this.outputSize);
        this.resetGateOutputWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.resetGateBias = new DoubleMatrix(1, this.outputSize);
        this.updateGateInputWeight = matrixInitializer.uniform(this.inputSize, this.outputSize);
        this.updateGateOutputWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.updateGateBias = new DoubleMatrix(1, this.outputSize);
        this.memoryCellInputWeight = matrixInitializer.uniform(this.inputSize, this.outputSize);
        this.memoryCellOutputWeight = matrixInitializer.uniform(this.outputSize, this.outputSize);
        this.memoryCellBias = new DoubleMatrix(1, this.outputSize);
        this.outputWeight = matrixInitializer.uniform(this.outputSize, this.inputSize);
        this.outputBias = new DoubleMatrix(1, this.inputSize);
    }

    @Override
    protected void setGaussianWeights(MatrixInitializer matrixInitializer) {
        this.resetGateInputWeight = matrixInitializer.gaussian(this.inputSize, this.outputSize);
        this.resetGateOutputWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.resetGateBias = new DoubleMatrix(1, this.outputSize);
        this.updateGateInputWeight = matrixInitializer.gaussian(this.inputSize, this.outputSize);
        this.updateGateOutputWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.updateGateBias = new DoubleMatrix(1, this.outputSize);
        this.memoryCellInputWeight = matrixInitializer.gaussian(this.inputSize, this.outputSize);
        this.memoryCellOutputWeight = matrixInitializer.gaussian(this.outputSize, this.outputSize);
        this.memoryCellBias = new DoubleMatrix(1, this.outputSize);
        this.outputWeight = matrixInitializer.gaussian(this.outputSize, this.inputSize);
        this.outputBias = new DoubleMatrix(1, this.inputSize);
    }
}

