/*
 * Decompiled with CFR 0.152.
 */
package com.github.stanfordfuturedata.momentsketch;

import com.github.stanfordfuturedata.momentsketch.optimizer.FunctionWithHessian;
import org.apache.commons.math3.util.FastMath;

public class DMaxentLoss
implements FunctionWithHessian {
    protected int dim;
    protected int nGrid;
    protected double[] d_mus;
    protected double[] xs;
    protected double[][] cpVals;
    protected double[] lambd;
    protected double[] weights;
    protected double[] mus;
    protected double[] grad;
    protected double[][] hess;

    public DMaxentLoss(double[] d_mus, int nGrid) {
        int i;
        this.dim = d_mus.length;
        this.nGrid = nGrid;
        this.d_mus = d_mus;
        this.xs = new double[nGrid];
        for (i = 0; i < nGrid; ++i) {
            this.xs[i] = (double)i * 2.0 / (double)(nGrid - 1) - 1.0;
        }
        this.cpVals = new double[2 * this.dim][nGrid];
        for (i = 0; i < nGrid; ++i) {
            this.cpVals[0][i] = 1.0;
            this.cpVals[1][i] = this.xs[i];
        }
        for (int j = 2; j < 2 * this.dim; ++j) {
            for (int i2 = 0; i2 < nGrid; ++i2) {
                this.cpVals[j][i2] = 2.0 * this.xs[i2] * this.cpVals[j - 1][i2] - this.cpVals[j - 2][i2];
            }
        }
        int k = this.dim;
        this.weights = new double[nGrid];
        this.mus = new double[2 * k];
        this.grad = new double[k];
        this.hess = new double[k][k];
    }

    public void setLambd(double[] newLambd) {
        this.lambd = newLambd;
    }

    @Override
    public void computeOnlyValue(double[] point, double tol) {
        this.computeAll(point, tol);
    }

    @Override
    public void computeAll(double[] point, double tol) {
        int j;
        double sum;
        int i;
        this.setLambd(point);
        for (i = 0; i < this.nGrid; ++i) {
            sum = 0.0;
            for (j = 0; j < this.dim; ++j) {
                sum += this.lambd[j] * this.cpVals[j][i];
            }
            this.weights[i] = FastMath.exp((double)sum);
        }
        for (i = 0; i < 2 * this.dim; ++i) {
            sum = 0.0;
            for (j = 0; j < this.nGrid; ++j) {
                sum += this.cpVals[i][j] * this.weights[j];
            }
            this.mus[i] = sum;
        }
        for (i = 0; i < this.dim; ++i) {
            this.grad[i] = this.mus[i] - this.d_mus[i];
        }
        for (i = 0; i < this.dim; ++i) {
            for (int j2 = 0; j2 < this.dim; ++j2) {
                this.hess[i][j2] = 0.5 * (this.mus[i + j2] + this.mus[FastMath.abs((int)(i - j2))]);
            }
        }
    }

    public double[] getWeights() {
        return this.weights;
    }

    @Override
    public int dim() {
        return this.dim;
    }

    @Override
    public double getValue() {
        double sum = 0.0;
        int k = this.d_mus.length;
        for (int i = 0; i < k; ++i) {
            sum += this.lambd[i] * this.d_mus[i];
        }
        return this.mus[0] - sum;
    }

    @Override
    public double[] getGradient() {
        return this.grad;
    }

    @Override
    public double[][] getHessian() {
        return this.hess;
    }
}

