package org.apache.sysds.runtime.transform.encode;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang.ArrayUtils;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.Mean;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.TfUtils;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.ProgramConverter;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.Statistics;
import org.apache.wink.json4j.JSONArray;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;

/* loaded from: input_file:org/apache/sysds/runtime/transform/encode/EncoderMVImpute.class */
public class EncoderMVImpute extends LegacyEncoder {
    private static final long serialVersionUID = 9057868620144662194L;
    private final Mean _meanFn;
    private MVMethod[] _mvMethodList;
    private KahanObject[] _meanList;
    private long[] _countList;
    private String[] _replacementList;
    private List<Integer> _rcList;
    private HashMap<Integer, HashMap<String, Long>> _hist;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/transform/encode/EncoderMVImpute$ColInfo.class */
    public static class ColInfo {
        MVMethod _method;
        String _replacement;
        KahanObject _mean;
        long _count;
        HashMap<String, Long> _hist;
        static final /* synthetic */ boolean $assertionsDisabled;

        ColInfo(MVMethod mVMethod, String str, KahanObject kahanObject, long j, HashMap<String, Long> hashMap) {
            this._method = mVMethod;
            this._replacement = str;
            this._mean = kahanObject;
            this._count = j;
            this._hist = hashMap;
        }

        public void merge(ColInfo colInfo) {
            if (this._method != colInfo._method) {
                throw new DMLRuntimeException("Tried to merge two different impute methods: " + this._method.name() + " vs. " + colInfo._method.name());
            }
            switch (this._method) {
                case CONSTANT:
                    if (!$assertionsDisabled && !this._replacement.equals(colInfo._replacement)) {
                        throw new AssertionError();
                    }
                    return;
                case GLOBAL_MEAN:
                    this._mean._sum *= this._count;
                    this._mean._correction *= this._count;
                    KahanPlus.getKahanPlusFnObject().execute(this._mean, colInfo._mean._sum * colInfo._count);
                    KahanPlus.getKahanPlusFnObject().execute(this._mean, colInfo._mean._correction * colInfo._count);
                    this._count += colInfo._count;
                    return;
                case GLOBAL_MODE:
                    if (this._hist == null) {
                        this._hist = new HashMap<>(colInfo._hist);
                        return;
                    } else {
                        this._hist.replaceAll((str, l) -> {
                            return Long.valueOf(l.longValue() + colInfo._hist.getOrDefault(str, 0L).longValue());
                        });
                        return;
                    }
                default:
                    throw new DMLRuntimeException("Method `" + this._method.name() + "` not supported for federated impute");
            }
        }

        static {
            $assertionsDisabled = !EncoderMVImpute.class.desiredAssertionStatus();
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/transform/encode/EncoderMVImpute$MVMethod.class */
    public enum MVMethod {
        INVALID,
        GLOBAL_MEAN,
        GLOBAL_MODE,
        CONSTANT
    }

    public EncoderMVImpute(JSONObject jSONObject, String[] strArr, int i, int i2, int i3) throws JSONException {
        super(null, i);
        this._meanFn = Mean.getMeanFnObject();
        this._mvMethodList = null;
        this._meanList = null;
        this._countList = null;
        this._replacementList = null;
        this._rcList = null;
        this._hist = null;
        initColList(TfMetaUtils.parseJsonObjectIDList(jSONObject, strArr, TfUtils.TfMethod.IMPUTE.toString(), i2, i3));
        parseMethodsAndReplacements(jSONObject, strArr, i2);
        this._hist = new HashMap<>();
    }

    public EncoderMVImpute() {
        super(new int[0], 0);
        this._meanFn = Mean.getMeanFnObject();
        this._mvMethodList = null;
        this._meanList = null;
        this._countList = null;
        this._replacementList = null;
        this._rcList = null;
        this._hist = null;
    }

    public EncoderMVImpute(int[] iArr, MVMethod[] mVMethodArr, String[] strArr, KahanObject[] kahanObjectArr, long[] jArr, List<Integer> list, int i) {
        super(iArr, i);
        this._meanFn = Mean.getMeanFnObject();
        this._mvMethodList = null;
        this._meanList = null;
        this._countList = null;
        this._replacementList = null;
        this._rcList = null;
        this._hist = null;
        this._mvMethodList = mVMethodArr;
        this._replacementList = strArr;
        this._meanList = kahanObjectArr;
        this._countList = jArr;
        this._rcList = list;
    }

    private static void fillListsFromMap(Map<Integer, ColInfo> map, int[] iArr, MVMethod[] mVMethodArr, String[] strArr, KahanObject[] kahanObjectArr, long[] jArr, HashMap<Integer, HashMap<String, Long>> hashMap) {
        int i = 0;
        for (Map.Entry<Integer, ColInfo> entry : map.entrySet()) {
            iArr[i] = entry.getKey().intValue();
            mVMethodArr[i] = entry.getValue()._method;
            strArr[i] = entry.getValue()._replacement;
            kahanObjectArr[i] = entry.getValue()._mean;
            int i2 = i;
            i++;
            jArr[i2] = entry.getValue()._count;
            hashMap.put(entry.getKey(), entry.getValue()._hist);
        }
    }

    public String[] getReplacements() {
        return this._replacementList;
    }

    public KahanObject[] getMeans() {
        return this._meanList;
    }

    private void parseMethodsAndReplacements(JSONObject jSONObject, String[] strArr, int i) throws JSONException {
        JSONArray jSONArray = (JSONArray) jSONObject.get(TfUtils.TfMethod.IMPUTE.toString());
        boolean z = jSONObject.containsKey("ids") && jSONObject.getBoolean("ids");
        this._mvMethodList = new MVMethod[jSONArray.size()];
        this._replacementList = new String[jSONArray.size()];
        this._meanList = new KahanObject[jSONArray.size()];
        this._countList = new long[jSONArray.size()];
        Arrays.sort(this._colList);
        int i2 = 0;
        Iterator it = jSONArray.iterator();
        while (it.hasNext()) {
            JSONObject jSONObject2 = (JSONObject) it.next();
            if (Arrays.binarySearch(this._colList, z ? jSONObject2.getInt("id") - (i == -1 ? 0 : i - 1) : ArrayUtils.indexOf(strArr, jSONObject2.get("name")) + 1) >= 0) {
                this._mvMethodList[i2] = MVMethod.valueOf(jSONObject2.get("method").toString().toUpperCase());
                if (this._mvMethodList[i2] == MVMethod.CONSTANT) {
                    this._replacementList[i2] = jSONObject2.getString("value");
                }
                int i3 = i2;
                i2++;
                this._meanList[i3] = new KahanObject(DataExpression.DEFAULT_DELIM_FILL_VALUE, DataExpression.DEFAULT_DELIM_FILL_VALUE);
            }
        }
        this._mvMethodList = (MVMethod[]) Arrays.copyOf(this._mvMethodList, i2);
        this._replacementList = (String[]) Arrays.copyOf(this._replacementList, i2);
        this._meanList = (KahanObject[]) Arrays.copyOf(this._meanList, i2);
        this._countList = Arrays.copyOf(this._countList, i2);
    }

    public MVMethod getMethod(int i) {
        int isApplicable = isApplicable(i);
        return isApplicable == -1 ? MVMethod.INVALID : this._mvMethodList[isApplicable];
    }

    public long getNonMVCount(int i) {
        int isApplicable = isApplicable(i);
        if (isApplicable == -1) {
            return 0L;
        }
        return this._countList[isApplicable];
    }

    public String getReplacement(int i) {
        int isApplicable = isApplicable(i);
        if (isApplicable == -1) {
            return null;
        }
        return this._replacementList[isApplicable];
    }

    @Override // org.apache.sysds.runtime.transform.encode.LegacyEncoder
    public MatrixBlock encode(FrameBlock frameBlock, MatrixBlock matrixBlock) {
        build(frameBlock);
        return apply(frameBlock, matrixBlock);
    }

    @Override // org.apache.sysds.runtime.transform.encode.LegacyEncoder
    public void build(FrameBlock frameBlock) {
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        for (int i = 0; i < this._colList.length; i++) {
            try {
                int i2 = this._colList[i];
                if (this._mvMethodList[i] == MVMethod.GLOBAL_MEAN) {
                    long j = this._countList[i];
                    for (int i3 = 0; i3 < frameBlock.getNumRows(); i3++) {
                        Object obj = frameBlock.get(i3, i2 - 1);
                        if (obj == null) {
                            j--;
                        } else {
                            this._meanFn.execute2(this._meanList[i], UtilFunctions.objectToDouble(frameBlock.getSchema()[i2 - 1], obj), j + i3 + 1);
                        }
                    }
                    this._replacementList[i] = String.valueOf(this._meanList[i]._sum);
                    long[] jArr = this._countList;
                    int i4 = i;
                    jArr[i4] = jArr[i4] + frameBlock.getNumRows();
                } else if (this._mvMethodList[i] == MVMethod.GLOBAL_MODE) {
                    HashMap<String, Long> hashMap = this._hist.containsKey(Integer.valueOf(i2)) ? this._hist.get(Integer.valueOf(i2)) : new HashMap<>();
                    for (int i5 = 0; i5 < frameBlock.getNumRows(); i5++) {
                        String valueOf = String.valueOf(frameBlock.get(i5, i2 - 1));
                        if (!valueOf.equals(ProgramConverter.EMPTY) && !valueOf.isEmpty()) {
                            Long l = hashMap.get(valueOf);
                            hashMap.put(valueOf, Long.valueOf(l != null ? l.longValue() + 1 : 1L));
                        }
                    }
                    this._hist.put(Integer.valueOf(i2), hashMap);
                    long j2 = Long.MIN_VALUE;
                    for (Map.Entry<String, Long> entry : hashMap.entrySet()) {
                        if (entry.getValue().longValue() > j2) {
                            this._replacementList[i] = entry.getKey();
                            j2 = entry.getValue().longValue();
                        }
                    }
                }
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        if (DMLScript.STATISTICS) {
            Statistics.incTransformImputeBuildTime(System.nanoTime() - nanoTime);
        }
    }

    @Override // org.apache.sysds.runtime.transform.encode.LegacyEncoder
    public MatrixBlock apply(FrameBlock frameBlock, MatrixBlock matrixBlock) {
        long nanoTime = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        for (int i = 0; i < frameBlock.getNumRows(); i++) {
            for (int i2 = 0; i2 < this._colList.length; i2++) {
                int i3 = this._colList[i2];
                if (Double.isNaN(matrixBlock.quickGetValue(i, i3 - 1))) {
                    matrixBlock.quickSetValue(i, i3 - 1, Double.parseDouble(this._replacementList[i2]));
                }
            }
        }
        if (DMLScript.STATISTICS) {
            Statistics.incTransformImputeApplyTime(System.nanoTime() - nanoTime);
        }
        return matrixBlock;
    }

    @Override // org.apache.sysds.runtime.transform.encode.LegacyEncoder
    public LegacyEncoder subRangeEncoder(IndexRange indexRange) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this._colList.length; i++) {
            if (indexRange.inColRange(this._colList[i])) {
                hashMap.put(Integer.valueOf(this._colList[i]), new ColInfo(this._mvMethodList[i], this._replacementList[i], this._meanList[i], this._countList[i], this._hist.get(Integer.valueOf(i))));
            }
        }
        if (hashMap.size() == 0) {
            return null;
        }
        int[] iArr = new int[hashMap.size()];
        MVMethod[] mVMethodArr = new MVMethod[hashMap.size()];
        String[] strArr = new String[hashMap.size()];
        KahanObject[] kahanObjectArr = new KahanObject[hashMap.size()];
        long[] jArr = new long[hashMap.size()];
        fillListsFromMap(hashMap, iArr, mVMethodArr, strArr, kahanObjectArr, jArr, this._hist);
        if (this._rcList == null) {
            this._rcList = new ArrayList();
        }
        Stream<Integer> stream = this._rcList.stream();
        indexRange.getClass();
        return new EncoderMVImpute(iArr, mVMethodArr, strArr, kahanObjectArr, jArr, (List) stream.filter((v1) -> {
            return r1.inColRange(v1);
        }).map(num -> {
            return Integer.valueOf((int) (num.intValue() - (indexRange.colStart - 1)));
        }).collect(Collectors.toList()), (int) indexRange.colSpan());
    }

    @Override // org.apache.sysds.runtime.transform.encode.LegacyEncoder
    public void mergeAt(LegacyEncoder legacyEncoder, int i, int i2) {
        if (!(legacyEncoder instanceof EncoderMVImpute)) {
            super.mergeAt(legacyEncoder, i, i2);
            return;
        }
        EncoderMVImpute encoderMVImpute = (EncoderMVImpute) legacyEncoder;
        HashMap hashMap = new HashMap();
        for (int i3 = 0; i3 < this._colList.length; i3++) {
            hashMap.put(Integer.valueOf(this._colList[i3]), new ColInfo(this._mvMethodList[i3], this._replacementList[i3], this._meanList[i3], this._countList[i3], this._hist.get(Integer.valueOf(i3 + 1))));
        }
        for (int i4 = 0; i4 < legacyEncoder._colList.length; i4++) {
            int i5 = legacyEncoder._colList[i4];
            ColInfo colInfo = new ColInfo(encoderMVImpute._mvMethodList[i4], encoderMVImpute._replacementList[i4], encoderMVImpute._meanList[i4], encoderMVImpute._countList[i4], encoderMVImpute._hist.get(Integer.valueOf(i4 + 1)));
            ColInfo colInfo2 = (ColInfo) hashMap.get(Integer.valueOf(i5));
            if (colInfo2 == null) {
                hashMap.put(Integer.valueOf(i5), colInfo);
            } else {
                colInfo2.merge(colInfo);
            }
        }
        this._colList = new int[hashMap.size()];
        this._mvMethodList = new MVMethod[hashMap.size()];
        this._replacementList = new String[hashMap.size()];
        this._meanList = new KahanObject[hashMap.size()];
        this._countList = new long[hashMap.size()];
        this._hist = new HashMap<>();
        fillListsFromMap(hashMap, this._colList, this._mvMethodList, this._replacementList, this._meanList, this._countList, this._hist);
        if (this._rcList == null) {
            this._rcList = new ArrayList();
        }
        HashSet hashSet = new HashSet(this._rcList);
        hashSet.addAll((Collection) encoderMVImpute._rcList.stream().map(num -> {
            return Integer.valueOf(num.intValue() + (i2 - 1));
        }).collect(Collectors.toSet()));
        this._rcList = new ArrayList(hashSet);
    }

    @Override // org.apache.sysds.runtime.transform.encode.LegacyEncoder
    public FrameBlock getMetaData(FrameBlock frameBlock) {
        for (int i = 0; i < this._colList.length; i++) {
            frameBlock.getColumnMetadata(this._colList[i] - 1).setMvValue(this._replacementList[i]);
        }
        return frameBlock;
    }

    @Override // org.apache.sysds.runtime.transform.encode.LegacyEncoder
    public void initMetaData(FrameBlock frameBlock) {
        for (int i = 0; i < this._colList.length; i++) {
            int i2 = this._colList[i];
            String unquote = UtilFunctions.unquote(frameBlock.getColumnMetadata(i2 - 1).getMvValue());
            if (this._rcList.contains(Integer.valueOf(i2))) {
                Long l = frameBlock.getRecodeMap(i2 - 1).get(unquote);
                if (l == null) {
                    throw new RuntimeException("Missing recode value for impute value '" + unquote + "' (colID=" + i2 + ").");
                }
                this._replacementList[i] = l.toString();
            } else {
                this._replacementList[i] = unquote;
            }
        }
    }

    public void initRecodeIDList(List<Integer> list) {
        this._rcList = list;
    }

    public HashMap<String, Long> getHistogram(int i) {
        return this._hist.get(Integer.valueOf(i));
    }

    @Override // org.apache.sysds.runtime.transform.encode.LegacyEncoder, java.io.Externalizable
    public void writeExternal(ObjectOutput objectOutput) throws IOException {
        super.writeExternal(objectOutput);
        for (int i = 0; i < this._colList.length; i++) {
            objectOutput.writeByte(this._mvMethodList[i].ordinal());
            objectOutput.writeLong(this._countList[i]);
        }
        ArrayList arrayList = new ArrayList(Arrays.asList(this._replacementList));
        arrayList.removeAll(Collections.singleton(null));
        objectOutput.writeInt(arrayList.size());
        for (int i2 = 0; i2 < this._replacementList.length; i2++) {
            if (this._replacementList[i2] != null) {
                objectOutput.writeInt(i2);
                objectOutput.writeUTF(this._replacementList[i2]);
            }
        }
        objectOutput.writeInt(this._rcList.size());
        Iterator<Integer> it = this._rcList.iterator();
        while (it.hasNext()) {
            objectOutput.writeInt(it.next().intValue());
        }
        int size = this._hist == null ? 0 : this._hist.size();
        objectOutput.writeInt(size);
        if (size > 0) {
            for (Map.Entry<Integer, HashMap<String, Long>> entry : this._hist.entrySet()) {
                objectOutput.writeInt(entry.getKey().intValue());
                objectOutput.writeInt(entry.getValue().size());
                for (Map.Entry<String, Long> entry2 : entry.getValue().entrySet()) {
                    objectOutput.writeUTF(entry2.getKey());
                    objectOutput.writeLong(entry2.getValue().longValue());
                }
            }
        }
    }

    @Override // org.apache.sysds.runtime.transform.encode.LegacyEncoder, java.io.Externalizable
    public void readExternal(ObjectInput objectInput) throws IOException {
        super.readExternal(objectInput);
        this._mvMethodList = new MVMethod[this._colList.length];
        this._countList = new long[this._colList.length];
        this._meanList = new KahanObject[this._colList.length];
        this._replacementList = new String[this._colList.length];
        for (int i = 0; i < this._colList.length; i++) {
            this._mvMethodList[i] = MVMethod.values()[objectInput.readByte()];
            this._countList[i] = objectInput.readLong();
            this._meanList[i] = new KahanObject(DataExpression.DEFAULT_DELIM_FILL_VALUE, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        }
        int readInt = objectInput.readInt();
        for (int i2 = 0; i2 < readInt; i2++) {
            this._replacementList[objectInput.readInt()] = objectInput.readUTF();
        }
        int readInt2 = objectInput.readInt();
        this._rcList = new ArrayList();
        for (int i3 = 0; i3 < readInt2; i3++) {
            this._rcList.add(Integer.valueOf(objectInput.readInt()));
        }
        this._hist = new HashMap<>();
        int readInt3 = objectInput.readInt();
        for (int i4 = 0; i4 < readInt3; i4++) {
            Integer valueOf = Integer.valueOf(objectInput.readInt());
            int readInt4 = objectInput.readInt();
            HashMap<String, Long> hashMap = new HashMap<>();
            for (int i5 = 0; i5 < readInt4; i5++) {
                hashMap.put(objectInput.readUTF(), Long.valueOf(objectInput.readLong()));
            }
            this._hist.put(valueOf, hashMap);
        }
    }
}
