package org.apache.sysds.runtime.transform.tokenize;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.transform.tokenize.Tokenizer;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;

/* loaded from: input_file:org/apache/sysds/runtime/transform/tokenize/TokenizerPostCount.class */
public class TokenizerPostCount implements TokenizerPost {
    private static final long serialVersionUID = 6382000606237705019L;
    private final Params params;
    private final int numIdCols;
    private final int maxTokens;
    private final boolean wideFormat;

    /* loaded from: input_file:org/apache/sysds/runtime/transform/tokenize/TokenizerPostCount$Params.class */
    static class Params implements Serializable {
        private static final long serialVersionUID = 5121697674346781880L;
        public boolean sort_alpha;

        public Params(JSONObject jSONObject) throws JSONException {
            this.sort_alpha = false;
            if (jSONObject == null || !jSONObject.has("sort_alpha")) {
                return;
            }
            this.sort_alpha = jSONObject.getBoolean("sort_alpha");
        }
    }

    public TokenizerPostCount(JSONObject jSONObject, int i, int i2, boolean z) throws JSONException {
        this.params = new Params(jSONObject);
        this.numIdCols = i;
        this.maxTokens = i2;
        this.wideFormat = z;
    }

    @Override // org.apache.sysds.runtime.transform.tokenize.TokenizerPost
    public FrameBlock tokenizePost(List<Tokenizer.DocumentToTokens> list, FrameBlock frameBlock) {
        for (Tokenizer.DocumentToTokens documentToTokens : list) {
            List<Object> list2 = documentToTokens.keys;
            List<Tokenizer.Token> list3 = documentToTokens.tokens;
            Map map = (Map) list3.stream().collect(Collectors.groupingBy(token -> {
                return token.textToken;
            }, Collectors.counting()));
            Stream distinct = list3.stream().map(token2 -> {
                return token2.textToken;
            }).distinct();
            if (this.params.sort_alpha) {
                distinct = distinct.sorted();
            }
            int i = 0;
            for (String str : (List) distinct.collect(Collectors.toList())) {
                if (i >= this.maxTokens) {
                    break;
                }
                long longValue = ((Long) map.get(str)).longValue();
                ArrayList arrayList = new ArrayList(list2);
                arrayList.add(str);
                arrayList.add(Long.valueOf(longValue));
                Object[] objArr = new Object[arrayList.size()];
                arrayList.toArray(objArr);
                frameBlock.appendRow(objArr);
                i++;
            }
        }
        return frameBlock;
    }

    @Override // org.apache.sysds.runtime.transform.tokenize.TokenizerPost
    public Types.ValueType[] getOutSchema() {
        if (this.wideFormat) {
            throw new IllegalArgumentException("Wide Format is not supported for Count Representation.");
        }
        Types.ValueType[] nCopies = UtilFunctions.nCopies(this.numIdCols + 2, Types.ValueType.STRING);
        nCopies[this.numIdCols + 1] = Types.ValueType.INT64;
        return nCopies;
    }

    @Override // org.apache.sysds.runtime.transform.tokenize.TokenizerPost
    public long getNumRows(long j) {
        return this.wideFormat ? j : j * this.maxTokens;
    }

    @Override // org.apache.sysds.runtime.transform.tokenize.TokenizerPost
    public long getNumCols() {
        return getOutSchema().length;
    }
}
