package ml.combust.mleap.bundle.ops.classification;

import ml.combust.bundle.BundleContext;
import ml.combust.bundle.dsl.Bundle$BuiltinOps$classification$;
import ml.combust.bundle.dsl.HasAttributes;
import ml.combust.bundle.dsl.Model;
import ml.combust.bundle.dsl.Value$;
import ml.combust.bundle.op.OpModel;
import ml.combust.mleap.bundle.ops.MleapOp;
import ml.combust.mleap.core.classification.BinaryLogisticRegressionModel;
import ml.combust.mleap.core.classification.LogisticRegressionModel;
import ml.combust.mleap.core.classification.ProbabilisticLogisticsRegressionModel;
import ml.combust.mleap.runtime.MleapContext;
import ml.combust.mleap.runtime.transformer.classification.LogisticRegression;
import ml.combust.mleap.tensor.DenseTensor;
import ml.combust.mleap.tensor.Tensor;
import org.apache.spark.ml.linalg.Matrices$;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.Vectors$;
import scala.Predef$;
import scala.collection.Seq$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: LogisticRegressionOp.scala */
@ScalaSignature(bytes = "\u0006\u0001\u00153Aa\u0002\u0005\u0001+!)!\u0006\u0001C\u0001W!9a\u0006\u0001b\u0001\n\u001by\u0003B\u0002\u001a\u0001A\u00035\u0001\u0007C\u00044\u0001\t\u0007I\u0011\t\u001b\t\r\u0001\u0003\u0001\u0015!\u00036\u0011\u0015\t\u0005\u0001\"\u0011C\u0005QaunZ5ti&\u001c'+Z4sKN\u001c\u0018n\u001c8Pa*\u0011\u0011BC\u0001\u000fG2\f7o]5gS\u000e\fG/[8o\u0015\tYA\"A\u0002paNT!!\u0004\b\u0002\r\t,h\u000e\u001a7f\u0015\ty\u0001#A\u0003nY\u0016\f\u0007O\u0003\u0002\u0012%\u000591m\\7ckN$(\"A\n\u0002\u00055d7\u0001A\n\u0003\u0001Y\u0001Ba\u0006\r\u001bG5\t!\"\u0003\u0002\u001a\u0015\t9Q\n\\3ba>\u0003\bCA\u000e\"\u001b\u0005a\"BA\u0005\u001e\u0015\tqr$A\u0006ue\u0006t7OZ8s[\u0016\u0014(B\u0001\u0011\u000f\u0003\u001d\u0011XO\u001c;j[\u0016L!A\t\u000f\u0003%1{w-[:uS\u000e\u0014Vm\u001a:fgNLwN\u001c\t\u0003I!j\u0011!\n\u0006\u0003\u0013\u0019R!a\n\b\u0002\t\r|'/Z\u0005\u0003S\u0015\u0012q\u0003T8hSN$\u0018n\u0019*fOJ,7o]5p]6{G-\u001a7\u0002\rqJg.\u001b;?)\u0005a\u0003CA\u0017\u0001\u001b\u0005A\u0011!\n'P\u000f&\u001bF+S\"`%\u0016;%+R*T\u0013>su\fR#G\u0003VcEk\u0018+I%\u0016\u001b\u0006j\u0014'E+\u0005\u0001t\"A\u0019!\u0011}\u0002\u000f\u0001\u0001\u0001\u0001\u0001\u0001\ta\u0005T(H\u0013N#\u0016jQ0S\u000b\u001e\u0013ViU*J\u001f:{F)\u0012$B+2#v\f\u0016%S\u000bNCu\n\u0014#!\u0003\u0015iu\u000eZ3m+\u0005)\u0004\u0003\u0002\u001c;y\rj\u0011a\u000e\u0006\u0003qe\n!a\u001c9\u000b\u00055\u0001\u0012BA\u001e8\u0005\u001dy\u0005/T8eK2\u0004\"!\u0010 \u000e\u0003}I!aP\u0010\u0003\u00195cW-\u00199D_:$X\r\u001f;\u0002\r5{G-\u001a7!\u0003\u0015iw\u000eZ3m)\t\u00193\tC\u0003E\r\u0001\u0007!$\u0001\u0003o_\u0012,\u0007")
/* loaded from: input_file:ml/combust/mleap/bundle/ops/classification/LogisticRegressionOp.class */
public class LogisticRegressionOp extends MleapOp<LogisticRegression, LogisticRegressionModel> {
    private final OpModel<MleapContext, LogisticRegressionModel> Model;

    private final double LOGISTIC_REGRESSION_DEFAULT_THRESHOLD() {
        return 0.5d;
    }

    public OpModel<MleapContext, LogisticRegressionModel> Model() {
        return this.Model;
    }

    public LogisticRegressionModel model(LogisticRegression logisticRegression) {
        return logisticRegression.mo128model();
    }

    public LogisticRegressionOp() {
        super(ClassTag$.MODULE$.apply(LogisticRegression.class));
        final LogisticRegressionOp logisticRegressionOp = null;
        this.Model = new OpModel<MleapContext, LogisticRegressionModel>(logisticRegressionOp) { // from class: ml.combust.mleap.bundle.ops.classification.LogisticRegressionOp$$anon$1
            private final Class<LogisticRegressionModel> klazz;

            public String modelOpName(Object obj, BundleContext bundleContext) {
                return OpModel.modelOpName$(this, obj, bundleContext);
            }

            public Class<LogisticRegressionModel> klazz() {
                return this.klazz;
            }

            public String opName() {
                return Bundle$BuiltinOps$classification$.MODULE$.logistic_regression();
            }

            public Model store(Model model, LogisticRegressionModel logisticRegressionModel, BundleContext<MleapContext> bundleContext) {
                Model model2 = (Model) model.withValue("num_classes", Value$.MODULE$.long(logisticRegressionModel.numClasses()));
                if (!logisticRegressionModel.isMultinomial()) {
                    return (Model) ((HasAttributes) ((HasAttributes) model2.withValue("coefficients", Value$.MODULE$.vector(logisticRegressionModel.binaryModel().coefficients().toArray(), ClassTag$.MODULE$.Double()))).withValue("intercept", Value$.MODULE$.double(logisticRegressionModel.binaryModel().intercept()))).withValue("threshold", Value$.MODULE$.double(logisticRegressionModel.binaryModel().threshold()));
                }
                ProbabilisticLogisticsRegressionModel multinomialModel = logisticRegressionModel.multinomialModel();
                Matrix coefficientMatrix = multinomialModel.coefficientMatrix();
                return (Model) ((HasAttributes) ((HasAttributes) model2.withValue("coefficient_matrix", Value$.MODULE$.tensor(new DenseTensor(coefficientMatrix.toArray(), Seq$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{coefficientMatrix.numRows(), coefficientMatrix.numCols()})), ClassTag$.MODULE$.Double())))).withValue("intercept_vector", Value$.MODULE$.vector(multinomialModel.interceptVector().toArray(), ClassTag$.MODULE$.Double()))).withValue("thresholds", multinomialModel.thresholds().map(dArr -> {
                    return new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(dArr)).toSeq();
                }).map(seq -> {
                    return Value$.MODULE$.doubleList(seq);
                }));
            }

            public LogisticRegressionModel load(Model model, BundleContext<MleapContext> bundleContext) {
                ProbabilisticLogisticsRegressionModel binaryLogisticRegressionModel;
                if (model.value("num_classes").getLong() > 2) {
                    Tensor tensor = model.value("coefficient_matrix").getTensor();
                    binaryLogisticRegressionModel = new ProbabilisticLogisticsRegressionModel(Matrices$.MODULE$.dense(BoxesRunTime.unboxToInt(tensor.dimensions().head()), BoxesRunTime.unboxToInt(tensor.dimensions().apply(1)), (double[]) tensor.toArray()), Vectors$.MODULE$.dense((double[]) model.value("intercept_vector").getTensor().toArray()), model.getValue("thresholds").map(value -> {
                        return (double[]) value.getDoubleList().toArray(ClassTag$.MODULE$.Double());
                    }));
                } else {
                    binaryLogisticRegressionModel = new BinaryLogisticRegressionModel(Vectors$.MODULE$.dense((double[]) model.value("coefficients").getTensor().toArray()), model.value("intercept").getDouble(), BoxesRunTime.unboxToDouble(model.getValue("threshold").map(value2 -> {
                        return BoxesRunTime.boxToDouble(value2.getDouble());
                    }).getOrElse(() -> {
                        return 0.5d;
                    })));
                }
                return new LogisticRegressionModel(binaryLogisticRegressionModel);
            }

            /* renamed from: load, reason: collision with other method in class */
            public /* bridge */ /* synthetic */ Object m19load(Model model, BundleContext bundleContext) {
                return load(model, (BundleContext<MleapContext>) bundleContext);
            }

            public /* bridge */ /* synthetic */ Model store(Model model, Object obj, BundleContext bundleContext) {
                return store(model, (LogisticRegressionModel) obj, (BundleContext<MleapContext>) bundleContext);
            }

            {
                OpModel.$init$(this);
                this.klazz = LogisticRegressionModel.class;
            }
        };
    }
}
