package org.apache.sysds.runtime.transform.encode;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.federated.monitoring.repositories.Constants;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.TfUtils;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
import org.apache.sysds.runtime.util.CollectionUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.stats.TransformStatistics;
import org.apache.wink.json4j.JSONArray;
import org.apache.wink.json4j.JSONObject;

/* loaded from: input_file:org/apache/sysds/runtime/transform/encode/EncoderFactory.class */
public interface EncoderFactory {
    public static final Log LOG = LogFactory.getLog(EncoderFactory.class.getName());

    static MultiColumnEncoder createEncoder(String str, int i) {
        return createEncoder(str, null, i, null, null, -1, -1);
    }

    static MultiColumnEncoder createEncoder(String str, String[] strArr, int i, FrameBlock frameBlock) {
        return createEncoder(str, strArr, i, frameBlock, null, -1, -1);
    }

    static MultiColumnEncoder createEncoder(String str, String[] strArr, int i, FrameBlock frameBlock, int i2, int i3) {
        return createEncoder(str, strArr, i, frameBlock, null, i2, i3);
    }

    static MultiColumnEncoder createEncoder(String str, String[] strArr, Types.ValueType[] valueTypeArr, int i, FrameBlock frameBlock) {
        return createEncoder(str, strArr, i, frameBlock);
    }

    static MultiColumnEncoder createEncoder(String str, String[] strArr, Types.ValueType[] valueTypeArr, int i, FrameBlock frameBlock, MatrixBlock matrixBlock) {
        return createEncoder(str, strArr, valueTypeArr == null ? UtilFunctions.nCopies(i, Types.ValueType.STRING) : valueTypeArr, frameBlock, matrixBlock);
    }

    static MultiColumnEncoder createEncoder(String str, String[] strArr, Types.ValueType[] valueTypeArr, FrameBlock frameBlock) {
        return createEncoder(str, strArr, valueTypeArr, frameBlock, -1, -1);
    }

    static MultiColumnEncoder createEncoder(String str, String[] strArr, Types.ValueType[] valueTypeArr, FrameBlock frameBlock, int i, int i2) {
        return createEncoder(str, strArr, valueTypeArr.length, frameBlock, null, i, i2);
    }

    static MultiColumnEncoder createEncoder(String str, String[] strArr, int i, FrameBlock frameBlock, MatrixBlock matrixBlock) {
        return createEncoder(str, strArr, UtilFunctions.nCopies(i, Types.ValueType.STRING), frameBlock, matrixBlock);
    }

    static MultiColumnEncoder createEncoder(String str, String[] strArr, Types.ValueType[] valueTypeArr, FrameBlock frameBlock, MatrixBlock matrixBlock) {
        return createEncoder(str, strArr, valueTypeArr.length, frameBlock, matrixBlock, -1, -1);
    }

    static MultiColumnEncoder createEncoder(String str, String[] strArr, int i, FrameBlock frameBlock, MatrixBlock matrixBlock, int i2, int i3) {
        ColumnEncoderBin.BinMethod binMethod;
        try {
            JSONObject jSONObject = new JSONObject(str);
            ArrayList arrayList = new ArrayList();
            HashMap hashMap = new HashMap();
            boolean z = jSONObject.containsKey("ids") && jSONObject.getBoolean("ids");
            TfMetaUtils.checkValidEncoders(jSONObject);
            List asList = Arrays.asList(ArrayUtils.toObject(TfMetaUtils.parseJsonIDList(jSONObject, strArr, TfUtils.TfMethod.RECODE.toString(), i2, i3)));
            List asList2 = Arrays.asList(ArrayUtils.toObject(TfMetaUtils.parseJsonIDList(jSONObject, strArr, TfUtils.TfMethod.HASH.toString(), i2, i3)));
            List asList3 = Arrays.asList(ArrayUtils.toObject(TfMetaUtils.parseJsonIDList(jSONObject, strArr, TfUtils.TfMethod.DUMMYCODE.toString(), i2, i3)));
            List<Integer> parseBinningColIDs = TfMetaUtils.parseBinningColIDs(jSONObject, strArr, i2, i3);
            List asList4 = Arrays.asList(ArrayUtils.toObject(TfMetaUtils.parseJsonIDList(jSONObject, strArr, TfUtils.TfMethod.WORD_EMBEDDING.toString(), i2, i3)));
            if (!asList4.isEmpty() && matrixBlock == null) {
                throw new DMLRuntimeException("Missing argument Embeddings Matrix for transform [" + TfUtils.TfMethod.WORD_EMBEDDING + "]");
            }
            List<Integer> unionDistinct = CollectionUtils.unionDistinct(asList, CollectionUtils.except(CollectionUtils.except(asList3, parseBinningColIDs), asList2));
            if (CollectionUtils.intersect(unionDistinct, parseBinningColIDs, asList2, asList4)) {
                throw new DMLRuntimeException("More than one encoders (recode, binning, hashing, word_embedding) on one column is not allowed");
            }
            List except = CollectionUtils.except(CollectionUtils.except(CollectionUtils.except(UtilFunctions.getSeqList(1, i, 1), CollectionUtils.unionDistinct(unionDistinct, asList2)), parseBinningColIDs), asList4);
            List asList5 = Arrays.asList(ArrayUtils.toObject(TfMetaUtils.parseJsonIDList(jSONObject, strArr, TfUtils.TfMethod.OMIT.toString(), i2, i3)));
            List asList6 = Arrays.asList(ArrayUtils.toObject(TfMetaUtils.parseJsonObjectIDList(jSONObject, strArr, TfUtils.TfMethod.IMPUTE.toString(), i2, i3)));
            List<Integer> parseUDFColIDs = TfMetaUtils.parseUDFColIDs(jSONObject, strArr, i2, i3);
            if (!unionDistinct.isEmpty()) {
                Iterator<Integer> it = unionDistinct.iterator();
                while (it.hasNext()) {
                    addEncoderToMap(new ColumnEncoderRecode(it.next().intValue()), hashMap);
                }
            }
            if (!asList2.isEmpty()) {
                Iterator it2 = asList2.iterator();
                while (it2.hasNext()) {
                    addEncoderToMap(new ColumnEncoderFeatureHash(((Integer) it2.next()).intValue(), TfMetaUtils.getK(jSONObject)), hashMap);
                }
            }
            if (!except.isEmpty()) {
                Iterator it3 = except.iterator();
                while (it3.hasNext()) {
                    addEncoderToMap(new ColumnEncoderPassThrough(((Integer) it3.next()).intValue()), hashMap);
                }
            }
            if (!asList4.isEmpty()) {
                Iterator it4 = asList4.iterator();
                while (it4.hasNext()) {
                    addEncoderToMap(new ColumnEncoderWordEmbedding(((Integer) it4.next()).intValue()), hashMap);
                }
            }
            if (!parseBinningColIDs.isEmpty()) {
                Iterator it5 = ((JSONArray) jSONObject.get(TfUtils.TfMethod.BIN.toString())).iterator();
                while (it5.hasNext()) {
                    JSONObject jSONObject2 = (JSONObject) it5.next();
                    int i4 = jSONObject2.containsKey(TfUtils.JSON_NBINS) ? jSONObject2.getInt(TfUtils.JSON_NBINS) : 1;
                    int parseJsonObjectID = TfMetaUtils.parseJsonObjectID(jSONObject2, strArr, i2, i3, z);
                    if (parseJsonObjectID > 0) {
                        String upperCase = jSONObject2.get("method").toString().toUpperCase();
                        if ("EQUI-WIDTH".equals(upperCase)) {
                            binMethod = ColumnEncoderBin.BinMethod.EQUI_WIDTH;
                        } else if ("EQUI-HEIGHT".equals(upperCase)) {
                            binMethod = ColumnEncoderBin.BinMethod.EQUI_HEIGHT;
                        } else {
                            if (!"EQUI-HEIGHT-APPROX".equals(upperCase)) {
                                throw new DMLRuntimeException("Unsupported binning method: " + upperCase);
                            }
                            binMethod = ColumnEncoderBin.BinMethod.EQUI_HEIGHT_APPROX;
                        }
                        addEncoderToMap(new ColumnEncoderBin(parseJsonObjectID, i4, binMethod), hashMap);
                    }
                }
            }
            if (!asList3.isEmpty()) {
                Iterator it6 = asList3.iterator();
                while (it6.hasNext()) {
                    addEncoderToMap(new ColumnEncoderDummycode(((Integer) it6.next()).intValue()), hashMap);
                }
            }
            if (!parseUDFColIDs.isEmpty()) {
                String string = jSONObject.getJSONObject("udf").getString(Constants.ENTITY_NAME_COL);
                Iterator<Integer> it7 = parseUDFColIDs.iterator();
                while (it7.hasNext()) {
                    addEncoderToMap(new ColumnEncoderUDF(it7.next().intValue(), string), hashMap);
                }
            }
            for (Map.Entry entry : hashMap.entrySet()) {
                if (DMLScript.STATISTICS) {
                    TransformStatistics.incEncoderCount(((List) entry.getValue()).size());
                }
                arrayList.add(new ColumnEncoderComposite((List<ColumnEncoder>) entry.getValue()));
            }
            MultiColumnEncoder multiColumnEncoder = new MultiColumnEncoder(arrayList);
            if (!asList5.isEmpty()) {
                multiColumnEncoder.addReplaceLegacyEncoder(new EncoderOmit(jSONObject, strArr, i, i2, i3));
                if (DMLScript.STATISTICS) {
                    TransformStatistics.incEncoderCount(1L);
                }
            }
            if (!asList6.isEmpty()) {
                EncoderMVImpute encoderMVImpute = new EncoderMVImpute(jSONObject, strArr, i, i2, i3);
                encoderMVImpute.initRecodeIDList(unionDistinct);
                multiColumnEncoder.addReplaceLegacyEncoder(encoderMVImpute);
                if (DMLScript.STATISTICS) {
                    TransformStatistics.incEncoderCount(1L);
                }
            }
            if (frameBlock != null) {
                String[] columnNames = frameBlock.getColumnNames();
                if (!TfMetaUtils.isIDSpec(jSONObject) && strArr != null && columnNames != null && !Objects.deepEquals(strArr, columnNames)) {
                    HashMap<String, Integer> columnPositions = getColumnPositions(columnNames);
                    FrameBlock frameBlock2 = new FrameBlock(frameBlock.getSchema(), columnNames);
                    for (int i5 = 0; i5 < strArr.length; i5++) {
                        if (!columnPositions.containsKey(strArr[i5])) {
                            throw new DMLRuntimeException("Column name not found in meta data: " + strArr[i5] + " (meta: " + Arrays.toString(columnNames) + ")");
                        }
                        int intValue = columnPositions.get(strArr[i5]).intValue();
                        frameBlock2.setColumn(i5, frameBlock.getColumn(intValue));
                        frameBlock2.setColumnMetadata(i5, frameBlock.getColumnMetadata(intValue));
                    }
                    frameBlock = frameBlock2;
                }
                multiColumnEncoder.initMetaData(frameBlock);
            }
            if (!asList4.isEmpty()) {
                multiColumnEncoder.initEmbeddings(matrixBlock);
            }
            return multiColumnEncoder;
        } catch (Exception e) {
            throw new DMLRuntimeException(e);
        }
    }

    private static void addEncoderToMap(ColumnEncoder columnEncoder, HashMap<Integer, List<ColumnEncoder>> hashMap) {
        if (!hashMap.containsKey(Integer.valueOf(columnEncoder._colID))) {
            hashMap.put(Integer.valueOf(columnEncoder._colID), new ArrayList());
        }
        hashMap.get(Integer.valueOf(columnEncoder._colID)).add(columnEncoder);
    }

    static int getEncoderType(ColumnEncoder columnEncoder) {
        if (columnEncoder instanceof ColumnEncoderBin) {
            return ColumnEncoder.EncoderType.Bin.ordinal();
        }
        if (columnEncoder instanceof ColumnEncoderDummycode) {
            return ColumnEncoder.EncoderType.Dummycode.ordinal();
        }
        if (columnEncoder instanceof ColumnEncoderFeatureHash) {
            return ColumnEncoder.EncoderType.FeatureHash.ordinal();
        }
        if (columnEncoder instanceof ColumnEncoderPassThrough) {
            return ColumnEncoder.EncoderType.PassThrough.ordinal();
        }
        if (columnEncoder instanceof ColumnEncoderRecode) {
            return ColumnEncoder.EncoderType.Recode.ordinal();
        }
        if (columnEncoder instanceof ColumnEncoderWordEmbedding) {
            return ColumnEncoder.EncoderType.WordEmbedding.ordinal();
        }
        throw new DMLRuntimeException("Unsupported encoder type: " + columnEncoder.getClass().getCanonicalName());
    }

    static ColumnEncoder createInstance(int i) {
        ColumnEncoder.EncoderType encoderType = ColumnEncoder.EncoderType.values()[i];
        switch (encoderType) {
            case Bin:
                return new ColumnEncoderBin();
            case Dummycode:
                return new ColumnEncoderDummycode();
            case FeatureHash:
                return new ColumnEncoderFeatureHash();
            case PassThrough:
                return new ColumnEncoderPassThrough();
            case Recode:
                return new ColumnEncoderRecode();
            case WordEmbedding:
                return new ColumnEncoderWordEmbedding();
            default:
                throw new DMLRuntimeException("Unsupported encoder type: " + encoderType);
        }
    }

    private static HashMap<String, Integer> getColumnPositions(String[] strArr) {
        HashMap<String, Integer> hashMap = new HashMap<>();
        for (int i = 0; i < strArr.length; i++) {
            hashMap.put(strArr[i], Integer.valueOf(i));
        }
        return hashMap;
    }
}
