package nak.classify;

import breeze.linalg.Counter;
import breeze.linalg.DenseVector;
import breeze.linalg.DenseVector$;
import breeze.linalg.Tensor$;
import breeze.linalg.package$;
import breeze.optimize.FirstOrderMinimizer;
import breeze.util.Encoder;
import breeze.util.Encoder$;
import breeze.util.Index;
import breeze.util.Index$;
import breeze.util.MutableIndex;
import nak.classify.Classifier;
import nak.data.Example;
import nak.nnet.NNObjective;
import nak.nnet.NeuralNetwork;
import scala.Array$;
import scala.Function1;
import scala.Predef$;
import scala.Predef$ArrowAssoc$;
import scala.Tuple2;
import scala.collection.Iterable;
import scala.collection.parallel.mutable.ParArray;
import scala.collection.parallel.mutable.ParArray$;
import scala.reflect.ClassTag$;
import scala.reflect.NoManifest$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: NNetClassifier.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005Ud\u0001B\u0001\u0003\u0001\u001d\u0011aB\u0014(fi\u000ec\u0017m]:jM&,'O\u0003\u0002\u0004\t\u0005A1\r\\1tg&4\u0017PC\u0001\u0006\u0003\rq\u0017m[\u0002\u0001+\rAQcH\n\u0004\u0001%y\u0001C\u0001\u0006\u000e\u001b\u0005Y!\"\u0001\u0007\u0002\u000bM\u001c\u0017\r\\1\n\u00059Y!AB!osJ+g\r\u0005\u0003\u0011#MqR\"\u0001\u0002\n\u0005I\u0011!AC\"mCN\u001c\u0018NZ5feB\u0011A#\u0006\u0007\u0001\t\u00151\u0002A1\u0001\u0018\u0005\u0005a\u0015C\u0001\r\u001c!\tQ\u0011$\u0003\u0002\u001b\u0017\t9aj\u001c;iS:<\u0007C\u0001\u0006\u001d\u0013\ti2BA\u0002B]f\u0004\"\u0001F\u0010\u0005\u000b\u0001\u0002!\u0019A\f\u0003\u0003QC\u0001B\t\u0001\u0003\u0002\u0003\u0006IaI\u0001\u0005]:,G\u000f\u0005\u0002%M5\tQE\u0003\u0002#\t%\u0011q%\n\u0002\u000e\u001d\u0016,(/\u00197OKR<xN]6\t\u0011%\u0002!\u0011!Q\u0001\n)\nA\"\u001b8qkR,enY8eKJ\u0004BAC\u0016\u001f[%\u0011Af\u0003\u0002\n\rVt7\r^5p]F\u00022AL\u001a6\u001b\u0005y#B\u0001\u00192\u0003\u0019a\u0017N\\1mO*\t!'\u0001\u0004ce\u0016,'0Z\u0005\u0003i=\u00121\u0002R3og\u00164Vm\u0019;peB\u0011!BN\u0005\u0003o-\u0011a\u0001R8vE2,\u0007\u0002C\u001d\u0001\u0005\u0003\u0005\u000b\u0011\u0002\u001e\u0002\u00151\f'-\u001a7J]\u0012,\u0007\u0010E\u0002<}Mi\u0011\u0001\u0010\u0006\u0003{E\nA!\u001e;jY&\u0011q\b\u0010\u0002\u0006\u0013:$W\r\u001f\u0005\u0006\u0003\u0002!\tAQ\u0001\u0007y%t\u0017\u000e\u001e \u0015\t\r#UI\u0012\t\u0005!\u0001\u0019b\u0004C\u0003#\u0001\u0002\u00071\u0005C\u0003*\u0001\u0002\u0007!\u0006C\u0003:\u0001\u0002\u0007!\bC\u0003I\u0001\u0011\u0005\u0011*\u0001\u0004tG>\u0014Xm\u001d\u000b\u0003\u00156\u0003BAL&\u0014k%\u0011Aj\f\u0002\b\u0007>,h\u000e^3s\u0011\u0015qu\t1\u0001\u001f\u0003\u0005yw!\u0002)\u0003\u0011\u0003\t\u0016A\u0004(OKR\u001cE.Y:tS\u001aLWM\u001d\t\u0003!I3Q!\u0001\u0002\t\u0002M\u001b\"AU\u0005\t\u000b\u0005\u0013F\u0011A+\u0015\u0003E3Aa\u0016*\u00011\nq1i\\;oi\u0016\u0014HK]1j]\u0016\u0014XcA-cKN\u0019a+\u0003.\u0011\tms\u0016m\u0019\b\u0003!qK!!\u0018\u0002\u0002\u0015\rc\u0017m]:jM&,'/\u0003\u0002`A\n9AK]1j]\u0016\u0014(BA/\u0003!\t!\"\rB\u0003\u0017-\n\u0007q\u0003\u0005\u0003/\u0017\u0012,\u0004C\u0001\u000bf\t\u0015\u0001cK1\u0001\u0018\u0011!9gK!A!\u0002\u0013A\u0017aA8qiB\u0011\u0011.\u001e\b\u0003UJt!a\u001b9\u000f\u00051|W\"A7\u000b\u000594\u0011A\u0002\u001fs_>$h(C\u00013\u0013\t\t\u0018'\u0001\u0005paRLW.\u001b>f\u0013\t\u0019H/A\nGSJ\u001cHo\u0014:eKJl\u0015N\\5nSj,'O\u0003\u0002rc%\u0011ao\u001e\u0002\n\u001fB$\b+\u0019:b[NT!a\u001d;\t\u0011e4&\u0011!Q\u0001\ni\f\u0001\u0002\\1zKJ\u001c\u0018J\u001c\t\u0004\u0015ml\u0018B\u0001?\f\u0005\u0015\t%O]1z!\tQa0\u0003\u0002��\u0017\t\u0019\u0011J\u001c;\t\r\u00053F\u0011AA\u0002)\u0019\t)!!\u0003\u0002\fA)\u0011q\u0001,bI6\t!\u000b\u0003\u0005h\u0003\u0003\u0001\n\u00111\u0001i\u0011!I\u0018\u0011\u0001I\u0001\u0002\u0004QXABA\b-\u0002\t\tB\u0001\u0007Ns\u000ec\u0017m]:jM&,'\u000f\u0005\u0003\u0011\u0001\u0005\u001c\u0007bBA\u000b-\u0012\u0005\u0011qC\u0001\u0006iJ\f\u0017N\u001c\u000b\u0005\u0003#\tI\u0002\u0003\u0005\u0002\u001c\u0005M\u0001\u0019AA\u000f\u0003\u0011!\u0017\r^1\u0011\r\u0005}\u0011\u0011FA\u0018\u001d\u0011\t\t#!\n\u000f\u00071\f\u0019#C\u0001\r\u0013\r\t9cC\u0001\ba\u0006\u001c7.Y4f\u0013\u0011\tY#!\f\u0003\u0011%#XM]1cY\u0016T1!a\n\f!\u0019\t\t$!\u000ebG6\u0011\u00111\u0007\u0006\u0004\u00037!\u0011\u0002BA\u001c\u0003g\u0011q!\u0012=b[BdWmB\u0005\u0002<I\u000b\t\u0011#\u0001\u0002>\u0005q1i\\;oi\u0016\u0014HK]1j]\u0016\u0014\b\u0003BA\u0004\u0003\u007f1\u0001b\u0016*\u0002\u0002#\u0005\u0011\u0011I\n\u0004\u0003\u007fI\u0001bB!\u0002@\u0011\u0005\u0011Q\t\u000b\u0003\u0003{A!\"!\u0013\u0002@E\u0005I\u0011AA&\u0003m!C.Z:tS:LG\u000fJ4sK\u0006$XM\u001d\u0013eK\u001a\fW\u000f\u001c;%cU1\u0011QJA2\u0003K*\"!a\u0014+\u0007!\f\tf\u000b\u0002\u0002TA!\u0011QKA0\u001b\t\t9F\u0003\u0003\u0002Z\u0005m\u0013!C;oG\",7m[3e\u0015\r\tifC\u0001\u000bC:tw\u000e^1uS>t\u0017\u0002BA1\u0003/\u0012\u0011#\u001e8dQ\u0016\u001c7.\u001a3WCJL\u0017M\\2f\t\u00191\u0012q\tb\u0001/\u00111\u0001%a\u0012C\u0002]A!\"!\u001b\u0002@E\u0005I\u0011AA6\u0003m!C.Z:tS:LG\u000fJ4sK\u0006$XM\u001d\u0013eK\u001a\fW\u000f\u001c;%eU1\u0011QNA9\u0003g*\"!a\u001c+\u0007i\f\t\u0006\u0002\u0004\u0017\u0003O\u0012\ra\u0006\u0003\u0007A\u0005\u001d$\u0019A\f")
/* loaded from: input_file:nak/classify/NNetClassifier.class */
public class NNetClassifier<L, T> implements Classifier<L, T> {
    private final NeuralNetwork nnet;
    private final Function1<T, DenseVector<Object>> inputEncoder;
    private final Index<L> labelIndex;

    /* compiled from: NNetClassifier.scala */
    /* loaded from: input_file:nak/classify/NNetClassifier$CounterTrainer.class */
    public static class CounterTrainer<L, T> implements Classifier.Trainer<L, Counter<T, Object>> {
        private final FirstOrderMinimizer.OptParams opt;
        private final int[] layersIn;

        @Override // nak.classify.Classifier.Trainer
        public NNetClassifier<L, Counter<T, Object>> train(Iterable<Example<L, Counter<T, Object>>> iterable) {
            MutableIndex apply = Index$.MODULE$.apply(NoManifest$.MODULE$);
            iterable.foreach(new NNetClassifier$CounterTrainer$$anonfun$train$1(this, apply));
            MutableIndex apply2 = Index$.MODULE$.apply(NoManifest$.MODULE$);
            iterable.foreach(new NNetClassifier$CounterTrainer$$anonfun$train$2(this, apply2));
            Encoder fromIndex = Encoder$.MODULE$.fromIndex(apply2);
            NNObjective nNObjective = new NNObjective(((ParArray) Predef$.MODULE$.refArrayOps((Object[]) iterable.toArray(ClassTag$.MODULE$.apply(Example.class))).par().map(new NNetClassifier$CounterTrainer$$anonfun$1(this, apply, fromIndex), ParArray$.MODULE$.canBuildFrom())).toIndexedSeq(), new NNetClassifier$CounterTrainer$$anonfun$2(this), (int[]) Predef$.MODULE$.intArrayOps((int[]) Predef$.MODULE$.intArrayOps(new int[]{apply2.size()}).$plus$plus(Predef$.MODULE$.intArrayOps(this.layersIn), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()))).$plus$plus(Predef$.MODULE$.intArrayOps(new int[]{apply.size()}), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int())));
            return new NNetClassifier<>(nNObjective.extract((DenseVector) this.opt.minimize(nNObjective, nNObjective.initialWeightVector(), DenseVector$.MODULE$.space_d())), new NNetClassifier$CounterTrainer$$anonfun$train$3(this, fromIndex), apply);
        }

        public final Tuple2 nak$classify$NNetClassifier$CounterTrainer$$errorFun$1(DenseVector denseVector, int i) {
            double unboxToDouble = BoxesRunTime.unboxToDouble(package$.MODULE$.softmax().apply$mcD$sp(denseVector, Tensor$.MODULE$.canUReduce(Predef$.MODULE$.conforms())));
            double apply$mcD$sp = unboxToDouble - denseVector.apply$mcD$sp(i);
            DenseVector denseVector2 = (DenseVector) breeze.numerics.package$.MODULE$.exp().apply(denseVector.$minus(BoxesRunTime.boxToDouble(unboxToDouble), DenseVector$.MODULE$.dv_s_Op_Double_OpSub()), DenseVector$.MODULE$.canMapValues(ClassTag$.MODULE$.Double()));
            denseVector2.update$mcD$sp(i, denseVector2.apply$mcD$sp(i) - 1);
            return Predef$ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.any2ArrowAssoc(BoxesRunTime.boxToDouble(apply$mcD$sp)), denseVector2);
        }

        public CounterTrainer(FirstOrderMinimizer.OptParams optParams, int[] iArr) {
            this.opt = optParams;
            this.layersIn = iArr;
        }
    }

    @Override // nak.classify.Classifier
    public L apply(T t) {
        return (L) Classifier.Cclass.apply(this, t);
    }

    @Override // nak.classify.Classifier
    public L classify(T t) {
        return (L) Classifier.Cclass.classify(this, t);
    }

    @Override // nak.classify.Classifier
    public <M> Classifier<M, T> map(Function1<L, M> function1) {
        return Classifier.Cclass.map(this, function1);
    }

    public boolean apply$mcZD$sp(double d) {
        return Function1.class.apply$mcZD$sp(this, d);
    }

    public double apply$mcDD$sp(double d) {
        return Function1.class.apply$mcDD$sp(this, d);
    }

    public float apply$mcFD$sp(double d) {
        return Function1.class.apply$mcFD$sp(this, d);
    }

    public int apply$mcID$sp(double d) {
        return Function1.class.apply$mcID$sp(this, d);
    }

    public long apply$mcJD$sp(double d) {
        return Function1.class.apply$mcJD$sp(this, d);
    }

    public void apply$mcVD$sp(double d) {
        Function1.class.apply$mcVD$sp(this, d);
    }

    public boolean apply$mcZF$sp(float f) {
        return Function1.class.apply$mcZF$sp(this, f);
    }

    public double apply$mcDF$sp(float f) {
        return Function1.class.apply$mcDF$sp(this, f);
    }

    public float apply$mcFF$sp(float f) {
        return Function1.class.apply$mcFF$sp(this, f);
    }

    public int apply$mcIF$sp(float f) {
        return Function1.class.apply$mcIF$sp(this, f);
    }

    public long apply$mcJF$sp(float f) {
        return Function1.class.apply$mcJF$sp(this, f);
    }

    public void apply$mcVF$sp(float f) {
        Function1.class.apply$mcVF$sp(this, f);
    }

    public boolean apply$mcZI$sp(int i) {
        return Function1.class.apply$mcZI$sp(this, i);
    }

    public double apply$mcDI$sp(int i) {
        return Function1.class.apply$mcDI$sp(this, i);
    }

    public float apply$mcFI$sp(int i) {
        return Function1.class.apply$mcFI$sp(this, i);
    }

    public int apply$mcII$sp(int i) {
        return Function1.class.apply$mcII$sp(this, i);
    }

    public long apply$mcJI$sp(int i) {
        return Function1.class.apply$mcJI$sp(this, i);
    }

    public void apply$mcVI$sp(int i) {
        Function1.class.apply$mcVI$sp(this, i);
    }

    public boolean apply$mcZJ$sp(long j) {
        return Function1.class.apply$mcZJ$sp(this, j);
    }

    public double apply$mcDJ$sp(long j) {
        return Function1.class.apply$mcDJ$sp(this, j);
    }

    public float apply$mcFJ$sp(long j) {
        return Function1.class.apply$mcFJ$sp(this, j);
    }

    public int apply$mcIJ$sp(long j) {
        return Function1.class.apply$mcIJ$sp(this, j);
    }

    public long apply$mcJJ$sp(long j) {
        return Function1.class.apply$mcJJ$sp(this, j);
    }

    public void apply$mcVJ$sp(long j) {
        Function1.class.apply$mcVJ$sp(this, j);
    }

    public <A> Function1<A, L> compose(Function1<A, T> function1) {
        return Function1.class.compose(this, function1);
    }

    public <A> Function1<A, Object> compose$mcZD$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcZD$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcDD$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcDD$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcFD$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcFD$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcID$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcID$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcJD$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcJD$sp(this, function1);
    }

    public <A> Function1<A, BoxedUnit> compose$mcVD$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcVD$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcZF$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcZF$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcDF$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcDF$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcFF$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcFF$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcIF$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcIF$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcJF$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcJF$sp(this, function1);
    }

    public <A> Function1<A, BoxedUnit> compose$mcVF$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcVF$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcZI$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcZI$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcDI$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcDI$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcFI$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcFI$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcII$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcII$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcJI$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcJI$sp(this, function1);
    }

    public <A> Function1<A, BoxedUnit> compose$mcVI$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcVI$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcZJ$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcZJ$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcDJ$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcDJ$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcFJ$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcFJ$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcIJ$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcIJ$sp(this, function1);
    }

    public <A> Function1<A, Object> compose$mcJJ$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcJJ$sp(this, function1);
    }

    public <A> Function1<A, BoxedUnit> compose$mcVJ$sp(Function1<A, Object> function1) {
        return Function1.class.compose$mcVJ$sp(this, function1);
    }

    public <A> Function1<T, A> andThen(Function1<L, A> function1) {
        return Function1.class.andThen(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcZD$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcZD$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcDD$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcDD$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcFD$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcFD$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcID$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcID$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcJD$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcJD$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcVD$sp(Function1<BoxedUnit, A> function1) {
        return Function1.class.andThen$mcVD$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcZF$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcZF$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcDF$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcDF$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcFF$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcFF$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcIF$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcIF$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcJF$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcJF$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcVF$sp(Function1<BoxedUnit, A> function1) {
        return Function1.class.andThen$mcVF$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcZI$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcZI$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcDI$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcDI$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcFI$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcFI$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcII$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcII$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcJI$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcJI$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcVI$sp(Function1<BoxedUnit, A> function1) {
        return Function1.class.andThen$mcVI$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcZJ$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcZJ$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcDJ$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcDJ$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcFJ$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcFJ$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcIJ$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcIJ$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcJJ$sp(Function1<Object, A> function1) {
        return Function1.class.andThen$mcJJ$sp(this, function1);
    }

    public <A> Function1<Object, A> andThen$mcVJ$sp(Function1<BoxedUnit, A> function1) {
        return Function1.class.andThen$mcVJ$sp(this, function1);
    }

    public String toString() {
        return Function1.class.toString(this);
    }

    @Override // nak.classify.Classifier
    public Counter<L, Object> scores(T t) {
        Encoder fromIndex = Encoder$.MODULE$.fromIndex(this.labelIndex);
        return fromIndex.decode(this.nnet.apply((DenseVector<Object>) this.inputEncoder.apply(t)), fromIndex.decode$default$2());
    }

    public NNetClassifier(NeuralNetwork neuralNetwork, Function1<T, DenseVector<Object>> function1, Index<L> index) {
        this.nnet = neuralNetwork;
        this.inputEncoder = function1;
        this.labelIndex = index;
        Function1.class.$init$(this);
        Classifier.Cclass.$init$(this);
    }
}
