package water.rapids.ast.prims.advmath;

import hex.tfidf.DocumentFrequencyTask;
import hex.tfidf.InverseDocumentFrequencyTask;
import hex.tfidf.TermFrequencyTask;
import hex.tfidf.TfIdfPreprocessorTask;
import org.apache.log4j.Logger;
import water.Key;
import water.MRTask;
import water.Scope;
import water.fvec.CategoricalWrappedVec;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.Merge;
import water.rapids.Rapids;
import water.rapids.Val;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.ast.prims.string.AstToLower;
import water.rapids.vals.ValFrame;
import water.util.ArrayUtils;

/* loaded from: input_file:water/rapids/ast/prims/advmath/AstTfIdf.class */
public class AstTfIdf extends AstPrimitive<AstTfIdf> {
    private static final String IDF_COL_NAME = "IDF";
    private static final String TF_IDF_COL_NAME = "TF-IDF";
    private static final String[] PREPROCESSED_FRAME_COL_NAMES = {"DocID", "Words"};
    private static Logger log = Logger.getLogger((Class<?>) AstTfIdf.class);

    /* loaded from: input_file:water/rapids/ast/prims/advmath/AstTfIdf$TfIdfTask.class */
    private static class TfIdfTask extends MRTask<TfIdfTask> {
        private final int _tfColIndex;
        private final int _idfColIndex;

        private TfIdfTask(int i, int i2) {
            this._tfColIndex = i;
            this._idfColIndex = i2;
        }

        @Override // water.MRTask
        public void map(Chunk[] chunkArr, NewChunk newChunk) {
            Chunk chunk = chunkArr[this._tfColIndex];
            Chunk chunk2 = chunkArr[this._idfColIndex];
            for (int i = 0; i < chunk._len; i++) {
                newChunk.addNum(chunk.at8(i) * chunk2.atd(i));
            }
        }
    }

    @Override // water.rapids.ast.AstPrimitive
    public int nargs() {
        return 6;
    }

    @Override // water.rapids.ast.AstPrimitive
    public String[] args() {
        return new String[]{"frame", "doc_id_idx", "text_idx", "preprocess", "case_sensitive"};
    }

    /* JADX WARN: Type inference failed for: r0v101, types: [int[], int[][]] */
    @Override // water.rapids.ast.AstPrimitive
    public Val apply(Env env, Env.StackHelp stackHelp, AstRoot[] astRootArr) {
        Frame subframe;
        long length;
        Frame track = stackHelp.track(astRootArr[1].exec(env).getFrame());
        int num = (int) astRootArr[2].exec(env).getNum();
        int num2 = (int) astRootArr[3].exec(env).getNum();
        boolean bool = astRootArr[4].exec(env).getBool();
        boolean bool2 = astRootArr[5].exec(env).getBool();
        if (track.anyVec().length() <= 0) {
            throw new IllegalArgumentException("Empty input frame provided.");
        }
        Scope.enter();
        Frame frame = null;
        try {
            int numCols = track.numCols();
            if (num >= numCols || num2 >= numCols) {
                throw new IllegalArgumentException("Provided column index is out of bounds. Number of columns in the input frame: " + numCols);
            }
            Vec vec = track.vec(num);
            Vec vec2 = track.vec(num2);
            if (!vec.isNumeric() || !vec2.isString()) {
                throw new IllegalArgumentException("Incorrect format of input frame.Following row format is expected: (numeric) documentID, (string) " + (bool ? "documentContent." : "words. Got " + vec.get_type_str() + " and " + vec2.get_type_str() + " instead."));
            }
            if (!bool2) {
                Scope.track(track.replace(num2, AstToLower.toLowerStringCol(track.vec(num2))));
            }
            if (bool) {
                subframe = new TfIdfPreprocessorTask(num, num2).doAll(new byte[]{3, 2}, track).outputFrame(PREPROCESSED_FRAME_COL_NAMES, (String[][]) null);
                length = track.numRows();
            } else {
                subframe = track.subframe(ArrayUtils.select(track.names(), new int[]{num, num2}));
                length = Rapids.exec("(unique (cols " + astRootArr[1].toString() + " [" + num + "]) false)").getFrame().anyVec().length();
            }
            Scope.track(subframe);
            Frame compute = TermFrequencyTask.compute(subframe);
            Scope.track(compute);
            Frame compute2 = DocumentFrequencyTask.compute(compute);
            Scope.track(compute2);
            Vec anyVec = new InverseDocumentFrequencyTask(length).doAll(new byte[]{3}, compute2.lastVec()).outputFrame().anyVec();
            Scope.track(anyVec);
            Scope.track(compute2.remove(compute2.numCols() - 1));
            compute2.add(IDF_COL_NAME, anyVec);
            Scope.track(compute.replace(1, compute.vecs()[1].toCategoricalVec()));
            Scope.track(compute2.replace(0, compute2.vecs()[0].toCategoricalVec()));
            Frame merge = Merge.merge(compute, compute2, new int[]{1}, new int[]{0}, false, new int[]{CategoricalWrappedVec.computeMap(compute.vec(1).domain(), compute2.vec(0).domain())});
            Scope.track(merge.replace(1, merge.vecs()[1].toStringVec()));
            int numCols2 = merge.numCols();
            Vec anyVec2 = new TfIdfTask(numCols2 - 2, numCols2 - 1).doAll(new byte[]{3}, merge).outputFrame().anyVec();
            Scope.track(anyVec2);
            merge.add(TF_IDF_COL_NAME, anyVec2);
            merge._key = Key.make();
            if (log.isDebugEnabled()) {
                log.debug(merge.toTwoDimTable().toString());
            }
            Scope.exit(merge != null ? merge.keys() : new Key[0]);
            return new ValFrame(merge);
        } catch (Throwable th) {
            Scope.exit(0 != 0 ? frame.keys() : new Key[0]);
            throw th;
        }
    }

    @Override // water.rapids.ast.AstRoot
    public String str() {
        return "tf-idf";
    }
}
