/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.ast.prims.models;

import hex.Model;
import java.util.Arrays;
import water.Key;
import water.Scope;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.PermutationVarImp;
import water.rapids.Val;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.util.ArrayUtils;
import water.util.TwoDimTable;

public class AstPermutationVarImp
extends AstPrimitive {
    @Override
    public int nargs() {
        return 8;
    }

    @Override
    public String[] args() {
        return new String[]{"model", "frame", "metric", "n_samples", "n_repeats", "features", "seed"};
    }

    @Override
    public String str() {
        return "PermutationVarImp";
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public ValFrame apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
        Key[] keyArray;
        Model model = stk.track(asts[1].exec(env)).getModel();
        Frame fr = stk.track(asts[2].exec(env)).getFrame();
        String metric = stk.track(asts[3].exec(env)).getStr().toLowerCase();
        long n_samples = (long)stk.track(asts[4].exec(env)).getNum();
        int n_repeats = (int)stk.track(asts[5].exec(env)).getNum();
        String[] features = null;
        Val featuresVal = stk.track(asts[6].exec(env));
        if (!featuresVal.isEmpty()) {
            features = featuresVal.getStrs();
        }
        long seed = (long)stk.track(asts[7].exec(env)).getNum();
        if (n_samples < -1L || n_samples == 0L || n_samples == 1L || n_samples > fr.numRows()) {
            throw new IllegalArgumentException("Argument n_samples has to be either -1 to use the whole frame or greater than 2 and lower than or equal to the number of rows of the provided frame!");
        }
        if (n_repeats < 1) {
            throw new IllegalArgumentException("Argument n_repeats must be greater than 0!");
        }
        if (features != null) {
            CharSequence[] notInFrame = (String[])Arrays.stream(features).filter(f -> !ArrayUtils.contains(fr.names(), f)).toArray(String[]::new);
            if (notInFrame.length > 0) {
                throw new IllegalArgumentException("Features " + String.join((CharSequence)", ", notInFrame) + " are not present in the provided frame!");
            }
            String[] notUsedInModel = (String[])Arrays.stream(features).filter(f -> !ArrayUtils.contains(((Model.Output)model._output)._origNames == null ? ((Model.Output)model._output)._names : ((Model.Output)model._output)._origNames, f)).toArray(String[]::new);
            if (notUsedInModel.length > 0) {
                throw new IllegalArgumentException("Features " + String.join((CharSequence)", ", notInFrame) + " weren't used for training!");
            }
        }
        Scope.enter();
        Frame pviFr = null;
        try {
            PermutationVarImp pvi = new PermutationVarImp(model, fr);
            TwoDimTable varImpTable = null;
            varImpTable = n_repeats > 1 ? pvi.getRepeatedPermutationVarImp(metric, n_samples, n_repeats, features, seed) : pvi.getPermutationVarImp(metric, n_samples, features, seed);
            pviFr = AstPermutationVarImp.varimpToFrame(varImpTable, Key.make(model._key + "permutationVarImp"));
            Scope.track(pviFr);
            keyArray = pviFr != null ? pviFr.keys() : new Key[]{};
        }
        catch (Throwable throwable) {
            Key[] keysToKeep = pviFr != null ? pviFr.keys() : new Key[]{};
            Scope.exit(keysToKeep);
            throw throwable;
        }
        Key[] keysToKeep = keyArray;
        Scope.exit(keysToKeep);
        return new ValFrame(pviFr);
    }

    private static Frame varimpToFrame(TwoDimTable twoDimTable, Key frameKey) {
        String[] colNames = new String[twoDimTable.getColDim() + 1];
        colNames[0] = "Variable";
        System.arraycopy(twoDimTable.getColHeaders(), 0, colNames, 1, twoDimTable.getColDim());
        Vec[] vecs = new Vec[colNames.length];
        vecs[0] = Vec.makeVec(twoDimTable.getRowHeaders(), Vec.newKey());
        double[] tmpRow = new double[twoDimTable.getRowDim()];
        for (int j = 0; j < twoDimTable.getColDim(); ++j) {
            for (int i = 0; i < twoDimTable.getRowDim(); ++i) {
                tmpRow[i] = (Double)twoDimTable.get(i, j);
            }
            vecs[j + 1] = Vec.makeVec(tmpRow, Vec.newKey());
        }
        Frame fr = new Frame(frameKey, colNames, vecs);
        return fr;
    }
}

