/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.feature.vectorindexer;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.Estimator;
import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.feature.vectorindexer.VectorIndexerModel;
import org.apache.flink.ml.feature.vectorindexer.VectorIndexerModelData;
import org.apache.flink.ml.feature.vectorindexer.VectorIndexerParams;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.BoundedOneInput;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.table.types.AbstractDataType;
import org.apache.flink.table.types.DataType;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

public class VectorIndexer
implements Estimator<VectorIndexer, VectorIndexerModel>,
VectorIndexerParams<VectorIndexer> {
    private final Map<Param<?>, Object> paramMap = new HashMap();

    public VectorIndexer() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    @Override
    public VectorIndexerModel fit(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        int maxCategories = this.getMaxCategories();
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        SingleOutputStreamOperator localDistinctDoubles = tEnv.toDataStream(inputs[0]).transform("computeDistinctDoublesOperator", Types.OBJECT_ARRAY((TypeInformation)Types.LIST((TypeInformation)Types.DOUBLE)), (OneInputStreamOperator)new ComputeDistinctDoublesOperator(this.getInputCol(), maxCategories));
        DataStream distinctDoubles = DataStreamUtils.reduce(localDistinctDoubles, (ReduceFunction & Serializable)(value1, value2) -> {
            for (int i = 0; i < ((List[])value1).length; ++i) {
                if (value1[i] == null || value2[i] == null) {
                    value1[i] = null;
                    continue;
                }
                HashSet tmp = new HashSet(value1[i]);
                tmp.addAll(value2[i]);
                value1[i] = new ArrayList(tmp);
            }
            return value1;
        });
        SingleOutputStreamOperator modelData = distinctDoubles.map((MapFunction)new ModelGenerator(maxCategories), VectorIndexerModelData.TYPE_INFO);
        modelData.getTransformation().setParallelism(1);
        Schema schema = Schema.newBuilder().column("categoryMaps", (AbstractDataType)DataTypes.MAP((DataType)DataTypes.INT(), (DataType)DataTypes.MAP((DataType)DataTypes.DOUBLE(), (DataType)DataTypes.INT()))).build();
        VectorIndexerModel model = new VectorIndexerModel().setModelData(tEnv.fromDataStream((DataStream)modelData, schema));
        ParamUtils.updateExistingParams(model, this.paramMap);
        return model;
    }

    @Override
    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata(this, path);
    }

    public static VectorIndexer load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (VectorIndexer)ReadWriteUtils.loadStageParam(path);
    }

    @Override
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    private static class ModelGenerator
    implements MapFunction<List<Double>[], VectorIndexerModelData> {
        private final int maxCategories;

        public ModelGenerator(int maxCategories) {
            this.maxCategories = maxCategories;
        }

        public VectorIndexerModelData map(List<Double>[] distinctDoubles) {
            HashMap<Integer, Map<Double, Integer>> categoryMaps = new HashMap<Integer, Map<Double, Integer>>();
            for (int i = 0; i < distinctDoubles.length; ++i) {
                if (distinctDoubles[i] == null || distinctDoubles[i].size() > this.maxCategories) continue;
                double[] values = distinctDoubles[i].stream().mapToDouble(Double::doubleValue).toArray();
                Arrays.sort(values);
                int index0 = Arrays.binarySearch(values, 0.0);
                while (index0 > 0) {
                    values[index0--] = values[index0];
                }
                if (index0 == 0) {
                    values[index0] = 0.0;
                }
                HashMap<Double, Integer> valueAndIndex = new HashMap<Double, Integer>(values.length);
                for (int valueIdx = 0; valueIdx < values.length; ++valueIdx) {
                    valueAndIndex.put(values[valueIdx], valueIdx);
                }
                categoryMaps.put(i, valueAndIndex);
            }
            return new VectorIndexerModelData(categoryMaps);
        }
    }

    private static class ComputeDistinctDoublesOperator
    extends AbstractStreamOperator<List<Double>[]>
    implements OneInputStreamOperator<Row, List<Double>[]>,
    BoundedOneInput {
        private final String inputCol;
        private final int maxCategories;
        private HashSet<Double>[] doublesByColumn;
        private ListState<List<Double>[]> doublesByColumnState;

        public ComputeDistinctDoublesOperator(String inputCol, int maxCategories) {
            this.inputCol = inputCol;
            this.maxCategories = maxCategories;
        }

        public void endInput() {
            if (this.doublesByColumn != null) {
                this.output.collect((Object)new StreamRecord(this.convertToListArray(this.doublesByColumn)));
            }
            this.doublesByColumnState.clear();
        }

        public void processElement(StreamRecord<Row> element) {
            Vector vector;
            if (this.doublesByColumn == null) {
                vector = (Vector)((Row)element.getValue()).getField(this.inputCol);
                this.doublesByColumn = new HashSet[vector.size()];
                for (int i = 0; i < this.doublesByColumn.length; ++i) {
                    this.doublesByColumn[i] = new HashSet();
                }
            }
            Preconditions.checkState(((vector = (Vector)((Row)element.getValue()).getField(this.inputCol)).size() == this.doublesByColumn.length ? 1 : 0) != 0, (Object)"The size of the all input vectors should be the same.");
            double[] values = vector.toDense().values;
            for (int i = 0; i < values.length; ++i) {
                if (this.doublesByColumn[i] == null) continue;
                this.doublesByColumn[i].add(values[i]);
                if (this.doublesByColumn[i].size() <= this.maxCategories) continue;
                this.doublesByColumn[i] = null;
            }
        }

        public void initializeState(StateInitializationContext context) throws Exception {
            super.initializeState(context);
            this.doublesByColumnState = context.getOperatorStateStore().getListState(new ListStateDescriptor("doublesByColumnState", Types.OBJECT_ARRAY((TypeInformation)Types.LIST((TypeInformation)Types.DOUBLE))));
            OperatorStateUtils.getUniqueElement(this.doublesByColumnState, "doublesByColumnState").ifPresent(x -> {
                this.doublesByColumn = this.convertToHashSetArray((List<Double>[])x);
            });
        }

        public void snapshotState(StateSnapshotContext context) throws Exception {
            super.snapshotState(context);
            if (this.doublesByColumn != null) {
                this.doublesByColumnState.update(Collections.singletonList(this.convertToListArray(this.doublesByColumn)));
            }
        }

        private List<Double>[] convertToListArray(HashSet<Double>[] array) {
            ArrayList[] results = new ArrayList[array.length];
            for (int i = 0; i < array.length; ++i) {
                results[i] = new ArrayList<Double>(array[i]);
            }
            return results;
        }

        private HashSet<Double>[] convertToHashSetArray(List<Double>[] array) {
            HashSet[] results = new HashSet[array.length];
            for (int i = 0; i < array.length; ++i) {
                results[i] = new HashSet<Double>(array[i]);
            }
            return results;
        }
    }
}

