package ai.h2o.targetencoding;

import ai.h2o.targetencoding.TargetEncoderModel;
import ai.h2o.targetencoding.interaction.InteractionSupport;
import hex.ModelBuilder;
import hex.ModelCategory;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import water.DKV;
import water.Key;
import water.Scope;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Frame;
import water.fvec.Vec;
import water.logging.Logger;
import water.logging.LoggerFactory;
import water.util.IcedHashMap;

/* loaded from: input_file:ai/h2o/targetencoding/TargetEncoder.class */
public class TargetEncoder extends ModelBuilder<TargetEncoderModel, TargetEncoderModel.TargetEncoderParameters, TargetEncoderModel.TargetEncoderOutput> {
    private static final Logger logger;
    private TargetEncoderModel _targetEncoderModel;
    private String[][] _columnsToEncode;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:ai/h2o/targetencoding/TargetEncoder$TargetEncoderDriver.class */
    private class TargetEncoderDriver extends ModelBuilder<TargetEncoderModel, TargetEncoderModel.TargetEncoderParameters, TargetEncoderModel.TargetEncoderOutput>.Driver {
        private TargetEncoderDriver() {
            super(TargetEncoder.this);
        }

        public void computeImpl() {
            TargetEncoder.this._targetEncoderModel = null;
            try {
                try {
                    TargetEncoder.this.init(true);
                    if (TargetEncoder.this.error_count() > 0) {
                        throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(TargetEncoder.this);
                    }
                    TargetEncoderModel.TargetEncoderOutput targetEncoderOutput = new TargetEncoderModel.TargetEncoderOutput(TargetEncoder.this);
                    TargetEncoder.this._targetEncoderModel = new TargetEncoderModel(TargetEncoder.this.dest(), (TargetEncoderModel.TargetEncoderParameters) TargetEncoder.this._parms, targetEncoderOutput).delete_and_lock(TargetEncoder.this._job);
                    Frame frame = new Frame(TargetEncoder.this.train());
                    ColumnsToSingleMapping[] columnsToSingleMappingArr = new ColumnsToSingleMapping[TargetEncoder.this._columnsToEncode.length];
                    for (int i = 0; i < columnsToSingleMappingArr.length; i++) {
                        String[] strArr = TargetEncoder.this._columnsToEncode[i];
                        int addFeatureInteraction = InteractionSupport.addFeatureInteraction(frame, strArr);
                        columnsToSingleMappingArr[i] = new ColumnsToSingleMapping(strArr, frame.name(addFeatureInteraction), frame.vec(addFeatureInteraction).domain());
                    }
                    IcedHashMap<String, Frame> prepareEncodingMap = prepareEncodingMap(frame, (String[]) Arrays.stream(columnsToSingleMappingArr).map((v0) -> {
                        return v0.toSingle();
                    }).toArray(i2 -> {
                        return new String[i2];
                    }));
                    Iterator it = prepareEncodingMap.entrySet().iterator();
                    while (it.hasNext()) {
                        Scope.untrack(new Frame[]{(Frame) ((Map.Entry) it.next()).getValue()});
                    }
                    targetEncoderOutput.init(prepareEncodingMap, columnsToSingleMappingArr);
                    TargetEncoder.this._job.update(1L);
                    if (TargetEncoder.this._targetEncoderModel != null) {
                        TargetEncoder.this._targetEncoderModel.update(TargetEncoder.this._job);
                        TargetEncoder.this._targetEncoderModel.unlock(TargetEncoder.this._job);
                    }
                } catch (Exception e) {
                    if (TargetEncoder.this._targetEncoderModel != null) {
                        Scope.track_generic(TargetEncoder.this._targetEncoderModel);
                    }
                    throw e;
                }
            } catch (Throwable th) {
                if (TargetEncoder.this._targetEncoderModel != null) {
                    TargetEncoder.this._targetEncoderModel.update(TargetEncoder.this._job);
                    TargetEncoder.this._targetEncoderModel.unlock(TargetEncoder.this._job);
                }
                throw th;
            }
        }

        private Frame filterOutNAsFromTargetColumn(Frame frame, int i) {
            return TargetEncoderHelper.filterOutNAsInColumn(frame, i);
        }

        private IcedHashMap<String, Frame> prepareEncodingMap(Frame frame, String[] strArr) {
            Frame frame2 = null;
            try {
                int find = frame.find(((TargetEncoderModel.TargetEncoderParameters) TargetEncoder.this._parms)._response_column);
                int find2 = ((TargetEncoderModel.TargetEncoderParameters) TargetEncoder.this._parms)._fold_column == null ? -1 : frame.find(((TargetEncoderModel.TargetEncoderParameters) TargetEncoder.this._parms)._fold_column);
                frame2 = filterOutNAsFromTargetColumn(frame, find);
                IcedHashMap<String, Frame> icedHashMap = new IcedHashMap<>();
                for (String str : strArr) {
                    int find3 = frame2.find(str);
                    TargetEncoderHelper.imputeCategoricalColumn(frame2, find3, str + "_NA");
                    Frame buildEncodingsFrame = TargetEncoderHelper.buildEncodingsFrame(frame2, find3, find, find2, TargetEncoder.this.nclasses());
                    Frame applyLeakageStrategyToEncodings = applyLeakageStrategyToEncodings(buildEncodingsFrame, str, ((TargetEncoderModel.TargetEncoderParameters) TargetEncoder.this._parms)._data_leakage_handling, ((TargetEncoderModel.TargetEncoderParameters) TargetEncoder.this._parms)._fold_column);
                    buildEncodingsFrame.delete();
                    if (applyLeakageStrategyToEncodings._key != null) {
                        DKV.remove(applyLeakageStrategyToEncodings._key);
                    }
                    applyLeakageStrategyToEncodings._key = Key.make(TargetEncoder.this._result.toString() + "_encodings_" + str);
                    DKV.put(applyLeakageStrategyToEncodings);
                    icedHashMap.put(str, applyLeakageStrategyToEncodings);
                }
                if (frame2 != null) {
                    frame2.delete();
                }
                return icedHashMap;
            } catch (Throwable th) {
                if (frame2 != null) {
                    frame2.delete();
                }
                throw th;
            }
        }

        private Frame applyLeakageStrategyToEncodings(Frame frame, String str, TargetEncoderModel.DataLeakageHandlingStrategy dataLeakageHandlingStrategy, String str2) {
            Frame frame2 = null;
            int find = frame.find(str);
            try {
                Scope.enter();
                switch (dataLeakageHandlingStrategy) {
                    case KFold:
                        for (long j : TargetEncoderHelper.getUniqueColumnValues(frame, frame.find(str2))) {
                            Frame outOfFoldEncodings = getOutOfFoldEncodings(frame, str2, j);
                            Scope.track(new Frame[]{outOfFoldEncodings});
                            Frame register = TargetEncoderHelper.register(TargetEncoderHelper.groupEncodingsByCategory(outOfFoldEncodings, find));
                            Scope.track(new Frame[]{register});
                            TargetEncoderHelper.addCon(register, str2, j);
                            if (frame2 == null) {
                                frame2 = register;
                            } else {
                                Frame rBind = TargetEncoderHelper.rBind(frame2, register);
                                frame2.delete();
                                frame2 = rBind;
                            }
                            Scope.track(new Frame[]{frame2});
                        }
                        break;
                    case LeaveOneOut:
                    case None:
                        frame2 = TargetEncoderHelper.groupEncodingsByCategory(frame, find, str2 != null);
                        break;
                    default:
                        throw new IllegalStateException("null or unsupported leakageHandlingStrategy");
                }
                Scope.untrack(new Frame[]{frame2});
                Scope.exit(new Key[0]);
                return frame2;
            } catch (Throwable th) {
                Scope.exit(new Key[0]);
                throw th;
            }
        }

        private Frame getOutOfFoldEncodings(Frame frame, String str, long j) {
            return TargetEncoderHelper.filterNotByValue(frame, frame.find(str), j);
        }
    }

    public TargetEncoder(TargetEncoderModel.TargetEncoderParameters targetEncoderParameters) {
        super(targetEncoderParameters);
        init(false);
    }

    public TargetEncoder(TargetEncoderModel.TargetEncoderParameters targetEncoderParameters, Key<TargetEncoderModel> key) {
        super(targetEncoderParameters, key);
        init(false);
    }

    public TargetEncoder(boolean z) {
        super(new TargetEncoderModel.TargetEncoderParameters(), z);
    }

    public void init(boolean z) {
        disableIgnoreConstColsFeature(z);
        ignoreUnusedColumns(z);
        super.init(z);
        if (!$assertionsDisabled && ((TargetEncoderModel.TargetEncoderParameters) this._parms)._nfolds != 0) {
            throw new AssertionError("nfolds usage forbidden in TargetEncoder");
        }
        if (z) {
            if (((TargetEncoderModel.TargetEncoderParameters) this._parms)._data_leakage_handling == null) {
                ((TargetEncoderModel.TargetEncoderParameters) this._parms)._data_leakage_handling = TargetEncoderModel.DataLeakageHandlingStrategy.None;
            }
            if (((TargetEncoderModel.TargetEncoderParameters) this._parms)._data_leakage_handling == TargetEncoderModel.DataLeakageHandlingStrategy.KFold && ((TargetEncoderModel.TargetEncoderParameters) this._parms)._fold_column == null) {
                error("_fold_column", "Fold column is required when using KFold leakage handling strategy.");
            }
            Frame train = train();
            this._columnsToEncode = ((TargetEncoderModel.TargetEncoderParameters) this._parms)._columns_to_encode;
            if (this._columnsToEncode == null) {
                List asList = Arrays.asList(((TargetEncoderModel.TargetEncoderParameters) this._parms).getNonPredictors());
                ArrayList arrayList = new ArrayList(train.numCols());
                for (int i = 0; i < train.numCols(); i++) {
                    String name = train.name(i);
                    if (!asList.contains(name)) {
                        if (train.vec(i).isCategorical()) {
                            arrayList.add(new String[]{name});
                        } else {
                            warn("_train", "Column `" + name + "` is not categorical and will therefore be ignored by target encoder.");
                        }
                    }
                }
                this._columnsToEncode = (String[][]) arrayList.toArray(new String[0]);
                return;
            }
            HashSet hashSet = new HashSet();
            for (String[] strArr : this._columnsToEncode) {
                if (strArr.length != new HashSet(Arrays.asList(strArr)).size()) {
                    error("_columns_to_encode", "Columns interaction " + Arrays.toString(strArr) + " contains duplicate columns.");
                }
                for (String str : strArr) {
                    if (!hashSet.contains(str)) {
                        Vec vec = train.vec(str);
                        if (vec == null) {
                            error("_columns_to_encode", "Column `" + str + "` from interaction " + Arrays.toString(strArr) + " is not categorical or is missing from the training frame.");
                        } else if (!vec.isCategorical()) {
                            error("_columns_to_encode", "Column `" + str + "` from interaction " + Arrays.toString(strArr) + " must first be converted into categorical to be used by target encoder.");
                        }
                        hashSet.add(str);
                    }
                }
            }
        }
    }

    private void disableIgnoreConstColsFeature(boolean z) {
        ((TargetEncoderModel.TargetEncoderParameters) this._parms)._ignore_const_cols = false;
        if (z && logger.isInfoEnabled()) {
            logger.info("We don't want to ignore any columns during target encoding transformation therefore `_ignore_const_cols` parameter was set to `false`");
        }
    }

    private void ignoreUnusedColumns(boolean z) {
        if (!z || ((TargetEncoderModel.TargetEncoderParameters) this._parms)._columns_to_encode == null || ((TargetEncoderModel.TargetEncoderParameters) this._parms).train() == null) {
            return;
        }
        HashSet hashSet = new HashSet(Arrays.asList(((TargetEncoderModel.TargetEncoderParameters) this._parms).getNonPredictors()));
        for (String[] strArr : ((TargetEncoderModel.TargetEncoderParameters) this._parms)._columns_to_encode) {
            hashSet.addAll(Arrays.asList(strArr));
        }
        HashSet hashSet2 = new HashSet(Arrays.asList(((TargetEncoderModel.TargetEncoderParameters) this._parms).train()._names));
        hashSet2.removeAll(hashSet);
        hashSet2.addAll(((TargetEncoderModel.TargetEncoderParameters) this._parms)._ignored_columns == null ? new HashSet() : new HashSet(Arrays.asList(((TargetEncoderModel.TargetEncoderParameters) this._parms)._ignored_columns)));
        ((TargetEncoderModel.TargetEncoderParameters) this._parms)._ignored_columns = (String[]) hashSet2.toArray(new String[0]);
    }

    public boolean nFoldCV() {
        return false;
    }

    protected ModelBuilder<TargetEncoderModel, TargetEncoderModel.TargetEncoderParameters, TargetEncoderModel.TargetEncoderOutput>.Driver trainModelImpl() {
        return new TargetEncoderDriver();
    }

    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.TargetEncoder};
    }

    public boolean isSupervised() {
        return true;
    }

    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Stable;
    }

    public boolean haveMojo() {
        return true;
    }

    static {
        $assertionsDisabled = !TargetEncoder.class.desiredAssertionStatus();
        logger = LoggerFactory.getLogger(TargetEncoder.class);
    }
}
