/*
 * Decompiled with CFR 0.152.
 */
package org.encog.neural.rbf.training;

import org.encog.mathutil.rbf.RadialBasisFunction;
import org.encog.ml.TrainingImplementationType;
import org.encog.ml.data.MLDataSet;
import org.encog.ml.train.BasicTraining;
import org.encog.neural.networks.training.TrainingError;
import org.encog.neural.networks.training.propagation.TrainingContinuation;
import org.encog.neural.rbf.RBFNetwork;
import org.encog.neural.rbf.training.SVD;
import org.encog.util.ObjectPair;
import org.encog.util.simple.TrainingSetUtil;

public class SVDTraining
extends BasicTraining {
    private final RBFNetwork network;

    public SVDTraining(RBFNetwork network, MLDataSet training) {
        super(TrainingImplementationType.OnePass);
        if (network.getOutputCount() != 1) {
            throw new TrainingError("SVD requires an output layer with a single neuron.");
        }
        this.setTraining(training);
        this.network = network;
    }

    @Override
    public boolean canContinue() {
        return false;
    }

    public void flatToMatrix(double[] flat, int start, double[][] matrix) {
        int rows = matrix.length;
        int cols = matrix[0].length;
        int index = start;
        for (int r = 0; r < rows; ++r) {
            for (int c = 0; c < cols; ++c) {
                matrix[r][c] = flat[index++];
            }
        }
    }

    @Override
    public RBFNetwork getMethod() {
        return this.network;
    }

    @Override
    public void iteration() {
        int length = this.network.getRBF().length;
        RadialBasisFunction[] funcs = new RadialBasisFunction[length];
        for (int i = 0; i < length; ++i) {
            RadialBasisFunction basisFunc;
            funcs[i] = basisFunc = this.network.getRBF()[i];
        }
        ObjectPair<double[][], double[][]> data = TrainingSetUtil.trainingToArray(this.getTraining());
        double[][] matrix = new double[length][this.network.getOutputCount()];
        this.flatToMatrix(this.network.getFlat().getWeights(), 0, matrix);
        this.setError(SVD.svdfit(data.getA(), data.getB(), matrix, funcs));
        this.matrixToFlat(matrix, this.network.getFlat().getWeights(), 0);
    }

    public void matrixToFlat(double[][] matrix, double[] flat, int start) {
        int rows = matrix.length;
        int cols = matrix[0].length;
        int index = start;
        for (int r = 0; r < rows; ++r) {
            for (int c = 0; c < cols; ++c) {
                flat[index++] = matrix[r][c];
            }
        }
    }

    @Override
    public TrainingContinuation pause() {
        return null;
    }

    @Override
    public void resume(TrainingContinuation state) {
    }
}

