package org.nd4j.linalg.dataset.api.preprocessor.classimbalance;

import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;

/* loaded from: input_file:org/nd4j/linalg/dataset/api/preprocessor/classimbalance/UnderSamplingByMaskingMultiDataSetPreProcessor.class */
public class UnderSamplingByMaskingMultiDataSetPreProcessor extends BaseUnderSamplingPreProcessor implements MultiDataSetPreProcessor {
    private Map<Integer, Double> targetMinorityDistMap;
    private Map<Integer, Integer> minorityLabelMap = new HashMap();

    public UnderSamplingByMaskingMultiDataSetPreProcessor(Map<Integer, Double> map, int i) {
        for (Integer num : map.keySet()) {
            if (map.get(num).doubleValue() > 0.5d || map.get(num).doubleValue() <= 0.0d) {
                throw new IllegalArgumentException("Target distribution for the minority label class has to be greater than 0 and no greater than 0.5. Target distribution of " + map.get(num) + "given for label at index " + num);
            }
            this.minorityLabelMap.put(num, 1);
        }
        this.targetMinorityDistMap = map;
        this.tbpttWindowSize = i;
    }

    public void overrideMinorityDefault(int i) {
        if (!this.targetMinorityDistMap.containsKey(Integer.valueOf(i))) {
            throw new IllegalArgumentException("Index specified is not contained in the target minority distribution map specified with the preprocessor. Map contains " + ArrayUtils.toString(this.targetMinorityDistMap.keySet().toArray()));
        }
        this.minorityLabelMap.put(Integer.valueOf(i), 0);
    }

    @Override // org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor
    public void preProcess(MultiDataSet multiDataSet) {
        for (Integer num : this.targetMinorityDistMap.keySet()) {
            INDArray labels = multiDataSet.getLabels(num.intValue());
            INDArray labelsMaskArray = multiDataSet.getLabelsMaskArray(num.intValue());
            double doubleValue = this.targetMinorityDistMap.get(num).doubleValue();
            multiDataSet.setLabelsMaskArray(num.intValue(), adjustMasks(labels, labelsMaskArray, this.minorityLabelMap.get(num).intValue(), doubleValue));
        }
    }
}
