/*
 * Decompiled with CFR 0.152.
 */
package com.stripe.rainier.sampler;

import com.stripe.rainier.sampler.DenseMassMatrix;
import com.stripe.rainier.sampler.DenseMassMatrix$;
import com.stripe.rainier.sampler.DensityFunction;
import com.stripe.rainier.sampler.DiagonalMassMatrix;
import com.stripe.rainier.sampler.IdentityMassMatrix$;
import com.stripe.rainier.sampler.MassMatrix;
import com.stripe.rainier.sampler.RNG;
import com.stripe.rainier.sampler.Stats;
import java.util.Arrays;
import scala.MatchError;
import scala.Predef$;
import scala.collection.ArrayOps$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.RichDouble$;

@ScalaSignature(bytes="\u0006\u0005\u0005Mh\u0001\u0002\u001a4\u0005qB\u0001b\u0011\u0001\u0003\u0002\u0003\u0006I\u0001\u0012\u0005\t\u0011\u0002\u0011\t\u0011)A\u0005\u0013\")A\n\u0001C\u0001\u001b\"9\u0011\u000b\u0001a\u0001\n\u0003\u0011\u0006b\u0002,\u0001\u0001\u0004%\ta\u0016\u0005\u0007;\u0002\u0001\u000b\u0015B*\t\u000by\u0003A\u0011A0\t\u000b\u0001\u0004A\u0011A1\t\u000bE\u0004A\u0011\u0001:\t\u000b]\u0004A\u0011\u0001=\t\u0013u\u0004\u0001\u0019!a\u0001\n\u0003q\bbCA\u0003\u0001\u0001\u0007\t\u0019!C\u0001\u0003\u000fA!\"a\u0003\u0001\u0001\u0004\u0005\t\u0015)\u0003\u0000\u0011)\ti\u0001\u0001a\u0001\u0002\u0004%\tA \u0005\f\u0003\u001f\u0001\u0001\u0019!a\u0001\n\u0003\t\t\u0002\u0003\u0006\u0002\u0016\u0001\u0001\r\u0011!Q!\n}D1\"a\u0006\u0001\u0001\u0004\u0005\r\u0011\"\u0001\u0002\u001a!Y\u00111\u0004\u0001A\u0002\u0003\u0007I\u0011AA\u000f\u0011)\t\t\u0003\u0001a\u0001\u0002\u0003\u0006KA\u0019\u0005\b\u0003G\u0001A\u0011AA\u0013\u0011\u001d\t9\u0004\u0001C\u0001\u0003sAq!a\u0011\u0001\t\u0003\t)\u0005C\u0004\u0002L\u0001!\t!!\u0014\t\u000f\u0005M\u0003\u0001\"\u0001\u0002V!9\u00111\f\u0001\u0005\u0002\u0005u\u0003\"CA3\u0001\t\u0007I\u0011AA4\u0011\u001d\tI\u0007\u0001Q\u0001\n%C\u0011\"a\u001b\u0001\u0005\u0004%\t!a\u001a\t\u000f\u00055\u0004\u0001)A\u0005\u0013\"I\u0011q\u000e\u0001C\u0002\u0013\u0005\u0011q\r\u0005\b\u0003c\u0002\u0001\u0015!\u0003J\u0011%\t\u0019\b\u0001b\u0001\n\u0013\t)\bC\u0004\u0002x\u0001\u0001\u000b\u0011B4\t\u0013\u0005e\u0004A1A\u0005\n\u0005U\u0004bBA>\u0001\u0001\u0006Ia\u001a\u0005\b\u0003{\u0002A\u0011BA@\u0011\u001d\t)\t\u0001C\u0005\u0003\u000fCq!!$\u0001\t\u0013\ty\tC\u0004\u0002\u0016\u0002!I!a&\t\u000f\u0005u\u0005\u0001\"\u0003\u0002 \"9\u0011Q\u0015\u0001\u0005\n\u0005\u001d\u0006bBAV\u0001\u0011%\u0011Q\u0016\u0005\b\u0003g\u0003A\u0011BA[\u0011\u001d\tY\f\u0001C\u0005\u0003{Cq!!1\u0001\t\u0013\t\u0019\rC\u0004\u0002N\u0002!I!a4\t\u000f\u0005E\u0007\u0001\"\u0003\u0002T\"9\u00111\u001c\u0001\u0005\n\u0005u\u0007bBAt\u0001\u0011%\u0011\u0011\u001e\u0002\t\u0019\u0016\f\u0007O\u0012:pO*\u0011A'N\u0001\bg\u0006l\u0007\u000f\\3s\u0015\t1t'A\u0004sC&t\u0017.\u001a:\u000b\u0005aJ\u0014AB:ue&\u0004XMC\u0001;\u0003\r\u0019w.\\\u0002\u0001'\t\u0001Q\b\u0005\u0002?\u00036\tqHC\u0001A\u0003\u0015\u00198-\u00197b\u0013\t\u0011uH\u0001\u0004B]f\u0014VMZ\u0001\bI\u0016t7/\u001b;z!\t)e)D\u00014\u0013\t95GA\bEK:\u001c\u0018\u000e^=Gk:\u001cG/[8o\u0003-\u0019H/\u0019;t/&tGm\\<\u0011\u0005yR\u0015BA&@\u0005\rIe\u000e^\u0001\u0007y%t\u0017\u000e\u001e \u0015\u00079{\u0005\u000b\u0005\u0002F\u0001!)1i\u0001a\u0001\t\")\u0001j\u0001a\u0001\u0013\u0006)1\u000f^1ugV\t1\u000b\u0005\u0002F)&\u0011Qk\r\u0002\u0006'R\fGo]\u0001\ngR\fGo]0%KF$\"\u0001W.\u0011\u0005yJ\u0016B\u0001.@\u0005\u0011)f.\u001b;\t\u000fq+\u0011\u0011!a\u0001'\u0006\u0019\u0001\u0010J\u0019\u0002\rM$\u0018\r^:!\u0003)\u0011Xm]3u'R\fGo\u001d\u000b\u0002'\u0006YAO]=Ti\u0016\u0004\b/\u001b8h)\u0011\u0011WM\u001b7\u0011\u0005y\u001a\u0017B\u00013@\u0005\u0019!u.\u001e2mK\")a\r\u0003a\u0001O\u00061\u0001/\u0019:b[N\u00042A\u00105c\u0013\tIwHA\u0003BeJ\f\u0017\u0010C\u0003l\u0011\u0001\u0007!-\u0001\u0005ti\u0016\u00048+\u001b>f\u0011\u0015i\u0007\u00021\u0001o\u0003\u0011i\u0017m]:\u0011\u0005\u0015{\u0017B\u000194\u0005)i\u0015m]:NCR\u0014\u0018\u000e_\u0001\ni\u0006\\Wm\u0015;faN$B\u0001W:vm\")A/\u0003a\u0001\u0013\u0006\tA\u000eC\u0003l\u0013\u0001\u0007!\rC\u0003n\u0013\u0001\u0007a.A\u0004jgV#VO\u001d8\u0015\u0005ed\bC\u0001 {\u0013\tYxHA\u0004C_>dW-\u00198\t\u000b\u0019T\u0001\u0019A4\u0002%%$XM]1uS>t7\u000b^1siRKW.Z\u000b\u0002\u007fB\u0019a(!\u0001\n\u0007\u0005\rqH\u0001\u0003M_:<\u0017AF5uKJ\fG/[8o'R\f'\u000f\u001e+j[\u0016|F%Z9\u0015\u0007a\u000bI\u0001C\u0004]\u0019\u0005\u0005\t\u0019A@\u0002'%$XM]1uS>t7\u000b^1siRKW.\u001a\u0011\u0002'%$XM]1uS>t7\u000b^1si\u001e\u0013\u0018\rZ:\u0002/%$XM]1uS>t7\u000b^1si\u001e\u0013\u0018\rZ:`I\u0015\fHc\u0001-\u0002\u0014!9AlDA\u0001\u0002\u0004y\u0018\u0001F5uKJ\fG/[8o'R\f'\u000f^$sC\u0012\u001c\b%A\u0003qe\u00164\b*F\u0001c\u0003%\u0001(/\u001a<I?\u0012*\u0017\u000fF\u0002Y\u0003?Aq\u0001\u0018\n\u0002\u0002\u0003\u0007!-\u0001\u0004qe\u00164\b\nI\u0001\u000fgR\f'\u000f^%uKJ\fG/[8o)\u0019\t9#a\r\u00026Q\u0019\u0001,!\u000b\t\u000f\u0005-B\u0003q\u0001\u0002.\u0005\u0019!O\\4\u0011\u0007\u0015\u000by#C\u0002\u00022M\u00121A\u0015(H\u0011\u00151G\u00031\u0001h\u0011\u0015iG\u00031\u0001o\u0003=1\u0017N\\5tQ&#XM]1uS>tGCBA\u001e\u0003\u007f\t\t\u0005F\u0002c\u0003{Aq!a\u000b\u0016\u0001\b\ti\u0003C\u0003g+\u0001\u0007q\rC\u0003n+\u0001\u0007a.\u0001\u0005t]\u0006\u00048\u000f[8u)\rA\u0016q\t\u0005\u0007\u0003\u00132\u0002\u0019A4\u0002\u0007=,H/A\u0004sKN$xN]3\u0015\u0007a\u000by\u0005\u0003\u0004\u0002R]\u0001\raZ\u0001\u0003S:\f\u0011B^1sS\u0006\u0014G.Z:\u0015\u000ba\u000b9&!\u0017\t\u000b\u0019D\u0002\u0019A4\t\r\u0005%\u0003\u00041\u0001h\u0003)Ig.\u001b;jC2L'0\u001a\u000b\u0005\u0003?\n\u0019\u0007F\u0002h\u0003CBq!a\u000b\u001a\u0001\b\ti\u0003C\u0003n3\u0001\u0007a.A\u0003o-\u0006\u00148/F\u0001J\u0003\u0019qg+\u0019:tA\u0005q\u0001o\u001c;f]RL\u0017\r\\%oI\u0016D\u0018a\u00049pi\u0016tG/[1m\u0013:$W\r\u001f\u0011\u0002\u001f%t\u0007/\u001e;PkR\u0004X\u000f^*ju\u0016\f\u0001#\u001b8qkR|U\u000f\u001e9viNK'0\u001a\u0011\u0002\u000bA\f()\u001e4\u0016\u0003\u001d\fa\u0001]9Ck\u001a\u0004\u0013a\u00012vM\u0006!!-\u001e4!\u0003\u0019)g.\u001a:hsR)!-!!\u0002\u0004\")a\r\na\u0001O\")Q\u000e\na\u0001]\u0006\tBn\\4BG\u000e,\u0007\u000f^1oG\u0016\u0004&o\u001c2\u0015\u0007\t\fI\t\u0003\u0004\u0002\f\u0016\u0002\rAY\u0001\u0007I\u0016dG/\u0019%\u0002\u000b9,w/U:\u0015\u000ba\u000b\t*a%\t\u000b-4\u0003\u0019\u00012\t\u000b54\u0003\u0019\u00018\u0002\u0017!\fGN\u001a)t\u001d\u0016<\u0018k\u001d\u000b\u00061\u0006e\u00151\u0014\u0005\u0006W\u001e\u0002\rA\u0019\u0005\u0006[\u001e\u0002\rA\\\u0001\u0018S:LG/[1m\u0011\u0006dg\r\u00165f]\u001a+H\u000e\\*uKB$R\u0001WAQ\u0003GCQa\u001b\u0015A\u0002\tDQ!\u001c\u0015A\u00029\faAZ;mYB\u001bHc\u0001-\u0002*\")1.\u000ba\u0001E\u0006Ya-\u001e7m!NtUm^)t)\u0015A\u0016qVAY\u0011\u0015Y'\u00061\u0001c\u0011\u0015i'\u00061\u0001o\u00031!xo\u001c$vY2\u001cF/\u001a9t)\u0015A\u0016qWA]\u0011\u0015Y7\u00061\u0001c\u0011\u0015i7\u00061\u0001o\u000351\u0017N\\1m\u0011\u0006dgm\u0015;faR\u0019\u0001,a0\t\u000b-d\u0003\u0019\u00012\u0002\t\r|\u0007/\u001f\u000b\u00061\u0006\u0015\u0017\u0011\u001a\u0005\u0007\u0003\u000fl\u0003\u0019A4\u0002\u0017M|WO]2f\u0003J\u0014\u0018-\u001f\u0005\u0007\u0003\u0017l\u0003\u0019A4\u0002\u0017Q\f'oZ3u\u0003J\u0014\u0018-_\u0001\u0017G>\u0004\u00180U:B]\u0012,\u0006\u000fZ1uK\u0012+gn]5usR\t\u0001,\u0001\u0005wK2|7-\u001b;z)\u001dA\u0016Q[Al\u00033Da!!\u00150\u0001\u00049\u0007BBA%_\u0001\u0007q\rC\u0003n_\u0001\u0007a.A\u0002e_R$RAYAp\u0003GDa!!91\u0001\u00049\u0017!\u0001=\t\r\u0005\u0015\b\u00071\u0001h\u0003\u0005I\u0018\u0001D5oSRL\u0017\r\\5{KB\u001bHCBAv\u0003_\f\t\u0010F\u0002Y\u0003[Dq!a\u000b2\u0001\b\ti\u0003C\u0003gc\u0001\u0007q\rC\u0003nc\u0001\u0007a\u000e")
public final class LeapFrog {
    private final DensityFunction density;
    private final int statsWindow;
    private Stats stats;
    private long iterationStartTime;
    private long iterationStartGrads;
    private double prevH;
    private final int nVars;
    private final int potentialIndex;
    private final int inputOutputSize;
    private final double[] pqBuf;
    private final double[] buf;

    public Stats stats() {
        return this.stats;
    }

    public void stats_$eq(Stats x$1) {
        this.stats = x$1;
    }

    /*
     * WARNING - void declaration
     */
    public Stats resetStats() {
        void var1_1;
        Stats oldStats = this.stats();
        this.stats_$eq(new Stats(this.statsWindow));
        return var1_1;
    }

    public double tryStepping(double[] params, double stepSize, MassMatrix mass) {
        this.copy(params, this.pqBuf());
        this.initialHalfThenFullStep(stepSize, mass);
        this.finalHalfStep(stepSize);
        double deltaH = this.energy(this.pqBuf(), mass) - this.energy(params, mass);
        return this.logAcceptanceProb(deltaH);
    }

    public void takeSteps(int l, double stepSize, MassMatrix mass) {
        this.stats().stepSizes().add(stepSize);
        this.initialHalfThenFullStep(stepSize, mass);
        for (int i = 1; i < l; ++i) {
            this.twoFullSteps(stepSize, mass);
        }
        this.finalHalfStep(stepSize);
    }

    public boolean isUTurn(double[] params) {
        double out = 0.0;
        for (int i = 0; i < this.nVars(); ++i) {
            out += (this.pqBuf()[i + this.nVars()] - params[i + this.nVars()]) * this.pqBuf()[i];
        }
        return Double.isNaN(out) ? true : out < 0.0;
    }

    public long iterationStartTime() {
        return this.iterationStartTime;
    }

    public void iterationStartTime_$eq(long x$1) {
        this.iterationStartTime = x$1;
    }

    public long iterationStartGrads() {
        return this.iterationStartGrads;
    }

    public void iterationStartGrads_$eq(long x$1) {
        this.iterationStartGrads = x$1;
    }

    public double prevH() {
        return this.prevH;
    }

    public void prevH_$eq(double x$1) {
        this.prevH = x$1;
    }

    public void startIteration(double[] params, MassMatrix mass, RNG rng) {
        this.prevH_$eq(this.energy(params, mass));
        this.initializePs(params, mass, rng);
        this.copy(params, this.pqBuf());
        this.iterationStartTime_$eq(System.nanoTime());
        this.iterationStartGrads_$eq(this.stats().gradientEvaluations());
    }

    public double finishIteration(double[] params, MassMatrix mass, RNG rng) {
        double startH = this.energy(params, mass);
        double endH = this.energy(this.pqBuf(), mass);
        double deltaH = endH - startH;
        double a = this.logAcceptanceProb(deltaH);
        if (a > Math.log(rng.standardUniform())) {
            this.copy(this.pqBuf(), params);
            this.stats().energyVariance().update(endH);
            Stats stats = this.stats();
            stats.energyTransitions2_$eq(stats.energyTransitions2() + Math.pow(endH - this.prevH(), 2.0));
        } else {
            this.stats().energyVariance().update(startH);
            Stats stats = this.stats();
            stats.energyTransitions2_$eq(stats.energyTransitions2() + Math.pow(startH - this.prevH(), 2.0));
        }
        Stats stats = this.stats();
        stats.iterations_$eq(stats.iterations() + 1);
        this.stats().iterationTimes().add(System.nanoTime() - this.iterationStartTime());
        this.stats().acceptanceRates().add(Math.exp(a));
        this.stats().gradsPerIteration().add(this.stats().gradientEvaluations() - this.iterationStartGrads());
        return a;
    }

    public void snapshot(double[] out) {
        this.copy(this.pqBuf(), out);
    }

    public void restore(double[] in) {
        this.copy(in, this.pqBuf());
    }

    public void variables(double[] params, double[] out) {
        for (int i = 0; i < this.nVars(); ++i) {
            out[i] = params[i + this.nVars()];
        }
    }

    /*
     * WARNING - void declaration
     */
    public double[] initialize(MassMatrix mass, RNG rng) {
        void var3_3;
        double[] params = new double[this.inputOutputSize()];
        Arrays.fill(this.pqBuf(), 0.0);
        int j = this.nVars() * 2;
        for (int i = this.nVars(); i < j; ++i) {
            this.pqBuf()[i] = rng.standardNormal();
        }
        this.copyQsAndUpdateDensity();
        this.pqBuf()[this.potentialIndex()] = this.density.density() * (double)-1;
        this.copy(this.pqBuf(), params);
        this.initializePs(params, mass, rng);
        return var3_3;
    }

    public int nVars() {
        return this.nVars;
    }

    public int potentialIndex() {
        return this.potentialIndex;
    }

    public int inputOutputSize() {
        return this.inputOutputSize;
    }

    private double[] pqBuf() {
        return this.pqBuf;
    }

    private double[] buf() {
        return this.buf;
    }

    private double energy(double[] params, MassMatrix mass) {
        double potential = params[this.potentialIndex()];
        this.velocity(params, this.buf(), mass);
        double kinetic = this.dot(this.buf(), params) / 2.0;
        return potential + kinetic;
    }

    private double logAcceptanceProb(double deltaH) {
        return Double.isNaN(deltaH) ? Math.log(0.0) : RichDouble$.MODULE$.min$extension(Predef$.MODULE$.doubleWrapper(-deltaH), 0.0);
    }

    private void newQs(double stepSize, MassMatrix mass) {
        this.velocity(this.pqBuf(), this.buf(), mass);
        for (int i = 0; i < this.nVars(); ++i) {
            int n = i + this.nVars();
            this.pqBuf()[n] = this.pqBuf()[n] + stepSize * this.buf()[i];
        }
    }

    private void halfPsNewQs(double stepSize, MassMatrix mass) {
        this.fullPs(stepSize / 2.0);
        this.newQs(stepSize, mass);
    }

    private void initialHalfThenFullStep(double stepSize, MassMatrix mass) {
        this.halfPsNewQs(stepSize, mass);
        this.copyQsAndUpdateDensity();
        this.pqBuf()[this.potentialIndex()] = this.density.density() * (double)-1;
    }

    private void fullPs(double stepSize) {
        this.copyQsAndUpdateDensity();
        int j = this.nVars();
        for (int i = 0; i < j; ++i) {
            int n = i;
            this.pqBuf()[n] = this.pqBuf()[n] + stepSize * this.density.gradient(i);
        }
    }

    private void fullPsNewQs(double stepSize, MassMatrix mass) {
        this.fullPs(stepSize);
        this.newQs(stepSize, mass);
    }

    private void twoFullSteps(double stepSize, MassMatrix mass) {
        this.fullPsNewQs(stepSize, mass);
        this.copyQsAndUpdateDensity();
        this.pqBuf()[this.potentialIndex()] = this.density.density() * (double)-1;
    }

    private void finalHalfStep(double stepSize) {
        this.fullPs(stepSize / 2.0);
    }

    private void copy(double[] sourceArray, double[] targetArray) {
        System.arraycopy(sourceArray, 0, targetArray, 0, this.inputOutputSize());
    }

    private void copyQsAndUpdateDensity() {
        System.arraycopy(this.pqBuf(), this.nVars(), this.buf(), 0, this.nVars());
        long t = System.nanoTime();
        this.density.update(this.buf());
        this.stats().gradientTimes().add(System.nanoTime() - t);
        Stats stats = this.stats();
        stats.gradientEvaluations_$eq(stats.gradientEvaluations() + 1L);
    }

    private void velocity(double[] in, double[] out, MassMatrix mass) {
        MassMatrix massMatrix = mass;
        if (IdentityMassMatrix$.MODULE$.equals(massMatrix)) {
            System.arraycopy(in, 0, out, 0, ArrayOps$.MODULE$.size$extension(Predef$.MODULE$.doubleArrayOps(out)));
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else if (massMatrix instanceof DiagonalMassMatrix) {
            DiagonalMassMatrix diagonalMassMatrix = (DiagonalMassMatrix)massMatrix;
            double[] elements = diagonalMassMatrix.elements();
            for (int i = 0; i < ArrayOps$.MODULE$.size$extension(Predef$.MODULE$.doubleArrayOps(out)); ++i) {
                out[i] = in[i] * elements[i];
            }
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else if (massMatrix instanceof DenseMassMatrix) {
            DenseMassMatrix denseMassMatrix = (DenseMassMatrix)massMatrix;
            double[] elements = denseMassMatrix.elements();
            DenseMassMatrix$.MODULE$.squareMultiply(elements, in, out);
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            throw new MatchError((Object)massMatrix);
        }
    }

    /*
     * WARNING - void declaration
     */
    private double dot(double[] x, double[] y) {
        void var3_3;
        double k = 0.0;
        int n = ArrayOps$.MODULE$.size$extension(Predef$.MODULE$.doubleArrayOps(x));
        for (int i = 0; i < n; ++i) {
            k += x[i] * y[i];
        }
        return (double)var3_3;
    }

    private void initializePs(double[] params, MassMatrix mass, RNG rng) {
        int i;
        for (i = 0; i < this.nVars(); ++i) {
            this.buf()[i] = rng.standardNormal();
        }
        MassMatrix massMatrix = mass;
        if (IdentityMassMatrix$.MODULE$.equals(massMatrix)) {
            System.arraycopy(this.buf(), 0, params, 0, this.nVars());
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else if (massMatrix instanceof DiagonalMassMatrix) {
            DiagonalMassMatrix diagonalMassMatrix = (DiagonalMassMatrix)massMatrix;
            double[] stdDevs = diagonalMassMatrix.stdDevs();
            for (i = 0; i < this.nVars(); ++i) {
                params[i] = this.buf()[i] / stdDevs[i];
            }
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else if (massMatrix instanceof DenseMassMatrix) {
            DenseMassMatrix denseMassMatrix = (DenseMassMatrix)massMatrix;
            double[] u = denseMassMatrix.choleskyUpperTriangular();
            DenseMassMatrix$.MODULE$.upperTriangularSolve(u, this.buf(), params);
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            throw new MatchError((Object)massMatrix);
        }
    }

    public LeapFrog(DensityFunction density, int statsWindow) {
        this.density = density;
        this.statsWindow = statsWindow;
        this.stats = new Stats(statsWindow);
        this.nVars = density.nVars();
        this.potentialIndex = this.nVars() * 2;
        this.inputOutputSize = this.potentialIndex() + 1;
        this.pqBuf = new double[this.inputOutputSize()];
        this.buf = new double[this.nVars()];
    }
}

