/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.codegen;

import java.util.ArrayList;
import java.util.Objects;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.MemoTable;
import org.apache.sysds.hops.MultiThreadedHop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.codegen.SpoofCompiler;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.SpoofFused;
import org.apache.sysds.runtime.codegen.SpoofRowwise;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

public class SpoofFusedOp
extends MultiThreadedHop {
    private Class<?> _class = null;
    private boolean _distSupported = false;
    private long _constDim2 = -1L;
    private SpoofOutputDimsType _dimsType;
    private SpoofCompiler.GeneratorAPI _api = SpoofCompiler.GeneratorAPI.JAVA;
    private String _genVarName;

    public SpoofFusedOp() {
    }

    public SpoofFusedOp(String name, Types.DataType dt, Types.ValueType vt, Class<?> cla, SpoofCompiler.GeneratorAPI api, String genVarName, boolean dist, SpoofOutputDimsType type) {
        super(name, dt, vt);
        this._class = cla;
        this._distSupported = dist;
        this._dimsType = type;
        this._api = api;
        this._genVarName = genVarName;
    }

    @Override
    public void checkArity() {
    }

    @Override
    public boolean allowsAllExecTypes() {
        return this._distSupported;
    }

    public void setConstDim2(long constDim2) {
        this._constDim2 = constDim2;
    }

    @Override
    public boolean isGPUEnabled() {
        return this._api == SpoofCompiler.GeneratorAPI.CUDA;
    }

    @Override
    public boolean isMultiThreadedOpType() {
        return true;
    }

    @Override
    protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) {
        boolean onlyDenseOut = this._api == SpoofCompiler.GeneratorAPI.JAVA && this._class.getGenericSuperclass().equals(SpoofRowwise.class);
        int blen = this.getBlocksize() > 0 ? this.getBlocksize() : ConfigurationManager.getBlocksize();
        return onlyDenseOut ? (double)OptimizerUtils.estimateSize(dim1, dim2) : (double)OptimizerUtils.estimatePartitionedSizeExactSparsity(dim1, dim2, (long)blen, nnz);
    }

    @Override
    protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) {
        return 0.0;
    }

    @Override
    public Lop constructLops() {
        if (this.getLops() != null) {
            return this.getLops();
        }
        Types.ExecType et = this.optFindExecType();
        ArrayList<Lop> inputs = new ArrayList<Lop>();
        for (Hop c : this.getInput()) {
            inputs.add(c.constructLops());
        }
        int k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
        SpoofFused lop = new SpoofFused(inputs, this.getDataType(), this.getValueType(), this._class, this._api, this._genVarName, k, et);
        this.setOutputDimensions(lop);
        this.setLineNumbers(lop);
        this.setLops(lop);
        return lop;
    }

    @Override
    protected Types.ExecType optFindExecType(boolean transitive) {
        this.checkAndSetForcedPlatform();
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            this._etype = this.findExecTypeByMemEstimate();
            this.checkAndSetInvalidCPDimsAndSize();
        }
        return this._etype;
    }

    @Override
    public String getOpString() {
        if (this._class != null) {
            return "spoof(" + this._class.getSimpleName() + ")";
        }
        return "spoof(" + this.getName() + ")";
    }

    public String getClassName() {
        if (this._class != null) {
            return this._class.getName();
        }
        return "spoof" + this.getName();
    }

    @Override
    protected DataCharacteristics inferOutputCharacteristics(MemoTable memo) {
        DataCharacteristics mc = memo.getAllInputStats(this.getInput().get(0));
        MatrixCharacteristics ret = null;
        if (mc.dimsKnown()) {
            switch (this._dimsType) {
                case ROW_DIMS: {
                    ret = new MatrixCharacteristics(mc.getRows(), 1L, -1, -1L);
                    break;
                }
                case COLUMN_DIMS_ROWS: {
                    ret = new MatrixCharacteristics(mc.getCols(), 1L, -1, -1L);
                    break;
                }
                case COLUMN_DIMS_COLS: {
                    ret = new MatrixCharacteristics(1L, mc.getCols(), -1, -1L);
                    break;
                }
                case RANK_DIMS_COLS: {
                    DataCharacteristics dc2 = memo.getAllInputStats(this.getInput().get(1));
                    if (!dc2.dimsKnown()) break;
                    ret = new MatrixCharacteristics(1L, dc2.getCols(), -1, -1L);
                    break;
                }
                case INPUT_DIMS: {
                    ret = new MatrixCharacteristics(mc.getRows(), mc.getCols(), -1, -1L);
                    break;
                }
                case INPUT_DIMS_CONST2: {
                    ret = new MatrixCharacteristics(mc.getRows(), this._constDim2, -1, -1L);
                    break;
                }
                case VECT_CONST2: {
                    ret = new MatrixCharacteristics(1L, this._constDim2, -1, -1L);
                    break;
                }
                case SCALAR: {
                    ret = new MatrixCharacteristics(0L, 0L, -1, -1L);
                    break;
                }
                case MULTI_SCALAR: {
                    ret = new MatrixCharacteristics(1L, this._dc.getCols(), -1, -1L);
                    break;
                }
                case ROW_RANK_DIMS: {
                    DataCharacteristics dc2 = memo.getAllInputStats(this.getInput().get(1));
                    if (!dc2.dimsKnown()) break;
                    ret = new MatrixCharacteristics(mc.getRows(), dc2.getCols(), -1, -1L);
                    break;
                }
                case COLUMN_RANK_DIMS: {
                    DataCharacteristics dc2 = memo.getAllInputStats(this.getInput().get(1));
                    if (!dc2.dimsKnown()) break;
                    ret = new MatrixCharacteristics(mc.getCols(), dc2.getCols(), -1, -1L);
                    break;
                }
                case COLUMN_RANK_DIMS_T: {
                    DataCharacteristics dc2 = memo.getAllInputStats(this.getInput().get(1));
                    if (!dc2.dimsKnown()) break;
                    ret = new MatrixCharacteristics(dc2.getCols(), mc.getCols(), -1, -1L);
                    break;
                }
                default: {
                    throw new RuntimeException("Failed to infer worst-case size information for type: " + this._dimsType.toString());
                }
            }
        }
        return ret;
    }

    @Override
    public void refreshSizeInformation() {
        switch (this._dimsType) {
            case ROW_DIMS: {
                this.setDim1(this.getInput().get(0).getDim1());
                this.setDim2(1L);
                break;
            }
            case COLUMN_DIMS_ROWS: {
                this.setDim1(this.getInput().get(0).getDim2());
                this.setDim2(1L);
                break;
            }
            case COLUMN_DIMS_COLS: {
                this.setDim1(1L);
                this.setDim2(this.getInput().get(0).getDim2());
                break;
            }
            case RANK_DIMS_COLS: {
                this.setDim1(1L);
                this.setDim2(this.getInput().get(1).getDim2());
                break;
            }
            case INPUT_DIMS: {
                this.setDim1(this.getInput().get(0).getDim1());
                this.setDim2(this.getInput().get(0).getDim2());
                break;
            }
            case INPUT_DIMS_CONST2: {
                this.setDim1(this.getInput().get(0).getDim1());
                this.setDim2(this._constDim2);
                break;
            }
            case VECT_CONST2: {
                this.setDim1(1L);
                this.setDim2(this._constDim2);
                break;
            }
            case SCALAR: {
                this.setDim1(0L);
                this.setDim2(0L);
                break;
            }
            case MULTI_SCALAR: {
                this.setDim1(1L);
                break;
            }
            case ROW_RANK_DIMS: {
                this.setDim1(this.getInput().get(0).getDim1());
                this.setDim2(this.getInput().get(1).getDim2());
                break;
            }
            case COLUMN_RANK_DIMS: {
                this.setDim1(this.getInput().get(0).getDim2());
                this.setDim2(this.getInput().get(1).getDim2());
                break;
            }
            case COLUMN_RANK_DIMS_T: {
                this.setDim1(this.getInput().get(1).getDim2());
                this.setDim2(this.getInput().get(0).getDim2());
                break;
            }
            default: {
                throw new RuntimeException("Failed to refresh size information for type: " + this._dimsType.toString());
            }
        }
    }

    @Override
    public Object clone() throws CloneNotSupportedException {
        SpoofFusedOp ret = new SpoofFusedOp();
        ret.clone(this, false);
        ret._class = this._class;
        ret._distSupported = this._distSupported;
        ret._maxNumThreads = this._maxNumThreads;
        ret._constDim2 = this._constDim2;
        ret._dimsType = this._dimsType;
        return ret;
    }

    @Override
    public boolean compare(Hop that) {
        boolean ret;
        if (!(that instanceof SpoofFusedOp)) {
            return false;
        }
        SpoofFusedOp that2 = (SpoofFusedOp)that;
        boolean bl = ret = Objects.equals(this._class, that2._class) && this._distSupported == that2._distSupported && this._maxNumThreads == that2._maxNumThreads && this._constDim2 == that2._constDim2 && this.getInput().size() == that2.getInput().size() && this._api == that2._api;
        if (ret) {
            for (int i = 0; i < this.getInput().size(); ++i) {
                ret &= this.getInput().get(i) == that2.getInput().get(i);
            }
        }
        return ret;
    }

    public static enum SpoofOutputDimsType {
        INPUT_DIMS,
        INPUT_DIMS_CONST2,
        ROW_DIMS,
        COLUMN_DIMS_ROWS,
        COLUMN_DIMS_COLS,
        RANK_DIMS_COLS,
        SCALAR,
        MULTI_SCALAR,
        ROW_RANK_DIMS,
        COLUMN_RANK_DIMS,
        COLUMN_RANK_DIMS_T,
        VECT_CONST2;

    }
}

