/*
 * Decompiled with CFR 0.152.
 */
package org.kie.pmml.compiler.commons.utils;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import javax.xml.bind.JAXBException;
import org.dmg.pmml.DataField;
import org.dmg.pmml.DataType;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MathContext;
import org.dmg.pmml.MiningField;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMML;
import org.dmg.pmml.ResultFeature;
import org.dmg.pmml.Targets;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.jpmml.model.PMMLUtil;
import org.kie.pmml.api.exceptions.KiePMMLException;
import org.xml.sax.SAXException;

public class KiePMMLUtil {
    public static final String SEGMENTID_TEMPLATE = "%sSegment%s";
    static final String MODELNAME_TEMPLATE = "%s%s%s";
    static final String SEGMENTMODELNAME_TEMPLATE = "Segment%s%s";
    static final String TARGETFIELD_TEMPLATE = "target%s";

    private KiePMMLUtil() {
    }

    public static PMML load(String source) throws SAXException, JAXBException {
        return KiePMMLUtil.load(new ByteArrayInputStream(source.getBytes()), "");
    }

    public static PMML load(InputStream is, String fileName) throws SAXException, JAXBException {
        PMML toReturn = PMMLUtil.unmarshal((InputStream)is);
        String cleanedFileName = fileName.contains(".") ? fileName.substring(0, fileName.indexOf(46)) : fileName;
        List dataFields = toReturn.getDataDictionary().getDataFields();
        List models = toReturn.getModels();
        for (int i = 0; i < models.size(); ++i) {
            Model model = (Model)models.get(i);
            KiePMMLUtil.populateMissingModelName(model, cleanedFileName, i);
            KiePMMLUtil.populateMissingOutputFieldDataType(model, dataFields);
            KiePMMLUtil.populateMissingMiningTargetField(model, dataFields);
            KiePMMLUtil.populateMissingPredictedOutputFieldTarget(model);
            if (!(model instanceof MiningModel)) continue;
            KiePMMLUtil.populateCorrectMiningModel((MiningModel)model);
        }
        return toReturn;
    }

    static void populateMissingModelName(Model model, String fileName, int i) {
        if (model.getModelName() == null || model.getModelName().isEmpty()) {
            String modelName = String.format(MODELNAME_TEMPLATE, fileName, model.getClass().getSimpleName(), i);
            model.setModelName(modelName);
        }
    }

    static void populateMissingMiningTargetField(Model model, List<DataField> dataFields) {
        List<MiningField> miningTargetFields = KiePMMLUtil.getMiningTargetFields(model.getMiningSchema().getMiningFields());
        if (miningTargetFields.isEmpty()) {
            Optional<DataField> targetDataField = KiePMMLUtil.getTargetDataField(model);
            targetDataField.ifPresent(dataField -> {
                dataFields.add((DataField)dataField);
                MiningField targetMiningField = KiePMMLUtil.getTargetMiningField(dataField);
                model.getMiningSchema().addMiningFields(new MiningField[]{targetMiningField});
                KiePMMLUtil.correctTargetFields(targetMiningField, model.getTargets());
            });
        }
    }

    static void populateMissingPredictedOutputFieldTarget(Model model) {
        if (model.getOutput() != null && model.getMiningSchema() != null) {
            Optional<OutputField> predictedOutputField = model.getOutput().getOutputFields().stream().filter(outputField -> (outputField.getResultFeature() == null || outputField.getResultFeature().equals((Object)ResultFeature.PREDICTED_VALUE)) && outputField.getTargetField() == null).findFirst();
            predictedOutputField.ifPresent(outputField -> {
                List<MiningField> targetFields = KiePMMLUtil.getMiningTargetFields(model.getMiningSchema().getMiningFields());
                if (!targetFields.isEmpty()) {
                    outputField.setTargetField(targetFields.get(0).getName());
                }
            });
        }
    }

    static Optional<DataField> getTargetDataField(Model model) {
        DataType targetDataType = KiePMMLUtil.getTargetDataType(model.getMiningFunction(), model.getMathContext());
        OpType targetOpType = KiePMMLUtil.getTargetOpType(model.getMiningFunction());
        if (targetDataType == null || targetOpType == null) {
            return Optional.empty();
        }
        String cleanedName = model.getModelName().replaceAll("[^A-Za-z0-9]", "");
        String fieldName = String.format(TARGETFIELD_TEMPLATE, cleanedName);
        DataField toReturn = new DataField();
        toReturn.setName(FieldName.create((String)fieldName));
        toReturn.setOpType(targetOpType);
        toReturn.setDataType(targetDataType);
        return Optional.of(toReturn);
    }

    static DataType getTargetDataType(MiningFunction miningFunction, MathContext mathContext) {
        switch (miningFunction) {
            case REGRESSION: {
                return DataType.fromValue((String)mathContext.value());
            }
            case CLASSIFICATION: 
            case CLUSTERING: {
                return DataType.STRING;
            }
        }
        return null;
    }

    static OpType getTargetOpType(MiningFunction miningFunction) {
        switch (miningFunction) {
            case REGRESSION: {
                return OpType.CONTINUOUS;
            }
            case CLASSIFICATION: 
            case CLUSTERING: {
                return OpType.CATEGORICAL;
            }
        }
        return null;
    }

    static MiningField getTargetMiningField(DataField dataField) {
        MiningField toReturn = new MiningField();
        toReturn.setName(dataField.getName());
        toReturn.setUsageType(MiningField.UsageType.TARGET);
        return toReturn;
    }

    static void correctTargetFields(MiningField targetMiningField, Targets targets) {
        if (targets != null && !targets.getTargets().isEmpty()) {
            List targetsFields = targets.getTargets();
            targetsFields.stream().filter(targetField -> targetField.getField() == null).forEach(targetField -> targetField.setField(targetMiningField.getName()));
        }
    }

    static void populateCorrectMiningModel(MiningModel miningModel) {
        List segments = miningModel.getSegmentation().getSegments();
        for (int i = 0; i < segments.size(); ++i) {
            Segment segment = (Segment)segments.get(i);
            KiePMMLUtil.populateCorrectSegmentId(segment, miningModel.getModelName(), i);
            Model model = segment.getModel();
            KiePMMLUtil.populateMissingSegmentModelName(model, segment.getId());
            KiePMMLUtil.populateMissingTargetFieldInSegment(miningModel.getMiningSchema(), model);
            KiePMMLUtil.populateMissingPredictedOutputFieldTarget(model);
            if (!(model instanceof MiningModel)) continue;
            KiePMMLUtil.populateCorrectMiningModel((MiningModel)segment.getModel());
        }
    }

    static void populateCorrectSegmentId(Segment segment, String modelName, int i) {
        String toSet = segment.getId() == null || segment.getId().isEmpty() ? String.format(SEGMENTID_TEMPLATE, modelName, i) : KiePMMLUtil.getSanitizedId(segment.getId(), modelName);
        segment.setId(toSet);
    }

    static void populateMissingSegmentModelName(Model model, String segmentId) {
        String modelName = String.format(SEGMENTMODELNAME_TEMPLATE, segmentId, model.getClass().getSimpleName());
        model.setModelName(modelName);
    }

    static void populateMissingTargetFieldInSegment(MiningSchema parentMiningSchema, Model childrenModel) {
        List<MiningField> parentTargetFields = KiePMMLUtil.getMiningTargetFields(parentMiningSchema.getMiningFields());
        List<MiningField> childrenTargetFields = KiePMMLUtil.getMiningTargetFields(childrenModel.getMiningSchema().getMiningFields());
        if (childrenTargetFields.isEmpty()) {
            childrenModel.getMiningSchema().addMiningFields(parentTargetFields.toArray(new MiningField[parentTargetFields.size()]));
        }
    }

    static void populateMissingOutputFieldDataType(Model model, List<DataField> dataFields) {
        if (model.getOutput() != null && model.getOutput().getOutputFields() != null) {
            KiePMMLUtil.populateMissingOutputFieldDataType(model.getOutput().getOutputFields(), model.getMiningSchema().getMiningFields(), dataFields);
        }
    }

    static void populateMissingOutputFieldDataType(List<OutputField> toPopulate, List<MiningField> miningFields, List<DataField> dataFields) {
        List<MiningField> targetFields = KiePMMLUtil.getMiningTargetFields(miningFields);
        toPopulate.stream().filter(outputField -> outputField.getDataType() == null).forEach(outputField -> {
            MiningField referencedField = null;
            if (outputField.getTargetField() != null) {
                referencedField = targetFields.stream().filter(targetField -> outputField.getTargetField().equals((Object)targetField.getName())).findFirst().orElseThrow(() -> new KiePMMLException("Failed to find a target field for OutputField " + outputField.getName().getValue()));
            }
            if (referencedField == null && (outputField.getResultFeature() == null || outputField.getResultFeature().equals((Object)ResultFeature.PREDICTED_VALUE))) {
                referencedField = targetFields.stream().findFirst().orElse(null);
            }
            if (referencedField == null && ResultFeature.PROBABILITY.equals((Object)outputField.getResultFeature())) {
                outputField.setDataType(DataType.DOUBLE);
                return;
            }
            if (referencedField != null) {
                FieldName targetFieldName = referencedField.getName();
                DataField dataField = dataFields.stream().filter(df -> df.getName().equals((Object)targetFieldName)).findFirst().orElseThrow(() -> new KiePMMLException("Failed to find a DataField field for MiningField " + targetFieldName.toString()));
                outputField.setDataType(dataField.getDataType());
            }
        });
    }

    static String getSanitizedId(String id, String modelName) {
        String toReturn = id.replace(".", "").replace(",", "");
        try {
            Integer.parseInt(toReturn);
            toReturn = String.format(SEGMENTID_TEMPLATE, modelName, id);
        }
        catch (NumberFormatException numberFormatException) {
            // empty catch block
        }
        return toReturn;
    }

    static List<MiningField> getMiningTargetFields(MiningSchema miningSchema) {
        return KiePMMLUtil.getMiningTargetFields(miningSchema.getMiningFields());
    }

    static List<MiningField> getMiningTargetFields(List<MiningField> miningFields) {
        return miningFields.stream().filter(miningField -> MiningField.UsageType.PREDICTED.equals((Object)miningField.getUsageType()) || MiningField.UsageType.TARGET.equals((Object)miningField.getUsageType())).collect(Collectors.toList());
    }
}

