/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.transform.encode;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.estim.ComEstSample;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.data.DenseBlockFP64DEDUP;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockCSR;
import org.apache.sysds.runtime.data.SparseBlockMCSR;
import org.apache.sysds.runtime.data.SparseRowVector;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderDummycode;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderUDF;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderWordEmbedding;
import org.apache.sysds.runtime.transform.encode.CompressedEncode;
import org.apache.sysds.runtime.transform.encode.Encoder;
import org.apache.sysds.runtime.transform.encode.EncoderMVImpute;
import org.apache.sysds.runtime.transform.encode.EncoderOmit;
import org.apache.sysds.runtime.transform.encode.LegacyEncoder;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.DependencyTask;
import org.apache.sysds.runtime.util.DependencyThreadPool;
import org.apache.sysds.runtime.util.DependencyWrapperTask;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.utils.MemoryEstimates;
import org.apache.sysds.utils.stats.TransformStatistics;

public class MultiColumnEncoder
implements Encoder {
    protected static final Log LOG = LogFactory.getLog((String)MultiColumnEncoder.class.getName());
    public static boolean MULTI_THREADED_STAGES = ConfigurationManager.isStagedParallelTransform();
    public static boolean APPLY_ENCODER_SEPARATE_STAGES = false;
    private List<ColumnEncoderComposite> _columnEncoders;
    private EncoderMVImpute _legacyMVImpute = null;
    private EncoderOmit _legacyOmit = null;
    private int _colOffset = 0;
    private FrameBlock _meta = null;
    private boolean _partitionDone = false;

    public MultiColumnEncoder(List<ColumnEncoderComposite> columnEncoders) {
        this._columnEncoders = columnEncoders;
    }

    public MultiColumnEncoder(MultiColumnEncoder menc) {
        List<ColumnEncoderComposite> colEncs = menc._columnEncoders;
        this._columnEncoders = new ArrayList<ColumnEncoderComposite>();
        for (ColumnEncoderComposite cColEnc : colEncs) {
            ArrayList<ColumnEncoder> newEncs = new ArrayList<ColumnEncoder>();
            ColumnEncoderComposite cColEncCopy = new ColumnEncoderComposite(newEncs, cColEnc._colID);
            this._columnEncoders.add(cColEncCopy);
            for (ColumnEncoder enc : cColEnc.getEncoders()) {
                newEncs.add(enc instanceof ColumnEncoderBagOfWords ? new ColumnEncoderBagOfWords((ColumnEncoderBagOfWords)enc) : enc);
            }
        }
    }

    public MultiColumnEncoder() {
        this._columnEncoders = new ArrayList<ColumnEncoderComposite>();
    }

    public MatrixBlock encode(CacheBlock<?> in) {
        return this.encode(in, 1);
    }

    public MatrixBlock encode(CacheBlock<?> in, int k) {
        return this.encode(in, k, false);
    }

    public MatrixBlock encode(CacheBlock<?> in, boolean compressedOut) {
        return this.encode(in, 1, compressedOut);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public MatrixBlock encode(CacheBlock<?> in, int k, boolean compressedOut) {
        try {
            if (this.isCompressedTransformEncode(in, compressedOut)) {
                return CompressedEncode.encode(this, (FrameBlock)in, k);
            }
            this.deriveNumRowPartitions(in, k);
            if (k > 1 && !MULTI_THREADED_STAGES && !this.hasLegacyEncoder()) {
                MatrixBlock out = new MatrixBlock();
                DependencyThreadPool pool = new DependencyThreadPool(k);
                LOG.debug((Object)("Encoding with full DAG on " + k + " Threads"));
                try {
                    List<DependencyTask<?>> tasks = this.getEncodeTasks(in, out, pool);
                    pool.submitAllAndWait(tasks);
                }
                finally {
                    pool.shutdown();
                }
                this.outputMatrixPostProcessing(out, k);
                this.outputLogging(out);
                return out;
            }
            LOG.debug((Object)("Encoding with staged approach on: " + k + " Threads"));
            long t0 = System.nanoTime();
            this.build(in, k);
            long t1 = System.nanoTime();
            LOG.debug((Object)("Elapsed time for build phase: " + ((double)t1 - (double)t0) / 1000000.0 + " ms"));
            if (this._legacyMVImpute != null) {
                this._meta = this.getMetaData(new FrameBlock(in.getNumColumns(), Types.ValueType.STRING));
                this.initMetaData(this._meta);
            }
            t0 = System.nanoTime();
            MatrixBlock out = this.apply(in, k);
            t1 = System.nanoTime();
            LOG.debug((Object)("Elapsed time for apply phase: " + ((double)t1 - (double)t0) / 1000000.0 + " ms"));
            this.outputLogging(out);
            return out;
        }
        catch (Exception ex) {
            String st = this.toString();
            st = st.substring(0, Math.min(st.length(), 1000));
            throw new DMLRuntimeException("Failed transform-encode frame with encoder:\n" + st, ex);
        }
    }

    private void outputLogging(MatrixBlock out) {
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)("Transform Encode output mem size: " + out.getInMemorySize()));
            LOG.debug((Object)String.format("Transform Encode output rows     : %10d", out.getNumRows()));
            LOG.debug((Object)String.format("Transform Encode output cols     : %10d", out.getNumColumns()));
            LOG.debug((Object)String.format("Transform Encode output sparsity : %10.5f", out.getSparsity()));
            LOG.debug((Object)String.format("Transform Encode output nnz      : %10d", out.getNonZeros()));
        }
    }

    protected List<ColumnEncoderComposite> getEncoders() {
        return this._columnEncoders;
    }

    private List<DependencyTask<?>> getEncodeTasks(CacheBlock<?> in, MatrixBlock out, DependencyThreadPool pool) {
        ArrayList tasks = new ArrayList();
        ArrayList<ApplyTasksWrapperTask> applyTAgg = null;
        HashMap<Integer[], Integer[]> depMap = new HashMap<Integer[], Integer[]>();
        boolean hasDC = !this.getColumnEncoders(ColumnEncoderDummycode.class).isEmpty();
        boolean hasBOW = !this.getColumnEncoders(ColumnEncoderBagOfWords.class).isEmpty();
        boolean applyOffsetDep = false;
        boolean independentUpdateDC = false;
        this._meta = new FrameBlock(in.getNumColumns(), Types.ValueType.STRING);
        tasks.add(DependencyThreadPool.createDependencyTask(new InitOutputMatrixTask(this, in, out)));
        tasks.add(DependencyThreadPool.createDependencyTask(new AllocMetaTask(this, this._meta)));
        for (ColumnEncoderComposite e : this._columnEncoders) {
            List<DependencyTask<?>> buildTasks = e.getBuildTasks(in);
            tasks.addAll(buildTasks);
            boolean compositeHasDC = e.hasEncoder(ColumnEncoderDummycode.class);
            boolean compositeHasBOW = e.hasEncoder(ColumnEncoderBagOfWords.class);
            if (!buildTasks.isEmpty()) {
                if (compositeHasDC && buildTasks.size() > 1 && !buildTasks.get(buildTasks.size() - 2).hasDependency(buildTasks.get(buildTasks.size() - 1))) {
                    independentUpdateDC = true;
                }
                if (independentUpdateDC) {
                    depMap.put(new Integer[]{tasks.size(), tasks.size() + 1}, new Integer[]{tasks.size() - 2, tasks.size() - 1});
                    depMap.put(new Integer[]{tasks.size() + 1, tasks.size() + 2}, new Integer[]{tasks.size() - 2, tasks.size() - 1});
                } else {
                    depMap.put(new Integer[]{tasks.size(), tasks.size() + 1}, new Integer[]{tasks.size() - 1, tasks.size()});
                    depMap.put(new Integer[]{tasks.size() + 1, tasks.size() + 2}, new Integer[]{tasks.size() - 1, tasks.size()});
                }
                if (compositeHasDC && buildTasks.size() > 1) {
                    depMap.put(new Integer[]{1, 2}, new Integer[]{tasks.size() - 2, tasks.size() - 1});
                } else {
                    depMap.put(new Integer[]{1, 2}, new Integer[]{tasks.size() - 1, tasks.size()});
                }
            }
            depMap.put(new Integer[]{tasks.size() + 1, tasks.size() + 2}, new Integer[]{1, 2});
            depMap.put(new Integer[]{tasks.size(), tasks.size() + 1}, new Integer[]{0, 1});
            ApplyTasksWrapperTask applyTaskWrapper = new ApplyTasksWrapperTask(e, in, out, pool);
            if (compositeHasDC || compositeHasBOW) {
                depMap.put(new Integer[]{0, 1}, new Integer[]{tasks.size() - 1, tasks.size()});
            }
            if (compositeHasDC || compositeHasBOW) {
                depMap.put(new Integer[]{-2, -1}, new Integer[]{tasks.size() - 1, tasks.size()});
                buildTasks.forEach(t -> t.setPriority(5));
                applyOffsetDep = true;
            }
            if ((hasDC || hasBOW) && applyOffsetDep) {
                depMap.put(new Integer[]{tasks.size(), tasks.size() + 1}, new Integer[]{-2, -1});
                applyTAgg = applyTAgg == null ? new ArrayList<ApplyTasksWrapperTask>() : applyTAgg;
                applyTAgg.add(applyTaskWrapper);
            } else {
                applyTaskWrapper.setOffset(0);
            }
            tasks.add(applyTaskWrapper);
            tasks.add(DependencyThreadPool.createDependencyTask(new ColumnMetaDataTask<ColumnEncoderComposite>(e, this._meta)));
        }
        if (hasDC || hasBOW) {
            tasks.add(DependencyThreadPool.createDependencyTask(new UpdateOutputColTask(this, applyTAgg)));
        }
        ArrayList<Object> deps = new ArrayList<Object>(Collections.nCopies(tasks.size(), null));
        DependencyThreadPool.createDependencyList(tasks, depMap, deps);
        return DependencyThreadPool.createDependencyTasks(tasks, deps);
    }

    @Override
    public void build(CacheBlock<?> in) {
        this.build(in, 1);
    }

    public void build(CacheBlock<?> in, int k) {
        this.build(in, k, null);
    }

    public void build(CacheBlock<?> in, int k, Map<Integer, double[]> equiHeightBinMaxs) {
        if (this.hasLegacyEncoder() && !(in instanceof FrameBlock)) {
            throw new DMLRuntimeException("LegacyEncoders do not support non FrameBlock Inputs");
        }
        if (!this._partitionDone) {
            this.deriveNumRowPartitions(in, k);
        }
        if (k > 1) {
            this.buildMT(in, k);
        } else {
            for (ColumnEncoderComposite columnEncoder : this._columnEncoders) {
                columnEncoder.build(in, equiHeightBinMaxs);
                columnEncoder.updateAllDCEncoders();
            }
        }
        if (this.hasLegacyEncoder()) {
            this.legacyBuild((FrameBlock)in);
        }
    }

    private List<DependencyTask<?>> getBuildTasks(CacheBlock<?> in) {
        ArrayList tasks = new ArrayList();
        for (ColumnEncoderComposite columnEncoder : this._columnEncoders) {
            tasks.addAll(columnEncoder.getBuildTasks(in));
        }
        return tasks;
    }

    private void buildMT(CacheBlock<?> in, int k) {
        DependencyThreadPool pool = new DependencyThreadPool(k);
        try {
            pool.submitAllAndWait(this.getBuildTasks(in));
        }
        catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException(e);
        }
        finally {
            pool.shutdown();
        }
    }

    public void legacyBuild(FrameBlock in) {
        if (this._legacyOmit != null) {
            this._legacyOmit.build(in);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.build(in);
        }
    }

    public MatrixBlock apply(CacheBlock<?> in) {
        return this.apply(in, 1);
    }

    public MatrixBlock apply(CacheBlock<?> in, int k) {
        EncoderMeta encm = MultiColumnEncoder.getEncMeta(this._columnEncoders, true, k, in);
        this.updateAllDCEncoders();
        int numCols = this.getNumOutCols();
        long estNNz = (long)in.getNumRows() * (encm.hasWE || encm.hasUDF ? (long)numCols : (long)(in.getNumColumns() - encm.numBOWEnc) + encm.nnzBOW);
        boolean sparse = MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz) && !encm.hasUDF;
        MatrixBlock out = new MatrixBlock(in.getNumRows(), numCols, sparse, estNNz);
        return this.apply(in, out, 0, k, encm, estNNz);
    }

    public void updateAllDCEncoders() {
        for (ColumnEncoderComposite columnEncoder : this._columnEncoders) {
            columnEncoder.updateAllDCEncoders();
        }
    }

    @Override
    public MatrixBlock apply(CacheBlock<?> in, MatrixBlock out, int outputCol) {
        throw new DMLRuntimeException("MultiColumnEncoder apply without Encoder Characteristics should not be called directly");
    }

    public MatrixBlock apply(CacheBlock<?> in, MatrixBlock out, int outputCol, int k, EncoderMeta encm, long nnz) {
        if (this.hasLegacyEncoder() && !(in instanceof FrameBlock)) {
            throw new DMLRuntimeException("LegacyEncoders do not support non FrameBlock Inputs");
        }
        int numEncoders = this.getEncoders().size();
        if (in.getNumColumns() != numEncoders) {
            throw new DMLRuntimeException("Not every column in has a CompositeEncoder. Please make sure every column has a encoder or slice the input accordingly: num encoders:  " + this.getEncoders() + " vs columns " + in.getNumColumns());
        }
        if (in.getNumRows() == 0) {
            throw new DMLRuntimeException("Invalid input with wrong number or rows");
        }
        ArrayList<int[]> nnzOffsets = MultiColumnEncoder.outputMatrixPreProcessing(out, in, encm, nnz, k);
        if (k > 1) {
            if (!this._partitionDone) {
                this.deriveNumRowPartitions(in, k);
            }
            this.applyMT(in, out, outputCol, k, nnzOffsets);
        } else {
            int offset = outputCol;
            int i = 0;
            int[] nnzOffset = null;
            for (ColumnEncoderComposite columnEncoder : this._columnEncoders) {
                columnEncoder.sparseRowPointerOffset = nnzOffset;
                columnEncoder.apply(in, out, columnEncoder._colID - 1 + offset);
                offset = this.getOutputColOffset(offset, columnEncoder);
                nnzOffset = nnzOffsets != null ? nnzOffsets.get(i++) : null;
            }
        }
        this.outputMatrixPostProcessing(out, k);
        if (this._legacyOmit != null) {
            out = this._legacyOmit.apply((FrameBlock)in, out);
        }
        if (this._legacyMVImpute != null) {
            out = this._legacyMVImpute.apply((FrameBlock)in, out);
        }
        return out;
    }

    private List<DependencyTask<?>> getApplyTasks(CacheBlock<?> in, MatrixBlock out, int outputCol, ArrayList<int[]> nnzOffsets) {
        ArrayList tasks = new ArrayList();
        int offset = outputCol;
        int i = 0;
        int[] currentNnzOffsets = null;
        for (ColumnEncoderComposite e : this._columnEncoders) {
            tasks.addAll(e.getApplyTasks(in, out, e._colID - 1 + offset, currentNnzOffsets));
            currentNnzOffsets = nnzOffsets != null ? nnzOffsets.get(i++) : null;
            offset = this.getOutputColOffset(offset, e);
        }
        return tasks;
    }

    private int getOutputColOffset(int offset, ColumnEncoderComposite e) {
        if (e.hasEncoder(ColumnEncoderDummycode.class)) {
            offset += e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
        }
        if (e.hasEncoder(ColumnEncoderWordEmbedding.class)) {
            offset += e.getEncoder(ColumnEncoderWordEmbedding.class).getDomainSize() - 1;
        }
        if (e.hasEncoder(ColumnEncoderBagOfWords.class)) {
            offset += e.getEncoder(ColumnEncoderBagOfWords.class).getDomainSize() - 1;
        }
        return offset;
    }

    private void applyMT(CacheBlock<?> in, MatrixBlock out, int outputCol, int k, ArrayList<int[]> nnzOffsets) {
        DependencyThreadPool pool = new DependencyThreadPool(k);
        try {
            if (APPLY_ENCODER_SEPARATE_STAGES) {
                int offset = outputCol;
                int i = 0;
                int[] currentNnzOffsets = null;
                for (ColumnEncoderComposite e : this._columnEncoders) {
                    pool.submitAllAndWait(e.getApplyTasks(in, out, e._colID - 1 + offset, currentNnzOffsets));
                    offset = this.getOutputColOffset(offset, e);
                    currentNnzOffsets = nnzOffsets != null ? nnzOffsets.get(i) : null;
                    ++i;
                }
            } else {
                pool.submitAllAndWait(this.getApplyTasks(in, out, outputCol, nnzOffsets));
            }
        }
        catch (InterruptedException | ExecutionException e) {
            throw new DMLRuntimeException(e);
        }
        finally {
            pool.shutdown();
        }
    }

    private void deriveNumRowPartitions(CacheBlock<?> in, int k) {
        int[] numBlocks = new int[2];
        if (k == 1) {
            numBlocks[0] = 1;
            numBlocks[1] = 1;
            this._columnEncoders.forEach(e -> e.setNumPartitions(1, 1));
            this._partitionDone = true;
            return;
        }
        if (ColumnEncoder.BUILD_ROW_BLOCKS_PER_COLUMN > 0) {
            numBlocks[0] = ColumnEncoder.BUILD_ROW_BLOCKS_PER_COLUMN;
        }
        if (ColumnEncoder.APPLY_ROW_BLOCKS_PER_COLUMN > 0) {
            numBlocks[1] = ColumnEncoder.APPLY_ROW_BLOCKS_PER_COLUMN;
        }
        if (numBlocks[0] == 0 && ConfigurationManager.getParallelBuildBlocks() > 0) {
            numBlocks[0] = ConfigurationManager.getParallelBuildBlocks();
        }
        if (numBlocks[1] == 0 && ConfigurationManager.getParallelApplyBlocks() > 0) {
            numBlocks[1] = ConfigurationManager.getParallelApplyBlocks();
        }
        int nRow = in.getNumRows();
        int nThread = OptimizerUtils.getTransformNumThreads();
        int minNumRows = 16000;
        ArrayList<ColumnEncoderComposite> recodeEncoders = new ArrayList<ColumnEncoderComposite>();
        ArrayList<ColumnEncoderComposite> bowEncoders = new ArrayList<ColumnEncoderComposite>();
        int nBuild = 0;
        for (ColumnEncoderComposite e2 : this._columnEncoders) {
            if (!e2.hasBuild()) continue;
            ++nBuild;
            if (e2.hasEncoder(ColumnEncoderRecode.class)) {
                recodeEncoders.add(e2);
            }
            if (!e2.hasEncoder(ColumnEncoderBagOfWords.class)) continue;
            bowEncoders.add(e2);
        }
        int nApply = in.getNumColumns();
        if (numBlocks[0] == 0 && nBuild > 0 && nBuild < nThread) {
            numBlocks[0] = Math.round((float)nThread / (float)nBuild);
        }
        if (numBlocks[1] == 0 && nApply > 0 && nApply < nThread * 2) {
            numBlocks[1] = Math.round((float)nThread * 2.0f / (float)nApply);
        }
        int bowNumBuildBlks = numBlocks[0];
        int bowNumApplyBlks = numBlocks[1];
        int optimalPartitions = Math.max(1, nRow / minNumRows);
        numBlocks[0] = Math.min(numBlocks[0], optimalPartitions);
        numBlocks[1] = Math.min(numBlocks[1], optimalPartitions);
        int rcdNumBuildBlks = numBlocks[0];
        optimalPartitions = Math.max(1, nRow / (minNumRows / 16));
        bowNumBuildBlks = Math.min(bowNumBuildBlks, optimalPartitions);
        bowNumApplyBlks = Math.min(bowNumApplyBlks, optimalPartitions);
        if (numBlocks[0] > 1 && !recodeEncoders.isEmpty() && bowEncoders.isEmpty()) {
            rcdNumBuildBlks = this.getNumBuildBlksMemorySafe(in, recodeEncoders, rcdNumBuildBlks, false);
        } else if (bowNumBuildBlks > 1 && recodeEncoders.isEmpty() && !bowEncoders.isEmpty()) {
            bowNumBuildBlks = this.getNumBuildBlksMemorySafe(in, bowEncoders, bowNumBuildBlks, true);
        } else if (bowNumBuildBlks > 1 || rcdNumBuildBlks > 1) {
            ArrayList<List<ColumnEncoderComposite>> encoders = new ArrayList<List<ColumnEncoderComposite>>();
            encoders.add(recodeEncoders);
            encoders.add(bowEncoders);
            int[] bldBlks = new int[]{rcdNumBuildBlks, bowNumBuildBlks};
            this.getNumBuildBlksMixedEncMemorySafe(in, encoders, bldBlks);
            rcdNumBuildBlks = bldBlks[0];
            bowNumBuildBlks = bldBlks[1];
        }
        for (int i = 0; i < 2; ++i) {
            if (numBlocks[i] != 0) continue;
            numBlocks[i] = 1;
        }
        this._partitionDone = true;
        this._columnEncoders.forEach(e -> e.setNumPartitions(numBlocks[0], numBlocks[1]));
        if (rcdNumBuildBlks > 0 && rcdNumBuildBlks != numBlocks[0]) {
            int rcdNumBlocks = rcdNumBuildBlks;
            recodeEncoders.forEach(e -> e.setNumPartitions(rcdNumBlocks, numBlocks[1]));
        }
        if (bowNumBuildBlks > 0) {
            int bowNumBlocks = bowNumBuildBlks;
            int bowApplyBlks = bowNumApplyBlks;
            bowEncoders.forEach(e -> e.setNumPartitions(bowNumBlocks, bowApplyBlks));
        }
    }

    private int getNumBuildBlksMemorySafe(CacheBlock<?> in, List<ColumnEncoderComposite> encoders, int numBldBlks, boolean hasBOW) {
        this.estimateMapSize(in, encoders);
        long memBudget = (long)(OptimizerUtils.getLocalMemBudget() - (double)in.getInMemorySize());
        if (hasBOW) {
            memBudget -= (long)encoders.size() * (long)MemoryEstimates.intArrayCost(in.getNumRows());
        }
        long totMemOverhead = this.getTotalMemOverhead(in, numBldBlks, encoders);
        while (numBldBlks > 1 && totMemOverhead > memBudget) {
            totMemOverhead = this.getTotalMemOverhead(in, --numBldBlks, encoders);
        }
        return numBldBlks;
    }

    private void getNumBuildBlksMixedEncMemorySafe(CacheBlock<?> in, List<List<ColumnEncoderComposite>> encs, int[] blks) {
        long memBudget = (long)(OptimizerUtils.getLocalMemBudget() - (double)in.getInMemorySize());
        memBudget -= (long)encs.get(1).size() * (long)MemoryEstimates.intArrayCost(in.getNumRows());
        int numOfEncTypes = encs.size();
        long[] totMemOverhead = new long[numOfEncTypes];
        for (int i = 0; i < numOfEncTypes; ++i) {
            this.estimateMapSize(in, encs.get(i));
            totMemOverhead[i] = this.getTotalMemOverhead(in, blks[i], encs.get(i));
        }
        int next = blks[1] > 1 ? 1 : 0;
        int skipped = 0;
        while (skipped != numOfEncTypes && Arrays.stream(totMemOverhead).sum() > memBudget) {
            if (blks[next] > 1) {
                int n = next;
                blks[n] = blks[n] - 1;
                totMemOverhead[next] = this.getTotalMemOverhead(in, blks[next], encs.get(next));
                next = (next + 1) % numOfEncTypes;
                skipped = 0;
                continue;
            }
            ++skipped;
        }
    }

    private void estimateMapSize(CacheBlock<?> in, List<ColumnEncoderComposite> encList) {
        long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        int k = OptimizerUtils.getTransformNumThreads();
        int[] sampleInds = MultiColumnEncoder.getSampleIndices(in, (int)(0.1 * (double)in.getNumRows()), (int)System.nanoTime(), 1);
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            pool.submit(() -> ((Stream)encList.stream().parallel()).forEach(e -> e.computeMapSizeEstimate(in, sampleInds))).get();
        }
        catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        finally {
            pool.shutdown();
        }
        if (DMLScript.STATISTICS) {
            LOG.debug((Object)("Elapsed time for encoder map size estimation: " + ((double)System.nanoTime() - (double)t0) / 1000000.0 + " ms"));
            TransformStatistics.incMapSizeEstimationTime(System.nanoTime() - t0);
        }
    }

    private static int[] getSampleIndices(CacheBlock<?> in, int sampleSize, int seed, int k) {
        return ComEstSample.getSortedSample(in.getNumRows(), sampleSize, seed, k);
    }

    private long getTotalMemOverhead(CacheBlock<?> in, int nBuildpart, List<ColumnEncoderComposite> encoders) {
        long totMemOverhead = 0L;
        if (nBuildpart == 1) {
            totMemOverhead = encoders.stream().mapToLong(ColumnEncoder::getEstMetaSize).sum();
            return totMemOverhead;
        }
        for (ColumnEncoderComposite enc : encoders) {
            int partSize = in.getNumRows() / nBuildpart;
            int partNumDist = Math.min(partSize, enc.getEstNumDistincts());
            if (enc.getAvgEntrySize() == 0L) {
                throw new DMLRuntimeException("Error while estimating entry size of encoder map");
            }
            long allMapsSize = (long)partNumDist * enc.getAvgEntrySize() * (long)nBuildpart;
            totMemOverhead += allMapsSize;
        }
        return totMemOverhead;
    }

    private static ArrayList<int[]> outputMatrixPreProcessing(MatrixBlock output, CacheBlock<?> input, EncoderMeta encm, long nnz, int k) {
        long t0;
        long l = t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        if (nnz < 0L) {
            nnz = (long)output.getNumRows() * (long)input.getNumColumns();
        }
        ArrayList<int[]> bowNnzRowOffsets = null;
        if (output.isInSparseFormat()) {
            if (MatrixBlock.DEFAULT_SPARSEBLOCK != SparseBlock.Type.CSR && MatrixBlock.DEFAULT_SPARSEBLOCK != SparseBlock.Type.MCSR) {
                throw new RuntimeException("Transformapply is only supported for MCSR and CSR output matrix");
            }
            boolean mcsr = false;
            if (mcsr) {
                output.allocateBlock();
                SparseBlock block = output.getSparseBlock();
                if (encm.hasDC && OptimizerUtils.getTransformNumThreads() > 1) {
                    IntStream.range(0, output.getNumRows()).parallel().forEach(r -> {
                        block.allocate(r, input.getNumColumns());
                        ((SparseRowVector)block.get(r)).setSize(input.getNumColumns());
                    });
                } else {
                    for (int r2 = 0; r2 < output.getNumRows(); ++r2) {
                        block.allocate(r2, input.getNumColumns());
                        ((SparseRowVector)block.get(r2)).setSize(input.getNumColumns());
                    }
                }
            } else {
                int static_offset;
                int nnzInt = (int)nnz;
                int[] rptr = new int[output.getNumRows() + 1];
                if (encm.numBOWEnc <= 0) {
                    for (int i = 0; i < rptr.length - 1; ++i) {
                        rptr[i + 1] = rptr[i] + input.getNumColumns();
                    }
                } else if (encm.nnzPerRowBOW != null) {
                    static_offset = input.getNumColumns() - encm.numBOWEnc;
                    for (int i = 0; i < rptr.length - 1; ++i) {
                        int nnzPerRow = static_offset + encm.nnzPerRowBOW[i];
                        rptr[i + 1] = rptr[i] + nnzPerRow;
                    }
                } else {
                    bowNnzRowOffsets = MultiColumnEncoder.getNnzPerRowFromBOWEncoders(input, encm, k);
                    static_offset = input.getNumColumns() - 1;
                    int[] aggOffsets = bowNnzRowOffsets.get(bowNnzRowOffsets.size() - 1);
                    for (int i = 0; i < rptr.length - 1; ++i) {
                        rptr[i + 1] = rptr[i] + static_offset + aggOffsets[i];
                    }
                    nnzInt = rptr[rptr.length - 1];
                }
                SparseBlockCSR csrblock = new SparseBlockCSR(rptr, new int[nnzInt], new double[nnzInt], nnzInt);
                output.setSparseBlock(csrblock);
            }
        } else {
            output.allocateDenseBlock(true, encm.hasWE);
            if (encm.hasWE) {
                DenseBlockFP64DEDUP dedup = (DenseBlockFP64DEDUP)output.getDenseBlock();
                dedup.setDistinct(encm.distinctWE);
                dedup.setEmbeddingSize(encm.sizeWE);
            }
        }
        if (DMLScript.STATISTICS) {
            LOG.debug((Object)("Elapsed time for allocation: " + ((double)System.nanoTime() - (double)t0) / 1000000.0 + " ms"));
            TransformStatistics.incOutMatrixPreProcessingTime(System.nanoTime() - t0);
        }
        return bowNnzRowOffsets;
    }

    /*
     * WARNING - void declaration
     */
    private static ArrayList<int[]> getNnzPerRowFromBOWEncoders(CacheBlock<?> input, EncoderMeta encm, int k) {
        ArrayList<int[]> bowNnzRowOffsets;
        int min_block_size = 1000;
        int num_blocks = input.getNumRows() / min_block_size;
        int num_blks1 = Math.min((k + encm.numBOWEnc - 1) / encm.numBOWEnc, Math.max(num_blocks, 1));
        int blk_len1 = (input.getNumRows() + num_blks1 - 1) / num_blks1;
        int num_blks2 = Math.min(k, Math.max(num_blocks, 1));
        int blk_len2 = (input.getNumRows() + num_blks2 - 1) / num_blks1;
        ExecutorService pool = CommonThreadPool.get(k);
        ArrayList<int[]> bowNnzRowOffsetsFinal = new ArrayList<int[]>();
        try {
            void var14_19;
            encm.bowEncoders.forEach(e -> {
                e._nnzPerRow = new int[input.getNumRows()];
            });
            ArrayList list = new ArrayList();
            for (int i = 0; i < num_blks1; ++i) {
                int n = i * blk_len1;
                int n2 = Math.min((i + 1) * blk_len1, input.getNumRows());
                list.add(pool.submit(() -> ((Stream)encm.bowEncoders.stream().parallel()).forEach(e -> e.computeNnzPerRow(input, start, end))));
            }
            for (Future future : list) {
                future.get();
            }
            list.clear();
            int[] previous = null;
            for (ColumnEncoderComposite columnEncoderComposite : encm.encs) {
                if (columnEncoderComposite.hasEncoder(ColumnEncoderBagOfWords.class)) {
                    previous = previous == null ? columnEncoderComposite.getEncoder(ColumnEncoderBagOfWords.class)._nnzPerRow : new int[input.getNumRows()];
                }
                bowNnzRowOffsetsFinal.add(previous);
            }
            boolean bl = false;
            while (var14_19 < num_blks2) {
                void var15_23 = var14_19 * blk_len1;
                list.add(pool.submit(() -> MultiColumnEncoder.lambda$getNnzPerRowFromBOWEncoders$11((int)var15_23, blk_len2, input, encm, bowNnzRowOffsetsFinal)));
                ++var14_19;
            }
            for (Future future : list) {
                future.get();
            }
            bowNnzRowOffsets = bowNnzRowOffsetsFinal;
        }
        catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        finally {
            pool.shutdown();
        }
        return bowNnzRowOffsets;
    }

    private static void aggregateNnzPerRow(int start, int blk_len, int numRows, List<ColumnEncoderComposite> encs, ArrayList<int[]> bowNnzRowOffsets) {
        int end = Math.min(start + blk_len, numRows);
        int pos = 0;
        int[] aggRowOffsets = null;
        for (ColumnEncoderComposite enc : encs) {
            int[] currentOffsets = bowNnzRowOffsets.get(pos);
            if (enc.hasEncoder(ColumnEncoderBagOfWords.class)) {
                ColumnEncoderBagOfWords bow = enc.getEncoder(ColumnEncoderBagOfWords.class);
                if (aggRowOffsets == null) {
                    aggRowOffsets = currentOffsets;
                } else {
                    for (int i = start; i < end; ++i) {
                        currentOffsets[i] = aggRowOffsets[i] + bow._nnzPerRow[i] - 1;
                    }
                }
            }
            ++pos;
        }
    }

    private void outputMatrixPostProcessing(MatrixBlock output, int k) {
        long t0;
        long l = t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        if (output.isInSparseFormat() && this.containsZeroOut()) {
            if (k == 1) {
                this.outputMatrixPostProcessingSingleThread(output);
            } else {
                this.outputMatrixPostProcessingParallel(output, k);
            }
        }
        output.recomputeNonZeros(k);
        if (DMLScript.STATISTICS) {
            TransformStatistics.incOutMatrixPostProcessingTime(System.nanoTime() - t0);
        }
    }

    private void outputMatrixPostProcessingSingleThread(MatrixBlock output) {
        SparseBlock sb = output.getSparseBlock();
        if (sb instanceof SparseBlockMCSR) {
            IntStream.range(0, output.getNumRows()).forEach(row -> sb.compact(row));
        } else {
            ((SparseBlockCSR)sb).compact();
        }
    }

    private boolean containsZeroOut() {
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            if (!columnEncoder.containsZeroOut()) continue;
            return true;
        }
        return false;
    }

    private void outputMatrixPostProcessingParallel(MatrixBlock output, int k) {
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            SparseBlock sb = output.getSparseBlock();
            if (sb instanceof SparseBlockMCSR) {
                pool.submit(() -> IntStream.range(0, output.getNumRows()).parallel().forEach(row -> sb.compact(row))).get();
            } else {
                ((SparseBlockCSR)sb).compact();
            }
        }
        catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        finally {
            pool.shutdown();
        }
    }

    @Override
    public void allocateMetaData(FrameBlock meta) {
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            columnEncoder.allocateMetaData(meta);
        }
    }

    @Override
    public FrameBlock getMetaData(FrameBlock meta) {
        return this.getMetaData(meta, 1);
    }

    public FrameBlock getMetaData(FrameBlock meta, int k) {
        long t0 = System.nanoTime();
        if (this._meta != null) {
            return this._meta;
        }
        if (meta == null) {
            meta = new FrameBlock(this._columnEncoders.size(), Types.ValueType.STRING);
        }
        this.allocateMetaData(meta);
        if (k > 1) {
            ExecutorService pool = CommonThreadPool.get(k);
            try {
                ArrayList<ColumnMetaDataTask<ColumnEncoder>> arrayList = new ArrayList<ColumnMetaDataTask<ColumnEncoder>>();
                for (ColumnEncoder columnEncoder : this._columnEncoders) {
                    arrayList.add(new ColumnMetaDataTask<ColumnEncoder>(columnEncoder, meta));
                }
                List taskret = pool.invokeAll(arrayList);
                for (Future task : taskret) {
                    task.get();
                }
            }
            catch (Exception exception) {
                throw new DMLRuntimeException(exception);
            }
            finally {
                pool.shutdown();
            }
        } else {
            for (ColumnEncoder columnEncoder : this._columnEncoders) {
                columnEncoder.getMetaData(meta);
            }
        }
        if (this._legacyOmit != null) {
            this._legacyOmit.getMetaData(meta);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.getMetaData(meta);
        }
        LOG.debug((Object)("Time spent getting metadata " + ((double)System.nanoTime() - (double)t0) / 1000000.0 + " ms"));
        return meta;
    }

    @Override
    public void initMetaData(FrameBlock meta) {
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            columnEncoder.initMetaData(meta);
        }
        if (this._legacyOmit != null) {
            this._legacyOmit.initMetaData(meta);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.initMetaData(meta);
        }
    }

    public void initEmbeddings(MatrixBlock embeddings) {
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            columnEncoder.initEmbeddings(embeddings);
        }
    }

    @Override
    public void prepareBuildPartial() {
        for (Encoder encoder : this._columnEncoders) {
            encoder.prepareBuildPartial();
        }
    }

    @Override
    public void buildPartial(FrameBlock in) {
        for (Encoder encoder : this._columnEncoders) {
            encoder.buildPartial(in);
        }
    }

    public MatrixBlock getColMapping(FrameBlock meta) {
        MatrixBlock out = new MatrixBlock(meta.getNumColumns(), 3, false);
        List<ColumnEncoderDummycode> dc = this.getColumnEncoders(ColumnEncoderDummycode.class);
        int ni = 0;
        for (int i = 0; i < out.getNumRows(); ++i) {
            int colID = i + 1;
            int nColID = ni + 1;
            List encoder = dc.stream().filter(e -> e.getColID() == colID).collect(Collectors.toList());
            assert (encoder.size() <= 1);
            ni = encoder.size() == 1 ? (int)((long)ni + meta.getColumnMetadata(i).getNumDistinct()) : ++ni;
            out.set(i, 0, colID);
            out.set(i, 1, nColID);
            out.set(i, 2, ni);
        }
        return out;
    }

    @Override
    public void updateIndexRanges(long[] beginDims, long[] endDims, int offset) {
        this._columnEncoders.forEach(encoder -> encoder.updateIndexRanges(beginDims, endDims, offset));
        if (this._legacyOmit != null) {
            this._legacyOmit.updateIndexRanges(beginDims, endDims);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.updateIndexRanges(beginDims, endDims);
        }
    }

    @Override
    public void writeExternal(ObjectOutput out) throws IOException {
        out.writeBoolean(this._legacyMVImpute != null);
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.writeExternal(out);
        }
        out.writeBoolean(this._legacyOmit != null);
        if (this._legacyOmit != null) {
            this._legacyOmit.writeExternal(out);
        }
        out.writeInt(this._colOffset);
        out.writeInt(this._columnEncoders.size());
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            out.writeInt(columnEncoder._colID);
            columnEncoder.writeExternal(out);
        }
        out.writeBoolean(this._meta != null);
        if (this._meta != null) {
            this._meta.write(out);
        }
    }

    @Override
    public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
        if (in.readBoolean()) {
            this._legacyMVImpute = new EncoderMVImpute();
            this._legacyMVImpute.readExternal(in);
        }
        if (in.readBoolean()) {
            this._legacyOmit = new EncoderOmit();
            this._legacyOmit.readExternal(in);
        }
        this._colOffset = in.readInt();
        int encodersSize = in.readInt();
        this._columnEncoders = new ArrayList<ColumnEncoderComposite>();
        for (int i = 0; i < encodersSize; ++i) {
            int colID = in.readInt();
            ColumnEncoderComposite columnEncoder = new ColumnEncoderComposite();
            columnEncoder.readExternal(in);
            columnEncoder.setColID(colID);
            this._columnEncoders.add(columnEncoder);
        }
        if (in.readBoolean()) {
            FrameBlock meta = new FrameBlock();
            meta.readFields(in);
            this._meta = meta;
        }
    }

    /*
     * WARNING - void declaration
     */
    public <T extends ColumnEncoder> List<T> getColumnEncoders(Class<T> type) {
        ArrayList<ColumnEncoder> ret = new ArrayList<ColumnEncoder>();
        for (ColumnEncoder columnEncoder : this._columnEncoders) {
            void var4_4;
            if (columnEncoder.getClass().equals(ColumnEncoderComposite.class) && type != ColumnEncoderComposite.class) {
                T t = ((ColumnEncoderComposite)columnEncoder).getEncoder(type);
            }
            if (var4_4 == null || !var4_4.getClass().equals(type)) continue;
            ret.add((ColumnEncoder)type.cast(var4_4));
        }
        return ret;
    }

    public <T extends ColumnEncoder> T getColumnEncoder(int colID, Class<T> type) {
        for (ColumnEncoder encoder : this.getColumnEncoders(type)) {
            if (encoder._colID != colID) continue;
            return (T)encoder;
        }
        return null;
    }

    public <T extends ColumnEncoder, E> List<E> getFromAll(Class<T> type, Function<? super T, ? extends E> mapper) {
        return this.getColumnEncoders(type).stream().map(mapper).collect(Collectors.toList());
    }

    public <T extends ColumnEncoder> int[] getFromAllIntArray(Class<T> type, Function<? super T, ? extends Integer> mapper) {
        return this.getFromAll(type, mapper).stream().mapToInt(i -> i).toArray();
    }

    public <T extends ColumnEncoder> double[] getFromAllDoubleArray(Class<T> type, Function<? super T, ? extends Double> mapper) {
        return this.getFromAll(type, mapper).stream().mapToDouble(i -> i).toArray();
    }

    public List<ColumnEncoderComposite> getColumnEncoders() {
        return this._columnEncoders;
    }

    public List<ColumnEncoderComposite> getCompositeEncodersForID(int colID) {
        return this._columnEncoders.stream().filter(encoder -> encoder._colID == colID).collect(Collectors.toList());
    }

    public List<Class<? extends ColumnEncoder>> getEncoderTypes(int colID) {
        HashSet set = new HashSet();
        for (ColumnEncoderComposite encoderComp : this._columnEncoders) {
            if (encoderComp._colID != colID && colID != -1) continue;
            for (ColumnEncoder encoder : encoderComp.getEncoders()) {
                set.add(encoder.getClass());
            }
        }
        return new ArrayList<Class<? extends ColumnEncoder>>(set);
    }

    public List<Class<? extends ColumnEncoder>> getEncoderTypes() {
        return this.getEncoderTypes(-1);
    }

    public int getNumOutCols() {
        int sum = 0;
        for (int i = 0; i < this._columnEncoders.size(); ++i) {
            sum += this._columnEncoders.get(i).getDomainSize();
        }
        return sum;
    }

    public int getNumExtraCols(IndexRange ixRange) {
        List dc = this.getColumnEncoders(ColumnEncoderDummycode.class).stream().filter(dce -> ixRange.inColRange(dce._colID)).collect(Collectors.toList());
        if (dc.isEmpty()) {
            return 0;
        }
        return dc.stream().map(ColumnEncoderDummycode::getDomainSize).mapToInt(i -> i).sum() - dc.size();
    }

    public <T extends ColumnEncoder> boolean containsEncoderForID(int colID, Class<T> type) {
        return this.getColumnEncoders(type).stream().anyMatch(encoder -> encoder.getColID() == colID);
    }

    public <T extends ColumnEncoder, E> void applyToAll(Class<T> type, Consumer<? super T> function) {
        this.getColumnEncoders(type).forEach(function);
    }

    public <T extends ColumnEncoder, E> void applyToAll(Consumer<? super ColumnEncoderComposite> function) {
        this.getColumnEncoders().forEach(function);
    }

    public MultiColumnEncoder subRangeEncoder(IndexRange ixRange) {
        ArrayList<ColumnEncoderComposite> encoders = new ArrayList<ColumnEncoderComposite>();
        for (long i = ixRange.colStart; i < ixRange.colEnd; ++i) {
            encoders.addAll(this.getCompositeEncodersForID((int)i));
        }
        MultiColumnEncoder subRangeEncoder = new MultiColumnEncoder(encoders);
        subRangeEncoder._colOffset = (int)(-ixRange.colStart) + 1;
        if (this._legacyOmit != null) {
            subRangeEncoder.addReplaceLegacyEncoder(this._legacyOmit.subRangeEncoder(ixRange));
        }
        if (this._legacyMVImpute != null) {
            subRangeEncoder.addReplaceLegacyEncoder(this._legacyMVImpute.subRangeEncoder(ixRange));
        }
        return subRangeEncoder;
    }

    public <T extends ColumnEncoder> MultiColumnEncoder subRangeEncoder(IndexRange ixRange, Class<T> type) {
        ArrayList<T> encoders = new ArrayList<T>();
        for (long i = ixRange.colStart; i < ixRange.colEnd; ++i) {
            encoders.add(this.getColumnEncoder((int)i, type));
        }
        if (type.equals(ColumnEncoderComposite.class)) {
            return new MultiColumnEncoder(encoders.stream().map(e -> (ColumnEncoderComposite)e).collect(Collectors.toList()));
        }
        return new MultiColumnEncoder(encoders.stream().map(ColumnEncoderComposite::new).collect(Collectors.toList()));
    }

    public void mergeReplace(MultiColumnEncoder multiEncoder) {
        for (ColumnEncoderComposite otherEncoder : multiEncoder._columnEncoders) {
            ColumnEncoderComposite encoder = (ColumnEncoderComposite)this.getColumnEncoder(otherEncoder._colID, otherEncoder.getClass());
            if (encoder != null) {
                this._columnEncoders.remove(encoder);
            }
            this._columnEncoders.add(otherEncoder);
        }
    }

    public void mergeAt(Encoder other, int columnOffset, int row) {
        if (other instanceof MultiColumnEncoder) {
            for (ColumnEncoder columnEncoder : ((MultiColumnEncoder)other)._columnEncoders) {
                this.addEncoder(columnEncoder, columnOffset);
            }
            this.legacyMergeAt((MultiColumnEncoder)other, row, columnOffset + 1);
        } else {
            this.addEncoder((ColumnEncoder)other, columnOffset);
        }
    }

    private void legacyMergeAt(MultiColumnEncoder other, int row, int col) {
        if (other._legacyOmit != null) {
            other._legacyOmit.shiftCols(col - 1);
        }
        if (other._legacyOmit != null) {
            if (this._legacyOmit == null) {
                this._legacyOmit = new EncoderOmit();
            }
            this._legacyOmit.mergeAt(other._legacyOmit, row, col);
        }
        if (other._legacyMVImpute != null) {
            other._legacyMVImpute.shiftCols(col - 1);
        }
        if (this._legacyMVImpute != null && other._legacyMVImpute != null) {
            this._legacyMVImpute.mergeAt(other._legacyMVImpute, row, col);
        } else if (this._legacyMVImpute == null) {
            this._legacyMVImpute = other._legacyMVImpute;
        }
    }

    private void addEncoder(ColumnEncoder encoder, int columnOffset) {
        int colId = encoder._colID + columnOffset;
        Object presentEncoder = this.getColumnEncoder(colId, encoder.getClass());
        if (presentEncoder != null) {
            encoder.shiftCol(columnOffset);
            ((ColumnEncoder)presentEncoder).mergeAt(encoder);
        } else {
            ColumnEncoderComposite presentComposite = this.getColumnEncoder(colId, ColumnEncoderComposite.class);
            if (presentComposite != null) {
                encoder.shiftCol(columnOffset);
                presentComposite.mergeAt(encoder);
            } else {
                encoder.shiftCol(columnOffset);
                if (encoder instanceof ColumnEncoderComposite) {
                    this._columnEncoders.add((ColumnEncoderComposite)encoder);
                } else {
                    this._columnEncoders.add(new ColumnEncoderComposite(encoder));
                }
            }
        }
    }

    public <T extends LegacyEncoder> void addReplaceLegacyEncoder(T encoder) {
        if (encoder.getClass() == EncoderMVImpute.class) {
            this._legacyMVImpute = (EncoderMVImpute)encoder;
        } else if (encoder.getClass().equals(EncoderOmit.class)) {
            this._legacyOmit = (EncoderOmit)encoder;
        } else {
            throw new DMLRuntimeException("Tried to add non legacy Encoder");
        }
    }

    public <T extends LegacyEncoder> boolean hasLegacyEncoder() {
        return this.hasLegacyEncoder(EncoderMVImpute.class) || this.hasLegacyEncoder(EncoderOmit.class);
    }

    public boolean isCompressedTransformEncode(CacheBlock<?> in, boolean enabled) {
        return (enabled || ConfigurationManager.getDMLConfig().getBooleanValue("sysds.compressed.transformencode")) && in instanceof FrameBlock && this._colOffset == 0;
    }

    public <T extends LegacyEncoder> boolean hasLegacyEncoder(Class<T> type) {
        if (type.equals(EncoderMVImpute.class)) {
            return this._legacyMVImpute != null;
        }
        if (type.equals(EncoderOmit.class)) {
            return this._legacyOmit != null;
        }
        assert (false);
        return false;
    }

    public <T extends LegacyEncoder> T getLegacyEncoder(Class<T> type) {
        if (type.equals(EncoderMVImpute.class)) {
            return (T)((LegacyEncoder)type.cast(this._legacyMVImpute));
        }
        if (type.equals(EncoderOmit.class)) {
            return (T)((LegacyEncoder)type.cast(this._legacyOmit));
        }
        assert (false);
        return null;
    }

    public void applyColumnOffset() {
        this.applyToAll(e -> e.shiftCol(this._colOffset));
        if (this._legacyOmit != null) {
            this._legacyOmit.shiftCols(this._colOffset);
        }
        if (this._legacyMVImpute != null) {
            this._legacyMVImpute.shiftCols(this._colOffset);
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(this.getClass().getSimpleName());
        sb.append("\nIs Legacy: ");
        sb.append(this._legacyMVImpute);
        sb.append("\nEncoders:\n");
        for (int i = 0; i < this._columnEncoders.size(); ++i) {
            sb.append(this._columnEncoders.get(i));
            sb.append("\n");
        }
        return sb.toString();
    }

    private static EncoderMeta getEncMeta(List<ColumnEncoderComposite> encoders, boolean noBuild, int k, CacheBlock<?> in) {
        boolean hasUDF = false;
        boolean hasDC = false;
        boolean hasWE = false;
        int distinctWE = 0;
        int sizeWE = 0;
        long nnzBOW = 0L;
        int numBOWEncoder = 0;
        int[] nnzPerRowBOW = null;
        ArrayList<ColumnEncoderBagOfWords> bows = new ArrayList<ColumnEncoderBagOfWords>();
        for (ColumnEncoderComposite enc : encoders) {
            if (enc.hasEncoder(ColumnEncoderUDF.class)) {
                hasUDF = true;
                continue;
            }
            if (enc.hasEncoder(ColumnEncoderDummycode.class)) {
                hasDC = true;
                continue;
            }
            if (enc.hasEncoder(ColumnEncoderBagOfWords.class)) {
                ColumnEncoderBagOfWords bowEnc = enc.getEncoder(ColumnEncoderBagOfWords.class);
                ++numBOWEncoder;
                nnzBOW += bowEnc._nnz;
                if (noBuild) {
                    bows.add(bowEnc);
                    continue;
                }
                if (nnzPerRowBOW != null) {
                    for (int i = 0; i < bowEnc._nnzPerRow.length; ++i) {
                        int n = i;
                        nnzPerRowBOW[n] = nnzPerRowBOW[n] + bowEnc._nnzPerRow[i];
                    }
                    continue;
                }
                nnzPerRowBOW = (int[])bowEnc._nnzPerRow.clone();
                continue;
            }
            if (!enc.hasEncoder(ColumnEncoderWordEmbedding.class)) continue;
            hasWE = true;
            distinctWE = enc.getEncoder(ColumnEncoderWordEmbedding.class).getNrDistinctEmbeddings();
            sizeWE = enc.getDomainSize();
        }
        if (!bows.isEmpty()) {
            int[] sampleInds = MultiColumnEncoder.getSampleIndices(in, in.getNumRows() > 1000 ? (int)(0.1 * (double)in.getNumRows()) : in.getNumRows(), (int)System.nanoTime(), 1);
            ExecutorService pool = CommonThreadPool.get(k);
            try {
                Double result = pool.submit(() -> ((Stream)bows.stream().parallel()).mapToDouble(e -> e.computeNnzEstimate(in, sampleInds)).sum()).get();
                nnzBOW = (long)Math.ceil(result);
            }
            catch (Exception ex) {
                throw new DMLRuntimeException(ex);
            }
            finally {
                pool.shutdown();
            }
        }
        return new EncoderMeta(hasUDF, hasDC, hasWE, distinctWE, sizeWE, nnzBOW, numBOWEncoder, nnzPerRowBOW, bows, encoders);
    }

    private static class ColumnMetaDataTask<T extends ColumnEncoder>
    implements Callable<Object> {
        private final T _colEncoder;
        private final FrameBlock _out;

        protected ColumnMetaDataTask(T encoder, FrameBlock out) {
            this._colEncoder = encoder;
            this._out = out;
        }

        @Override
        public Object call() throws Exception {
            this._colEncoder.getMetaData(this._out);
            return null;
        }

        public String toString() {
            return this.getClass().getSimpleName() + "<ColId: " + ((ColumnEncoder)this._colEncoder)._colID + ">";
        }
    }

    private static class AllocMetaTask
    implements Callable<Object> {
        private final MultiColumnEncoder _encoder;
        private final FrameBlock _meta;

        private AllocMetaTask(MultiColumnEncoder encoder, FrameBlock meta) {
            this._encoder = encoder;
            this._meta = meta;
        }

        @Override
        public Object call() throws Exception {
            this._encoder.allocateMetaData(this._meta);
            return null;
        }

        public String toString() {
            return this.getClass().getSimpleName();
        }
    }

    private static class UpdateOutputColTask
    implements Callable<Object> {
        private final MultiColumnEncoder _encoder;
        private final List<DependencyTask<?>> _applyTasksWrappers;

        private UpdateOutputColTask(MultiColumnEncoder encoder, List<DependencyTask<?>> applyTasksWrappers) {
            this._encoder = encoder;
            this._applyTasksWrappers = applyTasksWrappers;
        }

        public String toString() {
            return this.getClass().getSimpleName();
        }

        @Override
        public Object call() throws Exception {
            int currentCol = -1;
            int currentOffset = 0;
            int[] sparseRowPointerOffsets = null;
            for (DependencyTask<?> dtask : this._applyTasksWrappers) {
                int nonOffsetCol;
                ((ApplyTasksWrapperTask)dtask).setOffset(currentOffset);
                if (sparseRowPointerOffsets != null) {
                    ((ApplyTasksWrapperTask)dtask).setSparseRowPointerOffsets(sparseRowPointerOffsets);
                }
                if ((nonOffsetCol = ((ApplyTasksWrapperTask)dtask)._encoder._colID - 1) <= currentCol) continue;
                currentCol = nonOffsetCol;
                ColumnEncoderComposite enc = this._encoder._columnEncoders.get(nonOffsetCol);
                if (enc.hasEncoder(ColumnEncoderDummycode.class)) {
                    currentOffset += enc.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
                    continue;
                }
                if (!enc.hasEncoder(ColumnEncoderBagOfWords.class)) continue;
                ColumnEncoderBagOfWords bow = enc.getEncoder(ColumnEncoderBagOfWords.class);
                currentOffset += bow.getDomainSize() - 1;
                if (sparseRowPointerOffsets == null) {
                    sparseRowPointerOffsets = bow._nnzPerRow;
                    continue;
                }
                sparseRowPointerOffsets = (int[])sparseRowPointerOffsets.clone();
                for (int r = 0; r < sparseRowPointerOffsets.length; ++r) {
                    int n = r;
                    sparseRowPointerOffsets[n] = sparseRowPointerOffsets[n] + (bow._nnzPerRow[r] - 1);
                }
            }
            return null;
        }
    }

    private static class ApplyTasksWrapperTask
    extends DependencyWrapperTask<Object> {
        private final ColumnEncoder _encoder;
        private final MatrixBlock _out;
        private final CacheBlock<?> _in;
        private int _offset = -1;
        private int[] _sparseRowPointerOffsets = null;

        private ApplyTasksWrapperTask(ColumnEncoder encoder, CacheBlock<?> in, MatrixBlock out, DependencyThreadPool pool) {
            super(pool);
            this._encoder = encoder;
            this._out = out;
            this._in = in;
        }

        @Override
        public List<DependencyTask<?>> getWrappedTasks() {
            return this._encoder.getApplyTasks(this._in, this._out, this._encoder._colID - 1 + this._offset, this._sparseRowPointerOffsets);
        }

        @Override
        public Object call() throws Exception {
            if (this._offset == -1) {
                throw new DMLRuntimeException("OutputCol for apply task wrapper has not been updated!, Most likely some concurrency issues\n " + this);
            }
            return super.call();
        }

        public void setOffset(int offset) {
            this._offset = offset;
        }

        public void setSparseRowPointerOffsets(int[] offsets) {
            this._sparseRowPointerOffsets = offsets;
        }

        @Override
        public String toString() {
            return this.getClass().getSimpleName() + "<ColId: " + this._encoder._colID + ">";
        }
    }

    private static class InitOutputMatrixTask
    implements Callable<Object> {
        private final MultiColumnEncoder _encoder;
        private final CacheBlock<?> _input;
        private final MatrixBlock _output;

        private InitOutputMatrixTask(MultiColumnEncoder encoder, CacheBlock<?> input, MatrixBlock output) {
            this._encoder = encoder;
            this._input = input;
            this._output = output;
        }

        @Override
        public Object call() {
            EncoderMeta encm = MultiColumnEncoder.getEncMeta(this._encoder.getEncoders(), false, -1, this._input);
            int numCols = this._encoder.getNumOutCols();
            long estNNz = (long)this._input.getNumRows() * (long)(encm.hasUDF ? numCols : this._input.getNumColumns() - encm.numBOWEnc) + encm.nnzBOW;
            boolean sparse = MatrixBlock.evalSparseFormatInMemory(this._input.getNumRows(), numCols, estNNz) && !encm.hasUDF;
            this._output.reset(this._input.getNumRows(), numCols, sparse, estNNz);
            MultiColumnEncoder.outputMatrixPreProcessing(this._output, this._input, encm, estNNz, 1);
            return null;
        }

        public String toString() {
            return this.getClass().getSimpleName();
        }
    }

    private static class MultiColumnLegacyMVImputeMetaPrepareTask
    implements Callable<Object> {
        private final MultiColumnEncoder _encoder;
        private final FrameBlock _input;

        protected MultiColumnLegacyMVImputeMetaPrepareTask(MultiColumnEncoder encoder, FrameBlock input) {
            this._encoder = encoder;
            this._input = input;
        }

        @Override
        public Void call() throws Exception {
            this._encoder._meta = this._encoder.getMetaData(new FrameBlock(this._input.getNumColumns(), Types.ValueType.STRING));
            this._encoder.initMetaData(this._encoder._meta);
            return null;
        }
    }

    private static class MultiColumnLegacyBuildTask
    implements Callable<Object> {
        private final MultiColumnEncoder _encoder;
        private final FrameBlock _input;

        protected MultiColumnLegacyBuildTask(MultiColumnEncoder encoder, FrameBlock input) {
            this._encoder = encoder;
            this._input = input;
        }

        @Override
        public Void call() throws Exception {
            this._encoder.legacyBuild(this._input);
            return null;
        }
    }

    private static class EncoderMeta {
        public final boolean hasUDF;
        public final boolean hasDC;
        public final boolean hasWE;
        public final int distinctWE;
        public final int sizeWE;
        public final long nnzBOW;
        public final int numBOWEnc;
        public final int[] nnzPerRowBOW;
        public final ArrayList<ColumnEncoderBagOfWords> bowEncoders;
        public final List<ColumnEncoderComposite> encs;

        public EncoderMeta(boolean hasUDF, boolean hasDC, boolean hasWE, int distinctWE, int sizeWE, long nnzBOW, int numBOWEncoder, int[] nnzPerRowBOW, ArrayList<ColumnEncoderBagOfWords> bows, List<ColumnEncoderComposite> encoders) {
            this.hasUDF = hasUDF;
            this.hasDC = hasDC;
            this.hasWE = hasWE;
            this.distinctWE = distinctWE;
            this.sizeWE = sizeWE;
            this.nnzBOW = nnzBOW;
            this.numBOWEnc = numBOWEncoder;
            this.nnzPerRowBOW = nnzPerRowBOW;
            this.bowEncoders = bows;
            this.encs = encoders;
        }
    }
}

