package water.rapids.ast.prims.models;

import hex.Model;
import java.util.Arrays;
import water.Key;
import water.Scope;
import water.fvec.Frame;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.PermutationVarImp;
import water.rapids.Val;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.vals.ValFrame;
import water.util.ArrayUtils;
import water.util.TwoDimTable;

/* loaded from: input_file:water/rapids/ast/prims/models/AstPermutationVarImp.class */
public class AstPermutationVarImp extends AstPrimitive {
    @Override // water.rapids.ast.AstPrimitive
    public int nargs() {
        return 8;
    }

    @Override // water.rapids.ast.AstPrimitive
    public String[] args() {
        return new String[]{"model", "frame", "metric", "n_samples", "n_repeats", "features", "seed"};
    }

    @Override // water.rapids.ast.AstRoot
    public String str() {
        return "PermutationVarImp";
    }

    @Override // water.rapids.ast.AstPrimitive
    public ValFrame apply(Env env, Env.StackHelp stackHelp, AstRoot[] astRootArr) {
        Model model = stackHelp.track(astRootArr[1].exec(env)).getModel();
        Frame frame = stackHelp.track(astRootArr[2].exec(env)).getFrame();
        String lowerCase = stackHelp.track(astRootArr[3].exec(env)).getStr().toLowerCase();
        long num = (long) stackHelp.track(astRootArr[4].exec(env)).getNum();
        int num2 = (int) stackHelp.track(astRootArr[5].exec(env)).getNum();
        String[] strArr = null;
        Val track = stackHelp.track(astRootArr[6].exec(env));
        if (!track.isEmpty()) {
            strArr = track.getStrs();
        }
        long num3 = (long) stackHelp.track(astRootArr[7].exec(env)).getNum();
        if (num < -1 || num == 0 || num == 1 || num > frame.numRows()) {
            throw new IllegalArgumentException("Argument n_samples has to be either -1 to use the whole frame or greater than 2 and lower than or equal to the number of rows of the provided frame!");
        }
        if (num2 < 1) {
            throw new IllegalArgumentException("Argument n_repeats must be greater than 0!");
        }
        if (strArr != null) {
            String[] strArr2 = (String[]) Arrays.stream(strArr).filter(str -> {
                return !ArrayUtils.contains(frame.names(), str);
            }).toArray(i -> {
                return new String[i];
            });
            if (strArr2.length > 0) {
                throw new IllegalArgumentException("Features " + String.join(", ", strArr2) + " are not present in the provided frame!");
            }
            if (((String[]) Arrays.stream(strArr).filter(str2 -> {
                return !ArrayUtils.contains(model._output._origNames == null ? model._output._names : model._output._origNames, str2);
            }).toArray(i2 -> {
                return new String[i2];
            })).length > 0) {
                throw new IllegalArgumentException("Features " + String.join(", ", strArr2) + " weren't used for training!");
            }
        }
        Scope.enter();
        Frame frame2 = null;
        try {
            PermutationVarImp permutationVarImp = new PermutationVarImp(model, frame);
            frame2 = varimpToFrame(num2 > 1 ? permutationVarImp.getRepeatedPermutationVarImp(lowerCase, num, num2, strArr, num3) : permutationVarImp.getPermutationVarImp(lowerCase, num, strArr, num3), Key.make(model._key + "permutationVarImp"));
            Scope.track(frame2);
            Scope.exit(frame2 != null ? frame2.keys() : new Key[0]);
            return new ValFrame(frame2);
        } catch (Throwable th) {
            Scope.exit(frame2 != null ? frame2.keys() : new Key[0]);
            throw th;
        }
    }

    private static Frame varimpToFrame(TwoDimTable twoDimTable, Key key) {
        String[] strArr = new String[twoDimTable.getColDim() + 1];
        strArr[0] = "Variable";
        System.arraycopy(twoDimTable.getColHeaders(), 0, strArr, 1, twoDimTable.getColDim());
        Vec[] vecArr = new Vec[strArr.length];
        vecArr[0] = Vec.makeVec(twoDimTable.getRowHeaders(), Vec.newKey());
        double[] dArr = new double[twoDimTable.getRowDim()];
        for (int i = 0; i < twoDimTable.getColDim(); i++) {
            for (int i2 = 0; i2 < twoDimTable.getRowDim(); i2++) {
                dArr[i2] = ((Double) twoDimTable.get(i2, i)).doubleValue();
            }
            vecArr[i + 1] = Vec.makeVec(dArr, Vec.newKey());
        }
        return new Frame(key, strArr, vecArr);
    }
}
