package org.kramerlab.autoencoder.neuralnet.rbm;

import org.kramerlab.autoencoder.math.matrix.Mat;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple3;
import scala.collection.immutable.List;
import scala.collection.mutable.StringBuilder;
import scala.runtime.AbstractFunction1;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.IntRef;
import scala.runtime.ObjectRef;

/* compiled from: Rbm.scala */
/* loaded from: input_file:org/kramerlab/autoencoder/neuralnet/rbm/Rbm$$anonfun$train$2.class */
public class Rbm$$anonfun$train$2 extends AbstractFunction1<Mat, BoxedUnit> implements Serializable {
    public static final long serialVersionUID = 0;
    private final /* synthetic */ Rbm $outer;
    public final Mat trainingSet$1;
    private final RbmTrainingConfiguration configuration$1;
    private final List trainingObservers$1;
    private final ObjectRef visibleVelocity$1;
    private final ObjectRef hiddenVelocity$1;
    private final ObjectRef weightVelocity$1;
    private final IntRef epoch$1;
    private final double currentMomentum$1;
    private final IntRef minibatchIndex$1;

    public final void apply(Mat mat) {
        if (Predef$.MODULE$.double2Double(this.$outer.connection().parameters().l2Norm()).isNaN()) {
            throw new Exception(new StringBuilder().append("Weights exploded in epoch = ").append(BoxesRunTime.boxToInteger(this.epoch$1.elem)).append(" batch = ").append(BoxesRunTime.boxToInteger(this.minibatchIndex$1.elem)).toString());
        }
        Tuple3<Mat, Mat, Mat> contrastiveDivergence = this.$outer.contrastiveDivergence(mat, this.configuration$1.gibbsSamplingSteps(this.epoch$1.elem), this.configuration$1.sampleVisibleUnitsDeterministically());
        if (contrastiveDivergence == null) {
            throw new MatchError(contrastiveDivergence);
        }
        Tuple3 tuple3 = new Tuple3((Mat) contrastiveDivergence._1(), (Mat) contrastiveDivergence._2(), (Mat) contrastiveDivergence._3());
        Mat mat2 = (Mat) tuple3._1();
        Mat mat3 = (Mat) tuple3._2();
        Mat mat4 = (Mat) tuple3._3();
        this.weightVelocity$1.elem = ((Mat) this.weightVelocity$1.elem).$times2(this.currentMomentum$1).$plus(mat3.$minus(this.configuration$1.weightPenalty(this.$outer.connection().parameters())).$times2(this.configuration$1.learningRate()));
        this.visibleVelocity$1.elem = ((Mat) this.visibleVelocity$1.elem).$times2(this.currentMomentum$1).$plus(mat2.$minus(this.configuration$1.weightPenalty(this.$outer.visible().parameters())).$times2(this.configuration$1.learningRate()));
        this.hiddenVelocity$1.elem = ((Mat) this.hiddenVelocity$1.elem).$times2(this.currentMomentum$1).$plus(mat4.$minus(this.configuration$1.weightPenalty(this.$outer.hidden().parameters())).$times2(this.configuration$1.learningRate()));
        this.$outer.connection().parameters().$plus$eq((Mat) this.weightVelocity$1.elem);
        this.$outer.visible().parameters().$plus$eq((Mat) this.visibleVelocity$1.elem);
        this.$outer.hidden().parameters().$plus$eq((Mat) this.hiddenVelocity$1.elem);
        this.$outer.visible().parameters().transpose().sumRows().$div2(this.$outer.visible().parameters().width());
        this.$outer.hidden().parameters().transpose().sumRows().$div2(this.$outer.hidden().parameters().width());
        this.trainingObservers$1.foreach(new Rbm$$anonfun$train$2$$anonfun$apply$1(this));
        this.minibatchIndex$1.elem++;
    }

    public /* synthetic */ Rbm org$kramerlab$autoencoder$neuralnet$rbm$Rbm$$anonfun$$$outer() {
        return this.$outer;
    }

    public final /* bridge */ /* synthetic */ Object apply(Object obj) {
        apply((Mat) obj);
        return BoxedUnit.UNIT;
    }

    public Rbm$$anonfun$train$2(Rbm rbm, Mat mat, RbmTrainingConfiguration rbmTrainingConfiguration, List list, ObjectRef objectRef, ObjectRef objectRef2, ObjectRef objectRef3, IntRef intRef, double d, IntRef intRef2) {
        if (rbm == null) {
            throw new NullPointerException();
        }
        this.$outer = rbm;
        this.trainingSet$1 = mat;
        this.configuration$1 = rbmTrainingConfiguration;
        this.trainingObservers$1 = list;
        this.visibleVelocity$1 = objectRef;
        this.hiddenVelocity$1 = objectRef2;
        this.weightVelocity$1 = objectRef3;
        this.epoch$1 = intRef;
        this.currentMomentum$1 = d;
        this.minibatchIndex$1 = intRef2;
    }
}
