package ml.combust.mleap.bundle.ops.classification;

import ml.combust.bundle.BundleContext;
import ml.combust.bundle.dsl.Bundle$BuiltinOps$classification$;
import ml.combust.bundle.dsl.HasAttributes;
import ml.combust.bundle.dsl.Model;
import ml.combust.bundle.dsl.Value$;
import ml.combust.bundle.op.OpModel;
import ml.combust.mleap.bundle.ops.MleapOp;
import ml.combust.mleap.core.classification.BinaryLogisticRegressionModel;
import ml.combust.mleap.core.classification.LogisticRegressionModel;
import ml.combust.mleap.core.classification.ProbabilisticLogisticsRegressionModel;
import ml.combust.mleap.runtime.MleapContext;
import ml.combust.mleap.runtime.transformer.classification.LogisticRegression;
import ml.combust.mleap.tensor.DenseTensor;
import ml.combust.mleap.tensor.Tensor;
import org.apache.spark.ml.linalg.Matrices$;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.Vectors$;
import scala.Predef$;
import scala.collection.Seq$;
import scala.collection.mutable.ArrayOps;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: LogisticRegressionOp.scala */
@ScalaSignature(bytes = "\u0006\u0001y2A!\u0002\u0004\u0001'!)\u0001\u0006\u0001C\u0001S!9A\u0006\u0001b\u0001\n\u0003j\u0003BB\u001d\u0001A\u0003%a\u0006C\u0003;\u0001\u0011\u00053H\u0001\u000bM_\u001eL7\u000f^5d%\u0016<'/Z:tS>tw\n\u001d\u0006\u0003\u000f!\tab\u00197bgNLg-[2bi&|gN\u0003\u0002\n\u0015\u0005\u0019q\u000e]:\u000b\u0005-a\u0011A\u00022v]\u0012dWM\u0003\u0002\u000e\u001d\u0005)Q\u000e\\3ba*\u0011q\u0002E\u0001\bG>l'-^:u\u0015\u0005\t\u0012AA7m\u0007\u0001\u0019\"\u0001\u0001\u000b\u0011\tU1\u0002$I\u0007\u0002\u0011%\u0011q\u0003\u0003\u0002\b\u001b2,\u0017\r](q!\tIr$D\u0001\u001b\u0015\t91D\u0003\u0002\u001d;\u0005YAO]1og\u001a|'/\\3s\u0015\tqB\"A\u0004sk:$\u0018.\\3\n\u0005\u0001R\"A\u0005'pO&\u001cH/[2SK\u001e\u0014Xm]:j_:\u0004\"A\t\u0014\u000e\u0003\rR!a\u0002\u0013\u000b\u0005\u0015b\u0011\u0001B2pe\u0016L!aJ\u0012\u0003/1{w-[:uS\u000e\u0014Vm\u001a:fgNLwN\\'pI\u0016d\u0017A\u0002\u001fj]&$h\bF\u0001+!\tY\u0003!D\u0001\u0007\u0003\u0015iu\u000eZ3m+\u0005q\u0003\u0003B\u00184k\u0005j\u0011\u0001\r\u0006\u0003cI\n!a\u001c9\u000b\u0005-q\u0011B\u0001\u001b1\u0005\u001dy\u0005/T8eK2\u0004\"AN\u001c\u000e\u0003uI!\u0001O\u000f\u0003\u00195cW-\u00199D_:$X\r\u001f;\u0002\r5{G-\u001a7!\u0003\u0015iw\u000eZ3m)\t\tC\bC\u0003>\t\u0001\u0007\u0001$\u0001\u0003o_\u0012,\u0007")
/* loaded from: input_file:ml/combust/mleap/bundle/ops/classification/LogisticRegressionOp.class */
public class LogisticRegressionOp extends MleapOp<LogisticRegression, LogisticRegressionModel> {
    private final OpModel<MleapContext, LogisticRegressionModel> Model;

    public OpModel<MleapContext, LogisticRegressionModel> Model() {
        return this.Model;
    }

    public LogisticRegressionModel model(LogisticRegression logisticRegression) {
        return logisticRegression.mo128model();
    }

    public LogisticRegressionOp() {
        super(ClassTag$.MODULE$.apply(LogisticRegression.class));
        final LogisticRegressionOp logisticRegressionOp = null;
        this.Model = new OpModel<MleapContext, LogisticRegressionModel>(logisticRegressionOp) { // from class: ml.combust.mleap.bundle.ops.classification.LogisticRegressionOp$$anon$1
            private final Class<LogisticRegressionModel> klazz;

            public String modelOpName(Object obj, BundleContext bundleContext) {
                return OpModel.modelOpName$(this, obj, bundleContext);
            }

            public Class<LogisticRegressionModel> klazz() {
                return this.klazz;
            }

            public String opName() {
                return Bundle$BuiltinOps$classification$.MODULE$.logistic_regression();
            }

            public Model store(Model model, LogisticRegressionModel logisticRegressionModel, BundleContext<MleapContext> bundleContext) {
                Model model2 = (Model) model.withValue("num_classes", Value$.MODULE$.long(logisticRegressionModel.numClasses()));
                if (!logisticRegressionModel.isMultinomial()) {
                    return (Model) ((HasAttributes) ((HasAttributes) model2.withValue("coefficients", Value$.MODULE$.vector(logisticRegressionModel.binaryModel().coefficients().toArray(), ClassTag$.MODULE$.Double()))).withValue("intercept", Value$.MODULE$.double(logisticRegressionModel.binaryModel().intercept()))).withValue("threshold", Value$.MODULE$.double(logisticRegressionModel.binaryModel().threshold()));
                }
                ProbabilisticLogisticsRegressionModel multinomialModel = logisticRegressionModel.multinomialModel();
                Matrix coefficientMatrix = multinomialModel.coefficientMatrix();
                return (Model) ((HasAttributes) ((HasAttributes) model2.withValue("coefficient_matrix", Value$.MODULE$.tensor(new DenseTensor(coefficientMatrix.toArray(), Seq$.MODULE$.apply(Predef$.MODULE$.wrapIntArray(new int[]{coefficientMatrix.numRows(), coefficientMatrix.numCols()})), ClassTag$.MODULE$.Double())))).withValue("intercept_vector", Value$.MODULE$.vector(multinomialModel.interceptVector().toArray(), ClassTag$.MODULE$.Double()))).withValue("thresholds", multinomialModel.thresholds().map(dArr -> {
                    return new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(dArr)).toSeq();
                }).map(seq -> {
                    return Value$.MODULE$.doubleList(seq);
                }));
            }

            public LogisticRegressionModel load(Model model, BundleContext<MleapContext> bundleContext) {
                ProbabilisticLogisticsRegressionModel binaryLogisticRegressionModel;
                if (model.value("num_classes").getLong() > 2) {
                    Tensor tensor = model.value("coefficient_matrix").getTensor();
                    binaryLogisticRegressionModel = new ProbabilisticLogisticsRegressionModel(Matrices$.MODULE$.dense(BoxesRunTime.unboxToInt(tensor.dimensions().head()), BoxesRunTime.unboxToInt(tensor.dimensions().apply(1)), (double[]) tensor.toArray()), Vectors$.MODULE$.dense((double[]) model.value("intercept_vector").getTensor().toArray()), model.getValue("thresholds").map(value -> {
                        return (double[]) value.getDoubleList().toArray(ClassTag$.MODULE$.Double());
                    }));
                } else {
                    binaryLogisticRegressionModel = new BinaryLogisticRegressionModel(Vectors$.MODULE$.dense((double[]) model.value("coefficients").getTensor().toArray()), model.value("intercept").getDouble(), BoxesRunTime.unboxToDouble(model.getValue("threshold").map(value2 -> {
                        return BoxesRunTime.boxToDouble(value2.getDouble());
                    }).getOrElse(() -> {
                        return 0.5d;
                    })));
                }
                return new LogisticRegressionModel(binaryLogisticRegressionModel);
            }

            /* renamed from: load, reason: collision with other method in class */
            public /* bridge */ /* synthetic */ Object m19load(Model model, BundleContext bundleContext) {
                return load(model, (BundleContext<MleapContext>) bundleContext);
            }

            public /* bridge */ /* synthetic */ Model store(Model model, Object obj, BundleContext bundleContext) {
                return store(model, (LogisticRegressionModel) obj, (BundleContext<MleapContext>) bundleContext);
            }

            {
                OpModel.$init$(this);
                this.klazz = LogisticRegressionModel.class;
            }
        };
    }
}
