/*
 * Decompiled with CFR 0.152.
 */
package smile.regression;

import java.io.Serializable;
import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.Math;
import smile.math.matrix.BiconjugateGradient;
import smile.math.matrix.Matrix;
import smile.math.matrix.NaiveMatrix;
import smile.math.matrix.Preconditioner;
import smile.math.matrix.RowMajorMatrix;
import smile.math.matrix.SparseMatrix;
import smile.math.special.Beta;
import smile.regression.Regression;
import smile.regression.RegressionTrainer;

public class LASSO
implements Regression<double[]>,
Serializable {
    private static final long serialVersionUID = 1L;
    private static final Logger logger = LoggerFactory.getLogger(LASSO.class);
    private int p;
    private double lambda;
    private double b;
    private double[] w;
    private double ym;
    private double[] center;
    private double[] scale;
    private double[] residuals;
    private double RSS;
    private double error;
    private int df;
    private double RSquared;
    private double adjustedRSquared;
    private double F;
    private double pvalue;

    public LASSO(double[][] x, double[] y, double lambda) {
        this(x, y, lambda, 1.0E-4, 1000);
    }

    public LASSO(double[][] x, double[] y, double lambda, double tol, int maxIter) {
        int i;
        int j;
        int n = x.length;
        int p = x[0].length;
        this.center = Math.colMean((double[][])x);
        RowMajorMatrix X = new RowMajorMatrix(n, p);
        for (int i2 = 0; i2 < n; ++i2) {
            for (int j2 = 0; j2 < p; ++j2) {
                X.set(i2, j2, x[i2][j2] - this.center[j2]);
            }
        }
        this.scale = new double[p];
        for (j = 0; j < p; ++j) {
            for (i = 0; i < n; ++i) {
                int n2 = j;
                this.scale[n2] = this.scale[n2] + Math.sqr((double)X.get(i, j));
            }
            this.scale[j] = Math.sqrt((double)(this.scale[j] / (double)n));
        }
        for (j = 0; j < p; ++j) {
            if (Math.isZero((double)this.scale[j])) continue;
            for (i = 0; i < n; ++i) {
                X.div(i, j, this.scale[j]);
            }
        }
        this.train((Matrix)X, y, lambda, tol, maxIter);
        for (j = 0; j < p; ++j) {
            if (Math.isZero((double)this.scale[j])) continue;
            int n3 = j;
            this.w[n3] = this.w[n3] / this.scale[j];
        }
        this.b = this.ym - Math.dot((double[])this.w, (double[])this.center);
        this.fitness((Matrix)new NaiveMatrix(x), y);
    }

    public LASSO(Matrix x, double[] y, double lambda) {
        this(x, y, lambda, 1.0E-4, 1000);
    }

    public LASSO(Matrix x, double[] y, double lambda, double tol, int maxIter) {
        this.train(x, y, lambda, tol, maxIter);
        this.fitness(x, y);
    }

    private void train(Matrix x, double[] y, double lambda, double tol, int maxIter) {
        int ntiter;
        if (x.nrows() != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.nrows(), y.length));
        }
        if (lambda <= 0.0) {
            throw new IllegalArgumentException("Invalid shrinkage/regularization parameter lambda = " + lambda);
        }
        if (tol <= 0.0) {
            throw new IllegalArgumentException("Invalid tolerance: " + tol);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int MU = 2;
        double ALPHA = 0.01;
        double BETA = 0.5;
        int MAX_LS_ITER = 100;
        int pcgmaxi = 5000;
        double eta = 0.001;
        int pitr = 0;
        int n = x.nrows();
        this.p = x.ncols();
        double[] Y = new double[n];
        this.ym = Math.mean((double[])y);
        for (int i = 0; i < n; ++i) {
            Y[i] = y[i] - this.ym;
        }
        double t = Math.min((double)Math.max((double)1.0, (double)(1.0 / lambda)), (double)((double)(2 * this.p) / 0.001));
        double pobj = 0.0;
        double dobj = Double.NEGATIVE_INFINITY;
        double s = Double.POSITIVE_INFINITY;
        this.w = new double[this.p];
        this.b = this.ym;
        double[] u = new double[this.p];
        double[] z = new double[n];
        double[][] f = new double[2][this.p];
        Arrays.fill(u, 1.0);
        for (int i = 0; i < this.p; ++i) {
            f[0][i] = this.w[i] - u[i];
            f[1][i] = -this.w[i] - u[i];
        }
        double[] neww = new double[this.p];
        double[] newu = new double[this.p];
        double[] newz = new double[n];
        double[][] newf = new double[2][this.p];
        double[] dx = new double[this.p];
        double[] du = new double[this.p];
        double[] dxu = new double[2 * this.p];
        double[] grad = new double[2 * this.p];
        double[] diagxtx = new double[this.p];
        Arrays.fill(diagxtx, 2.0);
        double[] nu = new double[n];
        double[] xnu = new double[this.p];
        double[] q1 = new double[this.p];
        double[] q2 = new double[this.p];
        double[] d1 = new double[this.p];
        double[] d2 = new double[this.p];
        double[][] gradphi = new double[2][this.p];
        double[] prb = new double[this.p];
        double[] prs = new double[this.p];
        PCGMatrix pcg = new PCGMatrix(x, d1, d2, prb, prs);
        for (ntiter = 0; ntiter <= maxIter; ++ntiter) {
            int lsiter;
            double error;
            double gap;
            int i;
            x.ax(this.w, z);
            for (int i2 = 0; i2 < n; ++i2) {
                int n2 = i2;
                z[n2] = z[n2] - Y[i2];
                nu[i2] = 2.0 * z[i2];
            }
            x.atx(nu, xnu);
            double maxXnu = Math.normInf((double[])xnu);
            if (maxXnu > lambda) {
                double lnu = lambda / maxXnu;
                i = 0;
                while (i < n) {
                    int n3 = i++;
                    nu[n3] = nu[n3] * lnu;
                }
            }
            pobj = Math.dot((double[])z, (double[])z) + lambda * Math.norm1((double[])this.w);
            dobj = Math.max((double)(-0.25 * Math.dot((double[])nu, (double[])nu) - Math.dot((double[])nu, (double[])Y)), (double)dobj);
            if (ntiter % 10 == 0) {
                logger.info(String.format("LASSO: primal and dual objective function value after %3d iterations: %.5g\t%.5g%n", ntiter, pobj, dobj));
            }
            if ((gap = pobj - dobj) / dobj < tol) {
                logger.info(String.format("LASSO: primal and dual objective function value after %3d iterations: %.5g\t%.5g%n", ntiter, pobj, dobj));
                break;
            }
            if (s >= 0.5) {
                t = Math.max((double)Math.min((double)((double)(2 * this.p * 2) / gap), (double)(2.0 * t)), (double)t);
            }
            for (i = 0; i < this.p; ++i) {
                double q1i = 1.0 / (u[i] + this.w[i]);
                double q2i = 1.0 / (u[i] - this.w[i]);
                q1[i] = q1i;
                q2[i] = q2i;
                d1[i] = (q1i * q1i + q2i * q2i) / t;
                d2[i] = (q1i * q1i - q2i * q2i) / t;
            }
            x.atx(z, gradphi[0]);
            for (i = 0; i < this.p; ++i) {
                gradphi[0][i] = 2.0 * gradphi[0][i] - (q1[i] - q2[i]) / t;
                gradphi[1][i] = lambda - (q1[i] + q2[i]) / t;
                grad[i] = -gradphi[0][i];
                grad[i + this.p] = -gradphi[1][i];
            }
            for (i = 0; i < this.p; ++i) {
                prb[i] = diagxtx[i] + d1[i];
                prs[i] = prb[i] * d1[i] - d2[i] * d2[i];
            }
            double normg = Math.norm((double[])grad);
            double pcgtol = Math.min((double)0.1, (double)(0.001 * gap / Math.min((double)1.0, (double)normg)));
            if (ntiter != 0 && pitr == 0) {
                pcgtol *= 0.1;
            }
            if ((error = BiconjugateGradient.solve((Matrix)pcg, (Preconditioner)pcg, (double[])grad, (double[])dxu, (double)pcgtol, (int)1, (int)5000)) > pcgtol) {
                pitr = 5000;
            }
            for (int i3 = 0; i3 < this.p; ++i3) {
                dx[i3] = dxu[i3];
                du[i3] = dxu[i3 + this.p];
            }
            double phi = Math.dot((double[])z, (double[])z) + lambda * Math.sum((double[])u) - this.sumlogneg(f) / t;
            s = 1.0;
            double gdx = Math.dot((double[])grad, (double[])dxu);
            for (lsiter = 0; lsiter < 100; ++lsiter) {
                int i4;
                for (i4 = 0; i4 < this.p; ++i4) {
                    neww[i4] = this.w[i4] + s * dx[i4];
                    newu[i4] = u[i4] + s * du[i4];
                    newf[0][i4] = neww[i4] - newu[i4];
                    newf[1][i4] = -neww[i4] - newu[i4];
                }
                if (Math.max((double[][])newf) < 0.0) {
                    x.ax(neww, newz);
                    for (i4 = 0; i4 < n; ++i4) {
                        int n4 = i4;
                        newz[n4] = newz[n4] - Y[i4];
                    }
                    double newphi = Math.dot((double[])newz, (double[])newz) + lambda * Math.sum((double[])newu) - this.sumlogneg(newf) / t;
                    if (newphi - phi <= 0.01 * s * gdx) break;
                }
                s = 0.5 * s;
            }
            if (lsiter == 100) {
                logger.error("LASSO: Too many iterations of line search.");
                break;
            }
            System.arraycopy(neww, 0, this.w, 0, this.p);
            System.arraycopy(newu, 0, u, 0, this.p);
            System.arraycopy(newf[0], 0, f[0], 0, this.p);
            System.arraycopy(newf[1], 0, f[1], 0, this.p);
        }
        if (ntiter == maxIter) {
            logger.error("LASSO: Too many iterations.");
        }
    }

    private void fitness(Matrix x, double[] y) {
        int n = y.length;
        double[] yhat = new double[n];
        x.ax(this.w, yhat);
        double TSS = 0.0;
        this.RSS = 0.0;
        double ybar = Math.mean((double[])y);
        this.residuals = new double[n];
        for (int i = 0; i < n; ++i) {
            double r;
            this.residuals[i] = r = y[i] - yhat[i] - this.b;
            this.RSS += Math.sqr((double)r);
            TSS += Math.sqr((double)(y[i] - ybar));
        }
        this.error = Math.sqrt((double)(this.RSS / (double)(n - this.p - 1)));
        this.df = n - this.p - 1;
        this.RSquared = 1.0 - this.RSS / TSS;
        this.adjustedRSquared = 1.0 - (1.0 - this.RSquared) * (double)(n - 1) / (double)(n - this.p - 1);
        this.F = (TSS - this.RSS) * (double)(n - this.p - 1) / (this.RSS * (double)this.p);
        int df1 = this.p;
        int df2 = n - this.p - 1;
        this.pvalue = Beta.regularizedIncompleteBetaFunction((double)(0.5 * (double)df2), (double)(0.5 * (double)df1), (double)((double)df2 / ((double)df2 + (double)df1 * this.F)));
    }

    private double sumlogneg(double[][] f) {
        int m = f.length;
        int n = f[0].length;
        double sum = 0.0;
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                sum += Math.log((double)(-f[i][j]));
            }
        }
        return sum;
    }

    public double[] coefficients() {
        return this.w;
    }

    public double intercept() {
        return this.b;
    }

    public double shrinkage() {
        return this.lambda;
    }

    @Override
    public double predict(double[] x) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        return Math.dot((double[])x, (double[])this.w) + this.b;
    }

    public double[] residuals() {
        return this.residuals;
    }

    public double RSS() {
        return this.RSS;
    }

    public double error() {
        return this.error;
    }

    public int df() {
        return this.df;
    }

    public double RSquared() {
        return this.RSquared;
    }

    public double adjustedRSquared() {
        return this.adjustedRSquared;
    }

    public double ftest() {
        return this.F;
    }

    public double pvalue() {
        return this.pvalue;
    }

    public String toString() {
        StringBuilder builder = new StringBuilder();
        builder.append("LASSO:\n");
        double[] r = (double[])this.residuals.clone();
        builder.append("\nResiduals:\n");
        builder.append("\t       Min\t        1Q\t    Median\t        3Q\t       Max\n");
        builder.append(String.format("\t%10.4f\t%10.4f\t%10.4f\t%10.4f\t%10.4f%n", Math.min((double[])r), Math.q1((double[])r), Math.median((double[])r), Math.q3((double[])r), Math.max((double[])r)));
        builder.append("\nCoefficients:\n");
        builder.append("            Estimate\n");
        builder.append(String.format("Intercept%11.4f%n", this.b));
        for (int i = 0; i < this.p; ++i) {
            builder.append(String.format("Var %d\t %11.4f%n", i + 1, this.w[i]));
        }
        builder.append(String.format("\nResidual standard error: %.4f on %d degrees of freedom%n", this.error, this.df));
        builder.append(String.format("Multiple R-squared: %.4f,    Adjusted R-squared: %.4f%n", this.RSquared, this.adjustedRSquared));
        builder.append(String.format("F-statistic: %.4f on %d and %d DF,  p-value: %.4g%n", this.F, this.p, this.df, this.pvalue));
        return builder.toString();
    }

    class PCGMatrix
    implements Matrix,
    Preconditioner {
        Matrix A;
        Matrix AtA;
        double[] d1;
        double[] d2;
        double[] prb;
        double[] prs;
        double[] ax;
        double[] atax;

        PCGMatrix(Matrix A, double[] d1, double[] d2, double[] prb, double[] prs) {
            this.A = A;
            this.d1 = d1;
            this.d2 = d2;
            this.prb = prb;
            this.prs = prs;
            int n = A.nrows();
            this.ax = new double[n];
            this.atax = new double[LASSO.this.p];
            if (A.ncols() < 10000 && !(A instanceof SparseMatrix)) {
                this.AtA = A.ata();
            }
        }

        public int nrows() {
            return 2 * LASSO.this.p;
        }

        public int ncols() {
            return 2 * LASSO.this.p;
        }

        public double[] ax(double[] x, double[] y) {
            if (this.AtA != null) {
                this.AtA.ax(x, this.atax);
            } else {
                this.A.ax(x, this.ax);
                this.A.atx(this.ax, this.atax);
            }
            for (int i = 0; i < LASSO.this.p; ++i) {
                y[i] = 2.0 * this.atax[i] + this.d1[i] * x[i] + this.d2[i] * x[i + LASSO.this.p];
                y[i + ((LASSO)LASSO.this).p] = this.d2[i] * x[i] + this.d1[i] * x[i + LASSO.this.p];
            }
            return y;
        }

        public double[] atx(double[] x, double[] y) {
            return this.ax(x, y);
        }

        public void asolve(double[] b, double[] x) {
            for (int i = 0; i < LASSO.this.p; ++i) {
                x[i] = (this.d1[i] * b[i] - this.d2[i] * b[i + LASSO.this.p]) / this.prs[i];
                x[i + ((LASSO)LASSO.this).p] = (-this.d2[i] * b[i] + this.prb[i] * b[i + LASSO.this.p]) / this.prs[i];
            }
        }

        public Matrix transpose() {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        public Matrix aat() {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        public Matrix ata() {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        public double get(int i, int j) {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        public double apply(int i, int j) {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        public double[] axpy(double[] x, double[] y) {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        public double[] axpy(double[] x, double[] y, double b) {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        public double[] atxpy(double[] x, double[] y) {
            throw new UnsupportedOperationException("Not supported yet.");
        }

        public double[] atxpy(double[] x, double[] y, double b) {
            throw new UnsupportedOperationException("Not supported yet.");
        }
    }

    public static class Trainer
    extends RegressionTrainer<double[]> {
        private double lambda;
        private double tol = 0.001;
        private int maxIter = 1000;

        public Trainer(double lambda) {
            if (lambda < 0.0) {
                throw new IllegalArgumentException("Invalid shrinkage/regularization parameter lambda = " + lambda);
            }
            this.lambda = lambda;
        }

        public Trainer setTolerance(double tol) {
            if (tol <= 0.0) {
                throw new IllegalArgumentException("Invalid tolerance: " + tol);
            }
            this.tol = tol;
            return this;
        }

        public Trainer setMaxNumIteration(int maxIter) {
            if (maxIter <= 0) {
                throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
            }
            this.maxIter = maxIter;
            return this;
        }

        public LASSO train(double[][] x, double[] y) {
            return new LASSO(x, y, this.lambda, this.tol, this.maxIter);
        }

        public LASSO train(Matrix x, double[] y) {
            return new LASSO(x, y, this.lambda, this.tol, this.maxIter);
        }
    }
}

