package org.apache.sysds.runtime.transform.tokenize;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.transform.tokenize.Tokenizer;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;

/* loaded from: input_file:org/apache/sysds/runtime/transform/tokenize/TokenizerPostHash.class */
public class TokenizerPostHash implements TokenizerPost {
    private static final long serialVersionUID = 4763889041868044668L;
    private final Params params;
    private final int numIdCols;
    private final int maxTokens;
    private final boolean wideFormat;

    /* loaded from: input_file:org/apache/sysds/runtime/transform/tokenize/TokenizerPostHash$Params.class */
    static class Params implements Serializable {
        private static final long serialVersionUID = -256069061414241795L;
        public int num_features;

        public Params(JSONObject jSONObject) throws JSONException {
            this.num_features = 1048576;
            if (jSONObject == null || !jSONObject.has("num_features")) {
                return;
            }
            this.num_features = jSONObject.getInt("num_features");
        }
    }

    public TokenizerPostHash(JSONObject jSONObject, int i, int i2, boolean z) throws JSONException {
        this.params = new Params(jSONObject);
        this.numIdCols = i;
        this.maxTokens = i2;
        this.wideFormat = z;
    }

    @Override // org.apache.sysds.runtime.transform.tokenize.TokenizerPost
    public FrameBlock tokenizePost(List<Tokenizer.DocumentToTokens> list, FrameBlock frameBlock) {
        for (Tokenizer.DocumentToTokens documentToTokens : list) {
            List<Object> list2 = documentToTokens.keys;
            TreeMap treeMap = new TreeMap((Map) ((List) documentToTokens.tokens.stream().map(token -> {
                return Integer.valueOf(token.textToken.hashCode() % this.params.num_features);
            }).collect(Collectors.toList())).stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting())));
            if (this.wideFormat) {
                appendTokensWide(list2, treeMap, frameBlock);
            } else {
                appendTokensLong(list2, treeMap, frameBlock);
            }
        }
        return frameBlock;
    }

    private void appendTokensLong(List<Object> list, Map<Integer, Long> map, FrameBlock frameBlock) {
        int i = 0;
        for (Map.Entry<Integer, Long> entry : map.entrySet()) {
            if (i >= this.maxTokens) {
                return;
            }
            int intValue = entry.getKey().intValue() + 1;
            long longValue = entry.getValue().longValue();
            ArrayList arrayList = new ArrayList(list);
            arrayList.add(Long.valueOf(intValue));
            arrayList.add(Long.valueOf(longValue));
            Object[] objArr = new Object[arrayList.size()];
            arrayList.toArray(objArr);
            frameBlock.appendRow(objArr);
            i++;
        }
    }

    private void appendTokensWide(List<Object> list, Map<Integer, Long> map, FrameBlock frameBlock) {
        ArrayList arrayList = new ArrayList(list);
        for (int i = 0; i < this.maxTokens; i++) {
            arrayList.add(Long.valueOf(map.getOrDefault(Integer.valueOf(i), 0L).longValue()));
        }
        Object[] objArr = new Object[arrayList.size()];
        arrayList.toArray(objArr);
        frameBlock.appendRow(objArr);
    }

    @Override // org.apache.sysds.runtime.transform.tokenize.TokenizerPost
    public Types.ValueType[] getOutSchema() {
        return this.wideFormat ? getOutSchemaWide(this.numIdCols, this.maxTokens) : getOutSchemaLong(this.numIdCols);
    }

    private static Types.ValueType[] getOutSchemaWide(int i, int i2) {
        Types.ValueType[] valueTypeArr = new Types.ValueType[i + i2];
        int i3 = 0;
        while (i3 < i) {
            valueTypeArr[i3] = Types.ValueType.STRING;
            i3++;
        }
        int i4 = 0;
        while (i4 < i2) {
            valueTypeArr[i3] = Types.ValueType.INT64;
            i4++;
            i3++;
        }
        return valueTypeArr;
    }

    private static Types.ValueType[] getOutSchemaLong(int i) {
        Types.ValueType[] nCopies = UtilFunctions.nCopies(i + 2, Types.ValueType.STRING);
        nCopies[i] = Types.ValueType.INT64;
        nCopies[i + 1] = Types.ValueType.INT64;
        return nCopies;
    }

    @Override // org.apache.sysds.runtime.transform.tokenize.TokenizerPost
    public long getNumRows(long j) {
        return this.wideFormat ? j : j * this.maxTokens;
    }

    @Override // org.apache.sysds.runtime.transform.tokenize.TokenizerPost
    public long getNumCols() {
        return getOutSchema().length;
    }
}
