package org.apache.sysds.runtime.transform.tokenize.applier;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.transform.tokenize.DocumentRepresentation;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;

/* loaded from: input_file:org/apache/sysds/runtime/transform/tokenize/applier/TokenizerApplierHash.class */
public class TokenizerApplierHash extends TokenizerApplier {
    private static final long serialVersionUID = 4763889041868044668L;
    public int num_features;
    private List<Map<Integer, Long>> hashes;

    public TokenizerApplierHash(int i, int i2, boolean z, boolean z2, JSONObject jSONObject) throws JSONException {
        super(i, i2, z, z2);
        this.num_features = 1048576;
        if (!z2 && z) {
            LOG.warn("ApplyPadding was set to 'false', Hash Tokenizer with wide format always has padding applied");
        }
        if (jSONObject == null || !jSONObject.has("num_features")) {
            return;
        }
        this.num_features = jSONObject.getInt("num_features");
    }

    @Override // org.apache.sysds.runtime.transform.tokenize.applier.TokenizerApplier
    public int getNumRows(DocumentRepresentation[] documentRepresentationArr) {
        return this.wideFormat ? documentRepresentationArr.length : this.applyPadding ? this.maxTokens * documentRepresentationArr.length : this.hashes.stream().mapToInt(map -> {
            return Math.min(map.size(), this.maxTokens);
        }).sum();
    }

    @Override // org.apache.sysds.runtime.transform.tokenize.applier.TokenizerApplier
    public void allocateInternalMeta(int i) {
        this.hashes = new ArrayList(Collections.nCopies(i, null));
    }

    @Override // org.apache.sysds.runtime.transform.tokenize.applier.TokenizerApplier
    public void build(DocumentRepresentation[] documentRepresentationArr, int i, int i2) {
        int endIndex = UtilFunctions.getEndIndex(documentRepresentationArr.length, i, i2);
        for (int i3 = i; i3 < endIndex; i3++) {
            this.hashes.set(i3, new TreeMap((Map) ((List) documentRepresentationArr[i3].tokens.stream().map(token -> {
                int hashCode = token.hashCode() % this.num_features;
                if (hashCode < 0) {
                    hashCode += this.num_features;
                }
                return Integer.valueOf(hashCode);
            }).collect(Collectors.toList())).stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()))));
        }
    }

    @Override // org.apache.sysds.runtime.transform.tokenize.applier.TokenizerApplier
    public int applyInternalRepresentation(DocumentRepresentation[] documentRepresentationArr, FrameBlock frameBlock, int i, int i2) {
        int endIndex = UtilFunctions.getEndIndex(documentRepresentationArr.length, i, i2);
        int outputRow = getOutputRow(i, this.hashes);
        for (int i3 = i; i3 < endIndex; i3++) {
            List<Object> list = documentRepresentationArr[i3].keys;
            Map<Integer, Long> map = this.hashes.get(i3);
            outputRow = this.wideFormat ? setTokensWide(outputRow, list, map, frameBlock) : setTokensLong(outputRow, list, map, frameBlock);
        }
        return outputRow;
    }

    private int setTokensLong(int i, List<Object> list, Map<Integer, Long> map, FrameBlock frameBlock) {
        int i2 = 0;
        for (Map.Entry<Integer, Long> entry : map.entrySet()) {
            if (i2 >= this.maxTokens) {
                break;
            }
            int keys = setKeys(i, list, frameBlock);
            int intValue = entry.getKey().intValue() + 1;
            long longValue = entry.getValue().longValue();
            frameBlock.set(i, keys, Long.valueOf(intValue));
            frameBlock.set(i, keys + 1, Long.valueOf(longValue));
            i2++;
            i++;
        }
        if (this.applyPadding) {
            i = applyPaddingLong(i, i2, list, frameBlock, "", 0L);
        }
        return i;
    }

    private int setTokensWide(int i, List<Object> list, Map<Integer, Long> map, FrameBlock frameBlock) {
        int keys = setKeys(i, list, frameBlock);
        for (int i2 = 0; i2 < this.maxTokens; i2++) {
            frameBlock.set(i, keys + i2, Long.valueOf(map.getOrDefault(Integer.valueOf(i2), 0L).longValue()));
        }
        return i + 1;
    }

    @Override // org.apache.sysds.runtime.transform.tokenize.applier.TokenizerApplier
    public Types.ValueType[] getOutSchema() {
        return this.wideFormat ? getOutSchemaWide(this.numIdCols, this.maxTokens) : getOutSchemaLong(this.numIdCols);
    }

    private static Types.ValueType[] getOutSchemaWide(int i, int i2) {
        Types.ValueType[] valueTypeArr = new Types.ValueType[i + i2];
        int i3 = 0;
        while (i3 < i) {
            valueTypeArr[i3] = Types.ValueType.STRING;
            i3++;
        }
        int i4 = 0;
        while (i4 < i2) {
            valueTypeArr[i3] = Types.ValueType.INT64;
            i4++;
            i3++;
        }
        return valueTypeArr;
    }

    private static Types.ValueType[] getOutSchemaLong(int i) {
        Types.ValueType[] nCopies = UtilFunctions.nCopies(i + 2, Types.ValueType.STRING);
        nCopies[i] = Types.ValueType.INT64;
        nCopies[i + 1] = Types.ValueType.INT64;
        return nCopies;
    }
}
