/*
 * Decompiled with CFR 0.152.
 */
package org.apache.commons.math3.distribution.fitting;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution;
import org.apache.commons.math3.distribution.MultivariateNormalDistribution;
import org.apache.commons.math3.exception.ConvergenceException;
import org.apache.commons.math3.exception.DimensionMismatchException;
import org.apache.commons.math3.exception.NotStrictlyPositiveException;
import org.apache.commons.math3.exception.NumberIsTooLargeException;
import org.apache.commons.math3.exception.NumberIsTooSmallException;
import org.apache.commons.math3.exception.util.Localizable;
import org.apache.commons.math3.exception.util.LocalizedFormats;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.stat.correlation.Covariance;
import org.apache.commons.math3.util.MathArrays;
import org.apache.commons.math3.util.Pair;

public class MultivariateNormalMixtureExpectationMaximization {
    private static final int DEFAULT_MAX_ITERATIONS = 1000;
    private static final double DEFAULT_THRESHOLD = 1.0E-5;
    private final double[][] data;
    private MixtureMultivariateNormalDistribution fittedModel;
    private double logLikelihood = 0.0;

    public MultivariateNormalMixtureExpectationMaximization(double[][] data) throws NotStrictlyPositiveException, DimensionMismatchException, NumberIsTooSmallException {
        if (data.length < 1) {
            throw new NotStrictlyPositiveException(data.length);
        }
        this.data = new double[data.length][data[0].length];
        int i = 0;
        while (i < data.length) {
            if (data[i].length != data[0].length) {
                throw new DimensionMismatchException(data[i].length, data[0].length);
            }
            if (data[i].length < 2) {
                throw new NumberIsTooSmallException((Localizable)LocalizedFormats.NUMBER_TOO_SMALL, (Number)data[i].length, 2, true);
            }
            this.data[i] = MathArrays.copyOf(data[i], data[i].length);
            ++i;
        }
    }

    public void fit(MixtureMultivariateNormalDistribution initialMixture, int maxIterations, double threshold) throws SingularMatrixException, NotStrictlyPositiveException, DimensionMismatchException {
        if (maxIterations < 1) {
            throw new NotStrictlyPositiveException(maxIterations);
        }
        if (threshold < Double.MIN_VALUE) {
            throw new NotStrictlyPositiveException(threshold);
        }
        int n = this.data.length;
        int numCols = this.data[0].length;
        int k = initialMixture.getComponents().size();
        int numMeanColumns = ((MultivariateNormalDistribution)initialMixture.getComponents().get(0).getSecond()).getMeans().length;
        if (numMeanColumns != numCols) {
            throw new DimensionMismatchException(numMeanColumns, numCols);
        }
        int numIterations = 0;
        double previousLogLikelihood = 0.0;
        this.logLikelihood = Double.NEGATIVE_INFINITY;
        this.fittedModel = new MixtureMultivariateNormalDistribution(initialMixture.getComponents());
        while (numIterations++ <= maxIterations && Math.abs(previousLogLikelihood - this.logLikelihood) > threshold) {
            int j;
            int j2;
            previousLogLikelihood = this.logLikelihood;
            double sumLogLikelihood = 0.0;
            List components = this.fittedModel.getComponents();
            double[] weights = new double[k];
            MultivariateNormalDistribution[] mvns = new MultivariateNormalDistribution[k];
            int j3 = 0;
            while (j3 < k) {
                weights[j3] = components.get(j3).getFirst();
                mvns[j3] = (MultivariateNormalDistribution)components.get(j3).getSecond();
                ++j3;
            }
            double[][] gamma = new double[n][k];
            double[] gammaSums = new double[k];
            double[][] gammaDataProdSums = new double[k][numCols];
            int i = 0;
            while (i < n) {
                double rowDensity = this.fittedModel.density(this.data[i]);
                sumLogLikelihood += Math.log(rowDensity);
                j2 = 0;
                while (j2 < k) {
                    gamma[i][j2] = weights[j2] * mvns[j2].density(this.data[i]) / rowDensity;
                    int n2 = j2;
                    gammaSums[n2] = gammaSums[n2] + gamma[i][j2];
                    int col = 0;
                    while (col < numCols) {
                        double[] dArray = gammaDataProdSums[j2];
                        int n3 = col;
                        dArray[n3] = dArray[n3] + gamma[i][j2] * this.data[i][col];
                        ++col;
                    }
                    ++j2;
                }
                ++i;
            }
            this.logLikelihood = sumLogLikelihood / (double)n;
            double[] newWeights = new double[k];
            double[][] newMeans = new double[k][numCols];
            int j4 = 0;
            while (j4 < k) {
                newWeights[j4] = gammaSums[j4] / (double)n;
                int col = 0;
                while (col < numCols) {
                    newMeans[j4][col] = gammaDataProdSums[j4][col] / gammaSums[j4];
                    ++col;
                }
                ++j4;
            }
            RealMatrix[] newCovMats = new RealMatrix[k];
            j2 = 0;
            while (j2 < k) {
                newCovMats[j2] = new Array2DRowRealMatrix(numCols, numCols);
                ++j2;
            }
            int i2 = 0;
            while (i2 < n) {
                j = 0;
                while (j < k) {
                    Array2DRowRealMatrix vec = new Array2DRowRealMatrix(MathArrays.ebeSubtract(this.data[i2], newMeans[j]));
                    RealMatrix dataCov = vec.multiply(vec.transpose()).scalarMultiply(gamma[i2][j]);
                    newCovMats[j] = newCovMats[j].add(dataCov);
                    ++j;
                }
                ++i2;
            }
            double[][][] newCovMatArrays = new double[k][numCols][numCols];
            j = 0;
            while (j < k) {
                newCovMats[j] = newCovMats[j].scalarMultiply(1.0 / gammaSums[j]);
                newCovMatArrays[j] = newCovMats[j].getData();
                ++j;
            }
            this.fittedModel = new MixtureMultivariateNormalDistribution(newWeights, newMeans, newCovMatArrays);
        }
        if (Math.abs(previousLogLikelihood - this.logLikelihood) > threshold) {
            throw new ConvergenceException();
        }
    }

    public void fit(MixtureMultivariateNormalDistribution initialMixture) throws SingularMatrixException, NotStrictlyPositiveException {
        this.fit(initialMixture, 1000, 1.0E-5);
    }

    public static MixtureMultivariateNormalDistribution estimate(double[][] data, int numComponents) throws NotStrictlyPositiveException, DimensionMismatchException {
        if (data.length < 2) {
            throw new NotStrictlyPositiveException(data.length);
        }
        if (numComponents < 2) {
            throw new NumberIsTooSmallException(numComponents, (Number)2, true);
        }
        if (numComponents > data.length) {
            throw new NumberIsTooLargeException(numComponents, (Number)data.length, true);
        }
        int numRows = data.length;
        int numCols = data[0].length;
        Object[] sortedData = new DataRow[numRows];
        int i = 0;
        while (i < numRows) {
            sortedData[i] = new DataRow(data[i]);
            ++i;
        }
        Arrays.sort(sortedData);
        double weight = 1.0 / (double)numComponents;
        ArrayList<Pair<Double, MultivariateNormalDistribution>> components = new ArrayList<Pair<Double, MultivariateNormalDistribution>>();
        int binIndex = 0;
        while (binIndex < numComponents) {
            int minIndex = binIndex * numRows / numComponents;
            int maxIndex = (binIndex + 1) * numRows / numComponents;
            int numBinRows = maxIndex - minIndex;
            double[][] binData = new double[numBinRows][numCols];
            double[] columnMeans = new double[numCols];
            int i2 = minIndex;
            int iBin = 0;
            while (i2 < maxIndex) {
                int j = 0;
                while (j < numCols) {
                    double val = ((DataRow)sortedData[i2]).getRow()[j];
                    int n = j;
                    columnMeans[n] = columnMeans[n] + val;
                    binData[iBin][j] = val;
                    ++j;
                }
                ++i2;
                ++iBin;
            }
            MathArrays.scaleInPlace(1.0 / (double)numBinRows, columnMeans);
            double[][] covMat = new Covariance(binData).getCovarianceMatrix().getData();
            MultivariateNormalDistribution mvn = new MultivariateNormalDistribution(columnMeans, covMat);
            components.add(new Pair<Double, MultivariateNormalDistribution>(weight, mvn));
            ++binIndex;
        }
        return new MixtureMultivariateNormalDistribution((List<Pair<Double, MultivariateNormalDistribution>>)components);
    }

    public double getLogLikelihood() {
        return this.logLikelihood;
    }

    public MixtureMultivariateNormalDistribution getFittedModel() {
        return new MixtureMultivariateNormalDistribution(this.fittedModel.getComponents());
    }

    private static class DataRow
    implements Comparable<DataRow> {
        private final double[] row;
        private Double mean;

        DataRow(double[] data) {
            this.row = data;
            this.mean = 0.0;
            int i = 0;
            while (i < data.length) {
                this.mean = this.mean + data[i];
                ++i;
            }
            this.mean = this.mean / (double)data.length;
        }

        @Override
        public int compareTo(DataRow other) {
            return this.mean.compareTo(other.mean);
        }

        public boolean equals(Object other) {
            if (this == other) {
                return true;
            }
            if (other instanceof DataRow) {
                return MathArrays.equals(this.row, ((DataRow)other).row);
            }
            return false;
        }

        public int hashCode() {
            return Arrays.hashCode(this.row);
        }

        public double[] getRow() {
            return this.row;
        }
    }
}

