package water.util;

import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import water.DTask;
import water.H2O;
import water.H2ONode;
import water.MRTask;
import water.RPC;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;

/* loaded from: input_file:water/util/MRUtils.class */
public class MRUtils {
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:water/util/MRUtils$ClassDist.class */
    public static class ClassDist extends ClassDistHelper {
        static final /* synthetic */ boolean $assertionsDisabled;

        public ClassDist(Vec vec) {
            super(vec.domain().length);
        }

        public final long[] dist() {
            return this._ys;
        }

        public final float[] rel_dist() {
            float[] fArr = new float[this._ys.length];
            for (int i = 0; i < this._ys.length; i++) {
                fArr[i] = (float) this._ys[i];
            }
            float sum = ArrayUtils.sum(fArr);
            if (!$assertionsDisabled && sum == 0.0d) {
                throw new AssertionError();
            }
            ArrayUtils.div(fArr, sum);
            return fArr;
        }

        @Override // water.util.MRUtils.ClassDistHelper
        public /* bridge */ /* synthetic */ void reduce(ClassDist classDist) {
            super.reduce(classDist);
        }

        @Override // water.util.MRUtils.ClassDistHelper, water.MRTask
        public /* bridge */ /* synthetic */ void map(Chunk chunk) {
            super.map(chunk);
        }

        static {
            $assertionsDisabled = !MRUtils.class.desiredAssertionStatus();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:water/util/MRUtils$ClassDistHelper.class */
    public static class ClassDistHelper extends MRTask<ClassDist> {
        final int _nclass;
        protected long[] _ys;

        private ClassDistHelper(int i) {
            this._nclass = i;
        }

        @Override // water.MRTask
        public void map(Chunk chunk) {
            this._ys = new long[this._nclass];
            for (int i = 0; i < chunk._len; i++) {
                if (!chunk.isNA0(i)) {
                    long[] jArr = this._ys;
                    int at80 = (int) chunk.at80(i);
                    jArr[at80] = jArr[at80] + 1;
                }
            }
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // water.MRTask
        public void reduce(ClassDist classDist) {
            ArrayUtils.add(this._ys, classDist._ys);
        }
    }

    /* loaded from: input_file:water/util/MRUtils$ParallelTasks.class */
    public static class ParallelTasks<T extends DTask<T>> extends H2O.H2OCountedCompleter {
        public final transient T[] _tasks;
        public final transient int _maxP;
        private transient AtomicInteger _nextTask;

        /* loaded from: input_file:water/util/MRUtils$ParallelTasks$Callback.class */
        class Callback extends H2O.H2OCallback<H2O.H2OCountedCompleter> {
            final int i;
            final H2ONode n;

            public Callback(H2ONode h2ONode, int i) {
                super(ParallelTasks.this);
                this.n = h2ONode;
                this.i = i;
            }

            @Override // water.H2O.H2OCallback
            public void callback(H2O.H2OCountedCompleter h2OCountedCompleter) {
                int andIncrement = ParallelTasks.this._nextTask.getAndIncrement();
                if (andIncrement < ParallelTasks.this._tasks.length) {
                    ParallelTasks.this.forkDTask(andIncrement, this.n);
                }
            }
        }

        public ParallelTasks(H2O.H2OCountedCompleter h2OCountedCompleter, T[] tArr) {
            this(h2OCountedCompleter, tArr, H2O.CLOUD.size());
        }

        public ParallelTasks(H2O.H2OCountedCompleter h2OCountedCompleter, T[] tArr, int i) {
            super(h2OCountedCompleter);
            this._maxP = i;
            this._tasks = tArr;
            addToPendingCount(this._tasks.length - 1);
        }

        private void forkDTask(int i) {
            forkDTask(i, H2O.CLOUD._memary[i % H2O.CLOUD.size()]);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void forkDTask(int i, H2ONode h2ONode) {
            if (h2ONode == H2O.SELF) {
                H2O.submitTask(this._tasks[i]);
            } else {
                new RPC(h2ONode, this._tasks[i]).addCompleter(this).call();
            }
        }

        @Override // water.H2O.H2OCountedCompleter
        public void compute2() {
            int min = Math.min(this._maxP, this._tasks.length);
            this._nextTask = new AtomicInteger(min);
            for (int i = 0; i < min; i++) {
                forkDTask(i);
            }
        }
    }

    public static Frame sampleFrame(Frame frame, long j, final long j2) {
        if (frame == null) {
            return null;
        }
        final float numRows = j > 0 ? ((float) j) / ((float) frame.numRows()) : 1.0f;
        if (numRows >= 1.0f) {
            return frame;
        }
        Frame outputFrame = new MRTask() { // from class: water.util.MRUtils.1
            @Override // water.MRTask
            public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
                Random deterRNG = RandomUtils.getDeterRNG(j2 + chunkArr[0].cidx());
                int i = 0;
                for (int i2 = 0; i2 < chunkArr[0]._len; i2++) {
                    if (deterRNG.nextFloat() < numRows || (i == 0 && i2 == chunkArr[0]._len - 1)) {
                        i++;
                        for (int i3 = 0; i3 < newChunkArr.length; i3++) {
                            newChunkArr[i3].addNum(chunkArr[i3].at0(i2));
                        }
                    }
                }
            }
        }.doAll(frame.numCols(), frame).outputFrame(frame.names(), frame.domains());
        if (outputFrame.numRows() != 0) {
            return outputFrame;
        }
        Log.warn("You asked for " + j + " rows (out of " + frame.numRows() + "), but you got none (seed=" + j2 + ").");
        Log.warn("Let's try again. You've gotta ask yourself a question: \"Do I feel lucky?\"");
        return sampleFrame(frame, j, j2 + 1);
    }

    public static Frame shuffleFramePerChunk(Frame frame, final long j) {
        return new MRTask() { // from class: water.util.MRUtils.2
            @Override // water.MRTask
            public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
                long[] jArr = new long[chunkArr[0]._len];
                for (int i = 0; i < jArr.length; i++) {
                    jArr[i] = i;
                }
                ArrayUtils.shuffleArray(jArr, j);
                for (long j2 : jArr) {
                    for (int i2 = 0; i2 < newChunkArr.length; i2++) {
                        newChunkArr[i2].addNum(chunkArr[i2].at0((int) j2));
                    }
                }
            }
        }.doAll(frame.numCols(), frame).outputFrame(frame.names(), frame.domains());
    }

    public static Frame sampleFrameStratified(Frame frame, Vec vec, float[] fArr, long j, long j2, boolean z, boolean z2) {
        if (frame == null) {
            return null;
        }
        if (!$assertionsDisabled && !vec.isEnum()) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && j < vec.domain().length) {
            throw new AssertionError();
        }
        long[] dist = new ClassDist(vec).doAll(vec).dist();
        if (!$assertionsDisabled && dist.length <= 0) {
            throw new AssertionError();
        }
        Object[] objArr = new Object[1];
        objArr[0] = "Doing stratified sampling for data set containing " + frame.numRows() + " rows from " + dist.length + " classes. Oversampling: " + (z ? "on" : "off");
        Log.info(objArr);
        if (z2) {
            for (int i = 0; i < dist.length; i++) {
                Log.info("Class " + vec.factor(i) + ": count: " + dist[i] + " prior: " + (((float) dist[i]) / ((float) frame.numRows())));
            }
        }
        if (fArr == null || (ArrayUtils.minValue(fArr) == 0.0f && ArrayUtils.maxValue(fArr) == 0.0f)) {
            if (fArr == null) {
                fArr = new float[dist.length];
            }
            if (!$assertionsDisabled && fArr.length != dist.length) {
                throw new AssertionError();
            }
            for (int i2 = 0; i2 < dist.length; i2++) {
                fArr[i2] = (((float) frame.numRows()) / vec.domain().length) / ((float) dist[i2]);
            }
            float minValue = ArrayUtils.minValue(fArr);
            if (!Float.isNaN(minValue) && !Float.isInfinite(minValue)) {
                ArrayUtils.div(fArr, minValue);
            }
        }
        if (!z) {
            for (int i3 = 0; i3 < fArr.length; i3++) {
                fArr[i3] = Math.min(1.0f, fArr[i3]);
            }
        }
        float f = 0.0f;
        for (int i4 = 0; i4 < fArr.length; i4++) {
            f += fArr[i4] * ((float) dist[i4]);
        }
        long min = Math.min(j, Math.round(f));
        if (!$assertionsDisabled && min < 0) {
            throw new AssertionError();
        }
        Object[] objArr2 = new Object[1];
        objArr2[0] = "Stratified sampling to a total of " + String.format("%,d", Long.valueOf(min)) + " rows" + (((float) min) < f ? " (limited by max_after_balance_size)." : ".");
        Log.info(objArr2);
        if (((float) min) != f) {
            ArrayUtils.mult(fArr, ((float) min) / f);
            if (z2) {
                Log.info("Downsampling majority class by " + (((float) min) / f) + " to limit number of rows to " + String.format("%,d", Long.valueOf(j)));
            }
        }
        for (int i5 = 0; i5 < vec.domain().length; i5++) {
            Log.info("Class '" + vec.domain()[i5].toString() + "' sampling ratio: " + fArr[i5]);
        }
        return sampleFrameStratified(frame, vec, fArr, j2, z2);
    }

    public static Frame sampleFrameStratified(Frame frame, Vec vec, float[] fArr, long j, boolean z) {
        return sampleFrameStratified(frame, vec, fArr, j, z, 0);
    }

    private static Frame sampleFrameStratified(Frame frame, Vec vec, final float[] fArr, final long j, boolean z, int i) {
        if (frame == null) {
            return null;
        }
        if (!$assertionsDisabled && !vec.isEnum()) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && (fArr == null || fArr.length != vec.domain().length)) {
            throw new AssertionError();
        }
        final int find = frame.find(vec);
        if (!$assertionsDisabled && find < 0) {
            throw new AssertionError();
        }
        Frame outputFrame = new MRTask() { // from class: water.util.MRUtils.3
            static final /* synthetic */ boolean $assertionsDisabled;

            @Override // water.MRTask
            public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
                Random deterRNG = RandomUtils.getDeterRNG(j + chunkArr[0].cidx());
                for (int i2 = 0; i2 < chunkArr[0]._len; i2++) {
                    if (!chunkArr[find].isNA0(i2)) {
                        int at80 = (int) chunkArr[find].at80(i2);
                        if (!$assertionsDisabled && (fArr.length <= at80 || at80 < 0)) {
                            throw new AssertionError();
                        }
                        int i3 = ((int) fArr[at80]) + (deterRNG.nextFloat() < fArr[at80] - ((float) ((int) fArr[at80])) ? 1 : 0);
                        for (int i4 = 0; i4 < newChunkArr.length; i4++) {
                            for (int i5 = 0; i5 < i3; i5++) {
                                newChunkArr[i4].addNum(chunkArr[i4].at0(i2));
                            }
                        }
                    }
                }
            }

            static {
                $assertionsDisabled = !MRUtils.class.desiredAssertionStatus();
            }
        }.doAll(frame.numCols(), frame).outputFrame(frame.names(), frame.domains());
        long[] dist = new ClassDist(outputFrame.vecs()[find]).doAll(outputFrame.vecs()[find]).dist();
        if (dist == null) {
            return frame;
        }
        if (z) {
            long sum = ArrayUtils.sum(dist);
            Log.info("After stratified sampling: " + sum + " rows.");
            for (int i2 = 0; i2 < dist.length; i2++) {
                Log.info("Class " + outputFrame.vecs()[find].factor(i2) + ": count: " + dist[i2] + " sampling ratio: " + fArr[i2] + " actual relative frequency: " + ((((float) dist[i2]) / ((float) sum)) * dist.length));
            }
        }
        if (ArrayUtils.minValue(dist) != 0 || i >= 10) {
            Frame shuffleFramePerChunk = shuffleFramePerChunk(outputFrame, j + 92339987);
            outputFrame.delete();
            return shuffleFramePerChunk;
        }
        Log.info("Re-doing stratified sampling because not all classes were represented (unlucky draw).");
        outputFrame.delete();
        return sampleFrameStratified(frame, vec, fArr, j + 1, z, i + 1);
    }

    static {
        $assertionsDisabled = !MRUtils.class.desiredAssertionStatus();
    }
}
