/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.Serializable;
import java.util.Arrays;
import smile.classification.Classifier;
import smile.classification.ClassifierTrainer;
import smile.math.Math;
import smile.math.matrix.ColumnMajorMatrix;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.EigenValueDecomposition;
import smile.projection.Projection;

public class FLD
implements Classifier<double[]>,
Projection<double[]>,
Serializable {
    private static final long serialVersionUID = 1L;
    private final int p;
    private final int k;
    private final double[] mean;
    private final double[][] mu;
    private final double[][] scaling;
    private final double[] smean;
    private final double[][] smu;

    public FLD(double[][] x, int[] y) {
        this(x, y, -1);
    }

    public FLD(double[][] x, int[] y, int L) {
        this(x, y, L, 1.0E-4);
    }

    public FLD(double[][] x, int[] y, int L, double tol) {
        int j;
        int i;
        int l;
        int j2;
        int j3;
        int i2;
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        int[] labels = Math.unique((int[])y);
        Arrays.sort(labels);
        for (int i3 = 0; i3 < labels.length; ++i3) {
            if (labels[i3] < 0) {
                throw new IllegalArgumentException("Negative class label: " + labels[i3]);
            }
            if (i3 <= 0 || labels[i3] - labels[i3 - 1] <= 1) continue;
            throw new IllegalArgumentException("Missing class: " + labels[i3] + 1);
        }
        this.k = labels.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        if (tol < 0.0) {
            throw new IllegalArgumentException("Invalid tol: " + tol);
        }
        if (x.length <= this.k) {
            throw new IllegalArgumentException(String.format("Sample size is too small: %d <= %d", x.length, this.k));
        }
        if (L >= this.k) {
            throw new IllegalArgumentException(String.format("The dimensionality of mapped space is too high: %d >= %d", L, this.k));
        }
        if (L <= 0) {
            L = this.k - 1;
        }
        int n = x.length;
        this.p = x[0].length;
        int[] ni = new int[this.k];
        this.mean = Math.colMean((double[][])x);
        ColumnMajorMatrix T = new ColumnMajorMatrix(this.p, this.p);
        this.mu = new double[this.k][this.p];
        for (i2 = 0; i2 < n; ++i2) {
            int c;
            int n2 = c = y[i2];
            ni[n2] = ni[n2] + 1;
            for (j3 = 0; j3 < this.p; ++j3) {
                double[] dArray = this.mu[c];
                int n3 = j3;
                dArray[n3] = dArray[n3] + x[i2][j3];
            }
        }
        for (i2 = 0; i2 < this.k; ++i2) {
            for (j2 = 0; j2 < this.p; ++j2) {
                this.mu[i2][j2] = this.mu[i2][j2] / (double)ni[i2] - this.mean[j2];
            }
        }
        for (i2 = 0; i2 < n; ++i2) {
            for (j2 = 0; j2 < this.p; ++j2) {
                for (l = 0; l <= j2; ++l) {
                    T.add(j2, l, (x[i2][j2] - this.mean[j2]) * (x[i2][l] - this.mean[l]));
                }
            }
        }
        for (int j4 = 0; j4 < this.p; ++j4) {
            for (int l2 = 0; l2 <= j4; ++l2) {
                T.div(j4, l2, (double)n);
                T.set(l2, j4, T.get(j4, l2));
            }
        }
        ColumnMajorMatrix B = new ColumnMajorMatrix(this.p, this.p);
        for (int i4 = 0; i4 < this.k; ++i4) {
            for (j3 = 0; j3 < this.p; ++j3) {
                for (int l3 = 0; l3 <= j3; ++l3) {
                    B.add(j3, l3, this.mu[i4][j3] * this.mu[i4][l3]);
                }
            }
        }
        for (j2 = 0; j2 < this.p; ++j2) {
            for (l = 0; l <= j2; ++l) {
                B.div(j2, l, (double)this.k);
                B.set(l, j2, B.get(j2, l));
            }
        }
        EigenValueDecomposition eigen = new EigenValueDecomposition((DenseMatrix)T, true);
        tol *= tol;
        double[] s = eigen.getEigenValues();
        for (int i5 = 0; i5 < s.length; ++i5) {
            if (s[i5] < tol) {
                throw new IllegalArgumentException("The covariance matrix is close to singular.");
            }
            s[i5] = 1.0 / s[i5];
        }
        DenseMatrix U = eigen.getEigenVectors();
        DenseMatrix UB = (DenseMatrix)U.atbmm((Object)B);
        for (i = 0; i < this.k; ++i) {
            for (j = 0; j < this.p; ++j) {
                UB.mul(i, j, s[j]);
            }
        }
        B = (DenseMatrix)U.abmm((Object)UB);
        eigen = new EigenValueDecomposition((DenseMatrix)B, true);
        U = eigen.getEigenVectors();
        this.scaling = new double[this.p][L];
        for (i = 0; i < this.p; ++i) {
            for (j = 0; j < L; ++j) {
                this.scaling[i][j] = U.get(i, j);
            }
        }
        this.smean = new double[L];
        Math.atx((double[][])this.scaling, (double[])this.mean, (double[])this.smean);
        this.smu = Math.abmm((double[][])this.mu, (double[][])this.scaling);
    }

    @Override
    public int predict(double[] x) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        double[] wx = this.project(x);
        int y = 0;
        double nearest = Double.POSITIVE_INFINITY;
        for (int i = 0; i < this.k; ++i) {
            double d = Math.distance((double[])wx, (double[])this.smu[i]);
            if (!(d < nearest)) continue;
            nearest = d;
            y = i;
        }
        return y;
    }

    @Override
    public double[] project(double[] x) {
        if (x.length != this.p) {
            throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x.length, this.p));
        }
        double[] y = new double[this.scaling[0].length];
        Math.atx((double[][])this.scaling, (double[])x, (double[])y);
        Math.minus((double[])y, (double[])this.smean);
        return y;
    }

    public double[][] project(double[][] x) {
        double[][] y = new double[x.length][this.scaling[0].length];
        for (int i = 0; i < x.length; ++i) {
            if (x[i].length != this.p) {
                throw new IllegalArgumentException(String.format("Invalid input vector size: %d, expected: %d", x[i].length, this.p));
            }
            Math.atx((double[][])this.scaling, (double[])x[i], (double[])y[i]);
            Math.minus((double[])y[i], (double[])this.smean);
        }
        return y;
    }

    public double[][] getProjection() {
        return this.scaling;
    }

    public static class Trainer
    extends ClassifierTrainer<double[]> {
        private int L = -1;
        private double tol = 1.0E-4;

        public Trainer setDimension(int L) {
            if (L < 1) {
                throw new IllegalArgumentException("Invalid mapping space dimension: " + L);
            }
            this.L = L;
            return this;
        }

        public Trainer setTolerance(double tol) {
            if (tol < 0.0) {
                throw new IllegalArgumentException("Invalid tol: " + tol);
            }
            this.tol = tol;
            return this;
        }

        public FLD train(double[][] x, int[] y) {
            return new FLD(x, y, this.L, this.tol);
        }
    }
}

