/*
 * Decompiled with CFR 0.152.
 */
package smile.classification;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.Arrays;
import smile.classification.Classifier;
import smile.classification.ClassifierTrainer;
import smile.math.Math;
import smile.math.distance.Metric;
import smile.math.matrix.ColumnMajorMatrix;
import smile.math.matrix.DenseMatrix;
import smile.math.matrix.QRDecomposition;
import smile.math.rbf.GaussianRadialBasis;
import smile.math.rbf.RadialBasisFunction;
import smile.util.SmileUtils;

public class RBFNetwork<T>
implements Classifier<T>,
Serializable {
    private static final long serialVersionUID = 1L;
    private int k;
    private T[] centers;
    private DenseMatrix w;
    private Metric<T> distance;
    private RadialBasisFunction[] rbf;
    private boolean normalized;

    public RBFNetwork(T[] x, int[] y, Metric<T> distance, RadialBasisFunction rbf, T[] centers) {
        this(x, y, distance, rbf, centers, false);
    }

    public RBFNetwork(T[] x, int[] y, Metric<T> distance, RadialBasisFunction[] rbf, T[] centers) {
        this(x, y, distance, rbf, centers, false);
    }

    public RBFNetwork(T[] x, int[] y, Metric<T> distance, RadialBasisFunction rbf, T[] centers, boolean normalized) {
        this(x, y, distance, RBFNetwork.rep(rbf, centers.length), centers, normalized);
    }

    private static RadialBasisFunction[] rep(RadialBasisFunction rbf, int k) {
        Object[] arr = new RadialBasisFunction[k];
        Arrays.fill(arr, rbf);
        return arr;
    }

    public RBFNetwork(T[] x, int[] y, Metric<T> distance, RadialBasisFunction[] rbf, T[] centers, boolean normalized) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The sizes of X and Y don't match: %d != %d", x.length, y.length));
        }
        if (rbf.length != centers.length) {
            throw new IllegalArgumentException(String.format("The sizes of RBF functions and centers don't match: %d != %d", rbf.length, centers.length));
        }
        int[] labels = Math.unique((int[])y);
        Arrays.sort(labels);
        for (int i = 0; i < labels.length; ++i) {
            if (labels[i] < 0) {
                throw new IllegalArgumentException("Negative class label: " + labels[i]);
            }
            if (i <= 0 || labels[i] - labels[i - 1] <= 1) continue;
            throw new IllegalArgumentException("Missing class: " + labels[i] + 1);
        }
        this.k = labels.length;
        if (this.k < 2) {
            throw new IllegalArgumentException("Only one class.");
        }
        this.centers = centers;
        this.distance = distance;
        this.rbf = rbf;
        this.normalized = normalized;
        int n = x.length;
        int m = rbf.length;
        this.w = new ColumnMajorMatrix(m + 1, this.k);
        ColumnMajorMatrix G = new ColumnMajorMatrix(n, m + 1);
        ColumnMajorMatrix b = new ColumnMajorMatrix(n, this.k);
        for (int i = 0; i < n; ++i) {
            double sum = 0.0;
            for (int j = 0; j < m; ++j) {
                double r = rbf[j].f(distance.d(x[i], centers[j]));
                G.set(i, j, r);
                sum += r;
            }
            G.set(i, m, 1.0);
            if (normalized) {
                b.set(i, y[i], sum);
                continue;
            }
            b.set(i, y[i], 1.0);
        }
        QRDecomposition qr = new QRDecomposition((DenseMatrix)G);
        qr.solve((DenseMatrix)b, this.w);
    }

    @Override
    public int predict(T x) {
        int j;
        int j2;
        double[] sumw = new double[this.k];
        double sum = 0.0;
        for (int i = 0; i < this.rbf.length; ++i) {
            double f = this.rbf[i].f(this.distance.d(x, this.centers[i]));
            sum += f;
            for (j2 = 0; j2 < this.k; ++j2) {
                int n = j2;
                sumw[n] = sumw[n] + this.w.get(i, j2) * f;
            }
        }
        if (this.normalized) {
            for (j = 0; j < this.k; ++j) {
                sumw[j] = (sumw[j] + this.w.get(this.centers.length, j)) / sum;
            }
        } else {
            for (j = 0; j < this.k; ++j) {
                int n = j;
                sumw[n] = sumw[n] + this.w.get(this.centers.length, j);
            }
        }
        double max = Double.NEGATIVE_INFINITY;
        int y = 0;
        for (j2 = 0; j2 < this.k; ++j2) {
            if (!(max < sumw[j2])) continue;
            max = sumw[j2];
            y = j2;
        }
        return y;
    }

    public static class Trainer<T>
    extends ClassifierTrainer<T> {
        private int m = 10;
        private Metric<T> distance;
        private RadialBasisFunction[] rbf;
        private boolean normalized = false;

        public Trainer(Metric<T> distance) {
            this.distance = distance;
        }

        public Trainer setRBF(RadialBasisFunction rbf, int m) {
            this.m = m;
            this.rbf = RBFNetwork.rep(rbf, m);
            return this;
        }

        public Trainer setRBF(RadialBasisFunction[] rbf) {
            this.m = rbf.length;
            this.rbf = rbf;
            return this;
        }

        public Trainer setNormalized(boolean normalized) {
            this.normalized = normalized;
            return this;
        }

        @Override
        public RBFNetwork<T> train(T[] x, int[] y) {
            Object[] centers = (Object[])Array.newInstance(x.getClass().getComponentType(), this.m);
            GaussianRadialBasis gaussian = SmileUtils.learnGaussianRadialBasis(x, centers, this.distance);
            if (this.rbf == null) {
                return new RBFNetwork<Object>(x, y, this.distance, (RadialBasisFunction)gaussian, centers, this.normalized);
            }
            return new RBFNetwork<Object>(x, y, this.distance, this.rbf, centers, this.normalized);
        }

        public RBFNetwork<T> train(T[] x, int[] y, T[] centers) {
            return new RBFNetwork<T>(x, y, this.distance, this.rbf, centers, this.normalized);
        }
    }
}

