package org.jpmml.converter.support_vector_machine;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.List;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.regression.CategoricalPredictor;
import org.dmg.pmml.support_vector_machine.Coefficient;
import org.dmg.pmml.support_vector_machine.Coefficients;
import org.dmg.pmml.support_vector_machine.SupportVector;
import org.dmg.pmml.support_vector_machine.SupportVectorMachine;
import org.dmg.pmml.support_vector_machine.SupportVectorMachineModel;
import org.dmg.pmml.support_vector_machine.SupportVectors;
import org.dmg.pmml.support_vector_machine.VectorDictionary;
import org.dmg.pmml.support_vector_machine.VectorFields;
import org.dmg.pmml.support_vector_machine.VectorInstance;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.CMatrixUtil;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousLabel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Matrix;
import org.jpmml.converter.ModelUtil;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;

/* loaded from: input_file:org/jpmml/converter/support_vector_machine/LibSVMUtil.class */
public class LibSVMUtil {
    private LibSVMUtil() {
    }

    public static SupportVectorMachineModel createRegression(Matrix<Double> matrix, List<String> list, Double d, List<Double> list2, Schema schema) {
        ContinuousLabel continuousLabel = (ContinuousLabel) schema.getLabel();
        VectorDictionary createVectorDictionary = createVectorDictionary(matrix, list, schema);
        List vectorInstances = createVectorDictionary.getVectorInstances();
        ArrayList arrayList = new ArrayList();
        arrayList.add(createSupportVectorMachine(vectorInstances, d, list2));
        return new SupportVectorMachineModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(continuousLabel), createVectorDictionary, arrayList);
    }

    public static SupportVectorMachineModel createClassification(Matrix<Double> matrix, List<Integer> list, List<String> list2, List<Double> list3, List<Double> list4, Schema schema) {
        CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
        int rows = matrix.getRows();
        matrix.getColumns();
        VectorDictionary createVectorDictionary = createVectorDictionary(matrix, list2, schema);
        List vectorInstances = createVectorDictionary.getVectorInstances();
        ArrayList arrayList = new ArrayList();
        int[] iArr = new int[list.size() + 1];
        for (int i = 0; i < list.size(); i++) {
            iArr[i + 1] = iArr[i] + list.get(i).intValue();
        }
        int i2 = 0;
        int size = categoricalLabel.size();
        for (int i3 = 0; i3 < size; i3++) {
            for (int i4 = i3 + 1; i4 < size; i4++) {
                ArrayList arrayList2 = new ArrayList();
                arrayList2.addAll(slice(vectorInstances, iArr, i3));
                arrayList2.addAll(slice(vectorInstances, iArr, i4));
                Double d = list3.get(i2);
                ArrayList arrayList3 = new ArrayList();
                arrayList3.addAll(slice(CMatrixUtil.getRow(list4, size - 1, rows, i4 - 1), iArr, i3));
                arrayList3.addAll(slice(CMatrixUtil.getRow(list4, size - 1, rows, i3), iArr, i4));
                arrayList.add(createSupportVectorMachine(arrayList2, d, arrayList3).setTargetCategory(categoricalLabel.getValue(i3)).setAlternateTargetCategory(categoricalLabel.getValue(i4)));
                i2++;
            }
        }
        return new SupportVectorMachineModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), createVectorDictionary, arrayList).setClassificationMethod(SupportVectorMachineModel.ClassificationMethod.ONE_AGAINST_ONE);
    }

    public static VectorDictionary createVectorDictionary(Matrix<Double> matrix, List<String> list, Schema schema) {
        int rows = matrix.getRows();
        int columns = matrix.getColumns();
        if (columns != schema.getFeatures().size()) {
            throw new IllegalArgumentException();
        }
        BitSet bitSet = new BitSet(columns);
        Double valueOf = Double.valueOf(0.0d);
        for (int i = 0; i < rows; i++) {
            BitSet indices = ValueUtil.getIndices(matrix.getRowValues(i), valueOf);
            indices.flip(0, columns);
            bitSet.or(indices);
        }
        int cardinality = bitSet.cardinality();
        VectorFields vectorFields = new VectorFields();
        for (int i2 = 0; i2 < columns; i2++) {
            Feature feature = schema.getFeature(i2);
            if (bitSet.get(i2)) {
                if (feature instanceof BinaryFeature) {
                    BinaryFeature binaryFeature = (BinaryFeature) feature;
                    vectorFields.addContent(new PMMLObject[]{new CategoricalPredictor(binaryFeature.getName(), binaryFeature.getValue(), 1.0d)});
                } else {
                    vectorFields.addContent(new PMMLObject[]{feature.toContinuousFeature().ref()});
                }
            }
        }
        VectorDictionary vectorDictionary = new VectorDictionary(vectorFields);
        for (int i3 = 0; i3 < rows; i3++) {
            List<Double> rowValues = matrix.getRowValues(i3);
            if (cardinality < columns) {
                rowValues = ValueUtil.filterByIndices(rowValues, bitSet);
            }
            VectorInstance vectorInstance = new VectorInstance(list.get(i3));
            if (ValueUtil.isSparse(rowValues, valueOf, 0.75d)) {
                vectorInstance.setRealSparseArray(PMMLUtil.createRealSparseArray(rowValues, valueOf));
            } else {
                vectorInstance.setArray(PMMLUtil.createRealArray(rowValues));
            }
            vectorDictionary.addVectorInstances(new VectorInstance[]{vectorInstance});
        }
        return vectorDictionary;
    }

    public static SupportVectorMachine createSupportVectorMachine(List<VectorInstance> list, Double d, List<Double> list2) {
        if (list.size() != list2.size()) {
            throw new IllegalArgumentException();
        }
        Coefficients absoluteValue = new Coefficients().setAbsoluteValue(d);
        SupportVectors supportVectors = new SupportVectors();
        for (int i = 0; i < list.size(); i++) {
            VectorInstance vectorInstance = list.get(i);
            absoluteValue.addCoefficients(new Coefficient[]{new Coefficient().setValue(list2.get(i))});
            supportVectors.addSupportVectors(new SupportVector[]{new SupportVector(vectorInstance.getId())});
        }
        return new SupportVectorMachine(absoluteValue).setSupportVectors(supportVectors);
    }

    private static <E> List<E> slice(List<E> list, int[] iArr, int i) {
        return list.subList(iArr[i], iArr[i + 1]);
    }
}
