/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.feature.standardscaler;

import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.java.typeutils.TupleTypeInfo;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.feature.standardscaler.StandardScalerModel;
import org.apache.flink.ml.feature.standardscaler.StandardScalerModelData;
import org.apache.flink.ml.feature.standardscaler.StandardScalerParams;
import org.apache.flink.ml.linalg.BLAS;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

public class StandardScaler
implements Estimator<StandardScaler, StandardScalerModel>,
StandardScalerParams<StandardScaler> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public StandardScaler() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    @Override
    public StandardScalerModel fit(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        SingleOutputStreamOperator sumAndSquaredSumAndWeight = tEnv.toDataStream(inputs[0]).transform("computeMeta", (TypeInformation)new TupleTypeInfo(new TypeInformation[]{TypeInformation.of(DenseVector.class), TypeInformation.of(DenseVector.class), BasicTypeInfo.LONG_TYPE_INFO}), (OneInputStreamOperator)new ComputeMetaOperator(this.getInputCol()));
        SingleOutputStreamOperator modelData = sumAndSquaredSumAndWeight.transform("buildModel", TypeInformation.of(StandardScalerModelData.class), (OneInputStreamOperator)new BuildModelOperator()).setParallelism(1);
        StandardScalerModel model = new StandardScalerModel().setModelData(tEnv.fromDataStream((DataStream)modelData));
        ParamUtils.updateExistingParams(model, this.paramMap);
        return model;
    }

    @Override
    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata(this, path);
    }

    public static StandardScaler load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (StandardScaler)ReadWriteUtils.loadStageParam(path);
    }

    @Override
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    private static class ComputeMetaOperator
    extends AbstractStreamOperator<Tuple3<DenseVector, DenseVector, Long>>
    implements OneInputStreamOperator<Row, Tuple3<DenseVector, DenseVector, Long>>,
    BoundedOneInput {
        private ListState<DenseVector> sumState;
        private ListState<DenseVector> squaredSumState;
        private ListState<Long> numElementsState;
        private DenseVector sum;
        private DenseVector squaredSum;
        private long numElements;
        private final String inputCol;

        public ComputeMetaOperator(String inputCol) {
            this.inputCol = inputCol;
        }

        public void endInput() {
            if (this.numElements > 0L) {
                this.output.collect((Object)new StreamRecord((Object)Tuple3.of((Object)this.sum, (Object)this.squaredSum, (Object)this.numElements)));
            }
        }

        public void processElement(StreamRecord<Row> element) {
            Vector inputVec = (Vector)((Row)element.getValue()).getField(this.inputCol);
            if (this.numElements == 0L) {
                this.sum = new DenseVector(inputVec.size());
                this.squaredSum = new DenseVector(inputVec.size());
            }
            BLAS.axpy(1.0, inputVec, this.sum);
            BLAS.hDot(inputVec, inputVec);
            BLAS.axpy(1.0, inputVec, this.squaredSum);
            ++this.numElements;
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.sumState = context.getOperatorStateStore().getListState(new ListStateDescriptor("sumState", TypeInformation.of(DenseVector.class)));
            this.squaredSumState = context.getOperatorStateStore().getListState(new ListStateDescriptor("squaredSumState", TypeInformation.of(DenseVector.class)));
            this.numElementsState = context.getOperatorStateStore().getListState(new ListStateDescriptor("numElementsState", (TypeInformation)BasicTypeInfo.LONG_TYPE_INFO));
            this.sum = OperatorStateUtils.getUniqueElement(this.sumState, "sumState").orElse(null);
            this.squaredSum = OperatorStateUtils.getUniqueElement(this.squaredSumState, "squaredSumState").orElse(null);
            this.numElements = OperatorStateUtils.getUniqueElement(this.numElementsState, "numElementsState").orElse(0L);
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            if (this.numElements > 0L) {
                this.sumState.update(Collections.singletonList(this.sum));
                this.squaredSumState.update(Collections.singletonList(this.squaredSum));
                this.numElementsState.update(Collections.singletonList(this.numElements));
            }
        }
    }

    private static class BuildModelOperator
    extends AbstractStreamOperator<StandardScalerModelData>
    implements OneInputStreamOperator<Tuple3<DenseVector, DenseVector, Long>, StandardScalerModelData>,
    BoundedOneInput {
        private ListState<DenseVector> sumState;
        private ListState<DenseVector> squaredSumState;
        private ListState<Long> numElementsState;
        private DenseVector sum;
        private DenseVector squaredSum;
        private long numElements;

        private BuildModelOperator() {
        }

        public void endInput() {
            double[] std;
            double[] mean;
            if (this.numElements > 0L) {
                BLAS.scal(1.0 / (double)this.numElements, this.sum);
                mean = this.sum.values;
                std = this.squaredSum.values;
                if (this.numElements > 1L) {
                    for (int i = 0; i < mean.length; ++i) {
                        std[i] = Math.sqrt((this.squaredSum.values[i] - (double)this.numElements * mean[i] * mean[i]) / (double)(this.numElements - 1L));
                    }
                } else {
                    Arrays.fill(std, 0.0);
                }
            } else {
                throw new RuntimeException("The training set is empty.");
            }
            this.output.collect((Object)new StreamRecord((Object)new StandardScalerModelData(Vectors.dense(mean), Vectors.dense(std))));
        }

        public void processElement(StreamRecord<Tuple3<DenseVector, DenseVector, Long>> element) {
            Tuple3 value = (Tuple3)element.getValue();
            if (this.numElements == 0L) {
                this.sum = (DenseVector)value.f0;
                this.squaredSum = (DenseVector)value.f1;
                this.numElements = (Long)value.f2;
            } else {
                BLAS.axpy(1.0, (Vector)value.f0, this.sum);
                BLAS.axpy(1.0, (Vector)value.f1, this.squaredSum);
                this.numElements += ((Long)value.f2).longValue();
            }
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.sumState = context.getOperatorStateStore().getListState(new ListStateDescriptor("sumState", TypeInformation.of(DenseVector.class)));
            this.squaredSumState = context.getOperatorStateStore().getListState(new ListStateDescriptor("squaredSumState", TypeInformation.of(DenseVector.class)));
            this.numElementsState = context.getOperatorStateStore().getListState(new ListStateDescriptor("numElementsState", (TypeInformation)BasicTypeInfo.LONG_TYPE_INFO));
            this.sum = OperatorStateUtils.getUniqueElement(this.sumState, "sumState").orElse(null);
            this.squaredSum = OperatorStateUtils.getUniqueElement(this.squaredSumState, "squaredSumState").orElse(null);
            this.numElements = OperatorStateUtils.getUniqueElement(this.numElementsState, "numElementsState").orElse(0L);
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            if (this.numElements > 0L) {
                this.sumState.update(Collections.singletonList(this.sum));
                this.squaredSumState.update(Collections.singletonList(this.squaredSum));
                this.numElementsState.update(Collections.singletonList(this.numElements));
            }
        }
    }
}

