package water.rapids.ast.prims.advmath;

import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Random;
import water.DKV;
import water.Iced;
import water.Key;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.util.RandomUtils;
import water.util.VecUtils;

/* loaded from: input_file:water/rapids/ast/prims/advmath/AstStratifiedSplit.class */
public class AstStratifiedSplit extends AstPrimitive {
    public static final String OUTPUT_COLUMN_NAME = "test_train_split";
    public static final String[] OUTPUT_COLUMN_DOMAIN = {"train", "test"};

    /* loaded from: input_file:water/rapids/ast/prims/advmath/AstStratifiedSplit$ClassAssignMRTask.class */
    public static class ClassAssignMRTask extends MRTask<ClassAssignMRTask> {
        HashSet<Long> _idx;

        ClassAssignMRTask(HashSet<Long> hashSet) {
            this._idx = hashSet;
        }

        @Override // water.MRTask
        public void map(Chunk chunk) {
            for (int i = 0; i < chunk.len(); i++) {
                if (this._idx.contains(Long.valueOf(chunk.start() + i))) {
                    chunk.set(i, 1.0d);
                }
            }
            this._idx = null;
        }
    }

    /* loaded from: input_file:water/rapids/ast/prims/advmath/AstStratifiedSplit$ClassIdxTask.class */
    public static class ClassIdxTask extends MRTask<ClassIdxTask> {
        LongAry[] _indexes;
        private final int _nclasses;
        private long[] _classes;
        private transient HashMap<Long, Integer> _classMap;

        public ClassIdxTask(int i, long[] jArr) {
            this._nclasses = i;
            this._classes = jArr;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        @Override // water.MRTask
        public void setupLocal() {
            this._classMap = new HashMap<>(2 * this._classes.length);
            for (int i = 0; i < this._classes.length; i++) {
                this._classMap.put(Long.valueOf(this._classes[i]), Integer.valueOf(i));
            }
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr) {
            this._indexes = new LongAry[this._nclasses];
            for (int i = 0; i < this._nclasses; i++) {
                this._indexes[i] = new LongAry(new long[0]);
            }
            for (int i2 = 0; i2 < chunkArr[0].len(); i2++) {
                Integer num = this._classMap.get(Long.valueOf(chunkArr[0].at8(i2)));
                if (num != null) {
                    this._indexes[num.intValue()].add(chunkArr[0].start() + i2);
                }
            }
            this._classes = null;
        }

        @Override // water.MRTask
        public void reduce(ClassIdxTask classIdxTask) {
            for (int i = 0; i < classIdxTask._indexes.length; i++) {
                for (int i2 = 0; i2 < classIdxTask._indexes[i].size(); i2++) {
                    this._indexes[i].add(classIdxTask._indexes[i].get(i2));
                }
            }
        }
    }

    /* loaded from: input_file:water/rapids/ast/prims/advmath/AstStratifiedSplit$LongAry.class */
    public static class LongAry extends Iced<LongAry> {
        long[] _ary;
        int _sz;

        public LongAry(long... jArr) {
            this._ary = new long[4];
            this._ary = jArr;
            this._sz = jArr.length;
        }

        public void add(long j) {
            if (this._sz == this._ary.length) {
                this._ary = Arrays.copyOf(this._ary, Math.max(4, this._ary.length * 2));
            }
            long[] jArr = this._ary;
            int i = this._sz;
            this._sz = i + 1;
            jArr[i] = j;
        }

        public long get(int i) {
            if (i >= this._sz) {
                throw new ArrayIndexOutOfBoundsException(i);
            }
            return this._ary[i];
        }

        public int size() {
            return this._sz;
        }

        public long[] toArray() {
            return Arrays.copyOf(this._ary, this._sz);
        }

        public void clear() {
            this._sz = 0;
        }
    }

    @Override // water.rapids.ast.AstPrimitive
    public String[] args() {
        return new String[]{"ary", "test_frac", "seed"};
    }

    @Override // water.rapids.ast.AstPrimitive
    public int nargs() {
        return 4;
    }

    @Override // water.rapids.ast.AstRoot
    public String str() {
        return "h2o.random_stratified_split";
    }

    @Override // water.rapids.ast.AstPrimitive
    public ValFrame apply(Env env, Env.StackHelp stackHelp, AstRoot[] astRootArr) {
        Frame frame = stackHelp.track(astRootArr[1].exec(env)).getFrame();
        double num = astRootArr[2].exec(env).getNum();
        long num2 = (long) astRootArr[3].exec(env).getNum();
        if (frame.numCols() != 1) {
            throw new IllegalArgumentException("Must give a single column to stratify against. Got: " + frame.numCols() + " columns.");
        }
        return new ValFrame(new Frame(Key.make(), new String[]{OUTPUT_COLUMN_NAME}, new Vec[]{split(frame.anyVec(), num, num2, OUTPUT_COLUMN_DOMAIN)}));
    }

    public static Vec split(Vec vec, double d, long j, String[] strArr) {
        checkIfCanStratifyBy(vec);
        long nextLong = j == -1 ? new Random().nextLong() : j;
        long[] domain = new VecUtils.CollectIntegerDomain().doAll(vec).domain();
        int length = domain.length;
        Vec makeCon = vec.makeCon(0.0d, (byte) 4);
        makeCon.setDomain(strArr);
        DKV.put(makeCon);
        ClassIdxTask doAll = new ClassIdxTask(length, domain).doAll(vec);
        HashSet hashSet = new HashSet();
        for (int i = 0; i < length; i++) {
            LongAry longAry = doAll._indexes[i];
            long max = Math.max(Math.round(longAry.size() * d), 1L);
            HashSet hashSet2 = new HashSet();
            int i2 = 0;
            int i3 = 0;
            while (i2 < max) {
                int nextDouble = (int) (RandomUtils.getRNG(i3 + nextLong).nextDouble() * longAry.size());
                if (hashSet2.contains(Long.valueOf(longAry.get(nextDouble)))) {
                    i3++;
                } else {
                    hashSet2.add(Long.valueOf(longAry.get(nextDouble)));
                    i2++;
                    i3++;
                }
            }
            hashSet.addAll(hashSet2);
        }
        new ClassAssignMRTask(hashSet).doAll(makeCon);
        return makeCon;
    }

    static void checkIfCanStratifyBy(Vec vec) {
        if (!vec.isCategorical() && (!vec.isNumeric() || !vec.isInt())) {
            throw new IllegalArgumentException("Stratification only applies to integer and categorical columns. Got: " + vec.get_type_str());
        }
        if (vec.length() > 2147483647L) {
            throw new IllegalArgumentException("Cannot stratified the frame because it is too long: nrows=" + vec.length());
        }
    }
}
