package com.datumbox.framework.core.machinelearning.common.abstracts.validators;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.dataobjects.FlatDataList;
import com.datumbox.framework.common.interfaces.Trainable;
import com.datumbox.framework.common.utilities.PHPMethods;
import com.datumbox.framework.core.machinelearning.common.abstracts.modelers.AbstractModeler;
import com.datumbox.framework.core.machinelearning.common.interfaces.ModelParameters;
import com.datumbox.framework.core.machinelearning.common.interfaces.TrainingParameters;
import com.datumbox.framework.core.machinelearning.common.interfaces.ValidationMetrics;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/common/abstracts/validators/AbstractValidator.class */
public abstract class AbstractValidator<MP extends ModelParameters, TP extends TrainingParameters, VM extends ValidationMetrics> {
    protected final Logger logger = LoggerFactory.getLogger(getClass());
    private static final String DB_INDICATOR = "Kfold";

    public VM kFoldCrossValidation(Dataframe dataframe, int i, String str, Configuration configuration, Class<? extends AbstractModeler> cls, TP tp) {
        int size = dataframe.size();
        if (i <= 0 || size <= i) {
            throw new IllegalArgumentException("Invalid number of folds.");
        }
        int i2 = size / i;
        Integer[] numArr = new Integer[size];
        int i3 = 0;
        Iterator it = dataframe.index().iterator();
        while (it.hasNext()) {
            numArr[i3] = (Integer) it.next();
            i3++;
        }
        PHPMethods.shuffle(numArr);
        String str2 = str + configuration.getDbConfig().getDBnameSeparator() + DB_INDICATOR;
        LinkedList linkedList = new LinkedList();
        for (int i4 = 0; i4 < i; i4++) {
            this.logger.info("Kfold {}", Integer.valueOf(i4));
            FlatDataList flatDataList = new FlatDataList(new ArrayList(size - i2));
            FlatDataList flatDataList2 = new FlatDataList(new ArrayList(i2));
            for (int i5 = 0; i5 < size; i5++) {
                boolean z = false;
                if (i4 * i2 <= i5 && i5 < (i4 + 1) * i2) {
                    z = true;
                }
                if (z) {
                    flatDataList2.add(numArr[i5]);
                } else {
                    flatDataList.add(numArr[i5]);
                }
            }
            if (i == 1) {
                flatDataList = flatDataList2;
            }
            AbstractModeler abstractModeler = (AbstractModeler) Trainable.newInstance(cls, str2 + (i4 + 1), configuration);
            Dataframe subset = dataframe.getSubset(flatDataList);
            abstractModeler.fit(subset, tp);
            subset.delete();
            Dataframe subset2 = dataframe.getSubset(flatDataList2);
            AbstractModeler.AbstractValidationMetrics validate = abstractModeler.validate(subset2);
            subset2.delete();
            abstractModeler.delete();
            linkedList.add(validate);
        }
        return calculateAverageValidationMetrics(linkedList);
    }

    protected abstract VM calculateAverageValidationMetrics(List<VM> list);
}
