/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.api.ml;

import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaSparkContext$;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.sysml.api.ml.BaseSystemMLRegressorModel;
import org.apache.sysml.api.ml.PredictionUtils$;
import org.apache.sysml.api.mlcontext.MLContext;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.Matrix;
import org.apache.sysml.api.mlcontext.MatrixMetadata;
import org.apache.sysml.api.mlcontext.Script;
import org.apache.sysml.api.mlcontext.ScriptFactory;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtils;
import org.apache.sysml.runtime.instructions.spark.utils.RDDConverterUtilsExt;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import scala.Predef$;
import scala.Tuple2;
import scala.collection.Seq;

public abstract class BaseSystemMLRegressorModel$class {
    public static String baseTransform(BaseSystemMLRegressorModel $this, String X_file, SparkContext sc, String predictionVar) {
        boolean isSingleNode = false;
        MLContext ml = new MLContext(sc);
        $this.updateML(ml);
        Script readScript = ScriptFactory.dml($this.dmlRead("X", X_file)).out("X");
        MLResults res = ml.execute(readScript);
        Tuple2<Script, String> script = $this.getPredictionScript(isSingleNode);
        MLResults modelPredict = ml.execute(((Script)script._1()).in((String)script._2(), res.getMatrix("X")));
        Script writeScript = ScriptFactory.dml($this.dmlWrite("X")).in("X", modelPredict.getMatrix(predictionVar));
        ml.execute(writeScript);
        return "output.mtx";
    }

    public static MatrixBlock baseTransform(BaseSystemMLRegressorModel $this, MatrixBlock X, SparkContext sc, String predictionVar) {
        boolean isSingleNode = true;
        MLContext ml = new MLContext(sc);
        $this.updateML(ml);
        Tuple2<Script, String> script = $this.getPredictionScript(isSingleNode);
        MLResults modelPredict = ml.execute(((Script)script._1()).in((String)script._2(), X));
        MatrixBlock ret = modelPredict.getMatrix(predictionVar).toMatrixBlock();
        if (ret.getNumColumns() != 1) {
            throw new RuntimeException("Expected prediction to be a column vector");
        }
        return ret;
    }

    public static Dataset baseTransform(BaseSystemMLRegressorModel $this, Dataset df, SparkContext sc, String predictionVar) {
        boolean isSingleNode = false;
        MLContext ml = new MLContext(sc);
        $this.updateML(ml);
        MatrixCharacteristics mcXin = new MatrixCharacteristics();
        JavaPairRDD<MatrixIndexes, MatrixBlock> Xin = RDDConverterUtils.dataFrameToBinaryBlock(JavaSparkContext$.MODULE$.fromSparkContext(df.rdd().sparkContext()), (Dataset<Row>)df, mcXin, false, true);
        Tuple2<Script, String> script = $this.getPredictionScript(isSingleNode);
        MatrixMetadata mmXin = new MatrixMetadata(mcXin);
        Matrix Xin_bin = new Matrix(Xin, mmXin);
        MLResults modelPredict = ml.execute(((Script)script._1()).in((String)script._2(), Xin_bin));
        Dataset predictedDF = modelPredict.getDataFrame(predictionVar).select(RDDConverterUtils.DF_ID_COLUMN, (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"C1"})).withColumnRenamed("C1", "prediction");
        Dataset<Row> dataset = RDDConverterUtilsExt.addIDToDataFrame((Dataset<Row>)df, df.sparkSession(), RDDConverterUtils.DF_ID_COLUMN);
        return PredictionUtils$.MODULE$.joinUsingID(dataset, (Dataset<Row>)predictedDF);
    }

    public static void $init$(BaseSystemMLRegressorModel $this) {
    }
}

