/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.ast.prims.advmath;

import hex.tfidf.DocumentFrequencyTask;
import hex.tfidf.InverseDocumentFrequencyTask;
import hex.tfidf.TermFrequencyTask;
import hex.tfidf.TfIdfPreprocessorTask;
import org.apache.log4j.Logger;
import water.Key;
import water.MRTask;
import water.Scope;
import water.fvec.CategoricalWrappedVec;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.Merge;
import water.rapids.Rapids;
import water.rapids.Val;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.ast.prims.string.AstToLower;
import water.rapids.vals.ValFrame;
import water.util.ArrayUtils;

public class AstTfIdf
extends AstPrimitive<AstTfIdf> {
    private static final String IDF_COL_NAME = "IDF";
    private static final String TF_IDF_COL_NAME = "TF-IDF";
    private static final String[] PREPROCESSED_FRAME_COL_NAMES = new String[]{"DocID", "Words"};
    private static Logger log = Logger.getLogger(AstTfIdf.class);

    @Override
    public int nargs() {
        return 6;
    }

    @Override
    public String[] args() {
        return new String[]{"frame", "doc_id_idx", "text_idx", "preprocess", "case_sensitive"};
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public Val apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
        Key[] keyArray;
        Frame inputFrame = stk.track(asts[1].exec(env).getFrame());
        int docIdIdx = (int)asts[2].exec(env).getNum();
        int contentIdx = (int)asts[3].exec(env).getNum();
        boolean preprocess = asts[4].exec(env).getBool();
        boolean caseSensitive = asts[5].exec(env).getBool();
        if (inputFrame.anyVec().length() <= 0L) {
            throw new IllegalArgumentException("Empty input frame provided.");
        }
        Scope.enter();
        Frame tfIdfFrame = null;
        try {
            long documentsCnt;
            Frame wordFrame;
            int inputFrameColsCnt = inputFrame.numCols();
            if (docIdIdx >= inputFrameColsCnt || contentIdx >= inputFrameColsCnt) {
                throw new IllegalArgumentException("Provided column index is out of bounds. Number of columns in the input frame: " + inputFrameColsCnt);
            }
            Vec docIdVec = inputFrame.vec(docIdIdx);
            Vec contentVec = inputFrame.vec(contentIdx);
            if (!docIdVec.isNumeric() || !contentVec.isString()) {
                throw new IllegalArgumentException("Incorrect format of input frame.Following row format is expected: (numeric) documentID, (string) " + (preprocess ? "documentContent." : "words. Got " + docIdVec.get_type_str() + " and " + contentVec.get_type_str() + " instead."));
            }
            if (!caseSensitive) {
                Scope.track(inputFrame.replace(contentIdx, AstToLower.toLowerStringCol(inputFrame.vec(contentIdx))));
            }
            if (preprocess) {
                byte[] outputTypes = new byte[]{3, 2};
                wordFrame = ((TfIdfPreprocessorTask)new TfIdfPreprocessorTask(docIdIdx, contentIdx).doAll(outputTypes, inputFrame)).outputFrame(PREPROCESSED_FRAME_COL_NAMES, null);
                documentsCnt = inputFrame.numRows();
            } else {
                String[] columnsNames = ArrayUtils.select(inputFrame.names(), new int[]{docIdIdx, contentIdx});
                wordFrame = inputFrame.subframe(columnsNames);
                String countDocumentsRapid = "(unique (cols " + asts[1].toString() + " [" + docIdIdx + "]))";
                documentsCnt = Rapids.exec(countDocumentsRapid).getFrame().anyVec().length();
            }
            Scope.track(wordFrame);
            Frame tfOutFrame = TermFrequencyTask.compute(wordFrame);
            Scope.track(tfOutFrame);
            Frame dfOutFrame = DocumentFrequencyTask.compute(tfOutFrame);
            Scope.track(dfOutFrame);
            InverseDocumentFrequencyTask idf = new InverseDocumentFrequencyTask(documentsCnt);
            Vec idfValues = ((InverseDocumentFrequencyTask)idf.doAll(new byte[]{3}, dfOutFrame.lastVec())).outputFrame().anyVec();
            Scope.track(idfValues);
            Vec removedCol = dfOutFrame.remove(dfOutFrame.numCols() - 1);
            Scope.track(removedCol);
            dfOutFrame.add(IDF_COL_NAME, idfValues);
            Scope.track(tfOutFrame.replace(1, tfOutFrame.vecs()[1].toCategoricalVec()));
            Scope.track(dfOutFrame.replace(0, dfOutFrame.vecs()[0].toCategoricalVec()));
            int[][] levelMaps = new int[][]{CategoricalWrappedVec.computeMap(tfOutFrame.vec(1).domain(), dfOutFrame.vec(0).domain())};
            Frame tfIdfIntermediate = Merge.merge(tfOutFrame, dfOutFrame, new int[]{1}, new int[]{0}, false, levelMaps);
            Scope.track(tfIdfIntermediate.replace(1, tfIdfIntermediate.vecs()[1].toStringVec()));
            int tfOutFrameColCnt = tfIdfIntermediate.numCols();
            TfIdfTask tfIdfTask = new TfIdfTask(tfOutFrameColCnt - 2, tfOutFrameColCnt - 1);
            Vec tfIdfValues = ((TfIdfTask)tfIdfTask.doAll(new byte[]{3}, tfIdfIntermediate)).outputFrame().anyVec();
            Scope.track(tfIdfValues);
            tfIdfIntermediate.add(TF_IDF_COL_NAME, tfIdfValues);
            tfIdfIntermediate._key = Key.make();
            if (log.isDebugEnabled()) {
                log.debug((Object)tfIdfIntermediate.toTwoDimTable().toString());
            }
            keyArray = (tfIdfFrame = tfIdfIntermediate) != null ? tfIdfFrame.keys() : new Key[]{};
        }
        catch (Throwable throwable) {
            Key[] keysToKeep = tfIdfFrame != null ? tfIdfFrame.keys() : new Key[]{};
            Scope.exit(keysToKeep);
            throw throwable;
        }
        Key[] keysToKeep = keyArray;
        Scope.exit(keysToKeep);
        return new ValFrame(tfIdfFrame);
    }

    @Override
    public String str() {
        return "tf-idf";
    }

    private static class TfIdfTask
    extends MRTask<TfIdfTask> {
        private final int _tfColIndex;
        private final int _idfColIndex;

        private TfIdfTask(int tfColIndex, int idfColIndex) {
            this._tfColIndex = tfColIndex;
            this._idfColIndex = idfColIndex;
        }

        @Override
        public void map(Chunk[] cs, NewChunk nc) {
            Chunk tfValues = cs[this._tfColIndex];
            Chunk idfValues = cs[this._idfColIndex];
            for (int row = 0; row < tfValues._len; ++row) {
                nc.addNum((double)tfValues.at8(row) * idfValues.atd(row));
            }
        }
    }
}

