package org.tribuo.hash;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.util.Map;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.ImmutableDataset;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;

/* loaded from: input_file:org/tribuo/hash/HashingTrainer.class */
public final class HashingTrainer<T extends Output<T>> implements Trainer<T> {
    private static final Logger logger = Logger.getLogger(HashingTrainer.class.getName());

    @Config(mandatory = true, description = "Trainer to use.")
    private Trainer<T> innerTrainer;

    @Config(mandatory = true, description = "Feature hashing function to use.")
    private Hasher hasher;

    private HashingTrainer() {
    }

    public HashingTrainer(Trainer<T> trainer, Hasher hasher) {
        this.innerTrainer = trainer;
        this.hasher = hasher;
    }

    @Override // org.tribuo.Trainer
    public Model<T> train(Dataset<T> dataset, Map<String, Provenance> map) {
        logger.log(Level.INFO, "Before hashing, had " + dataset.getFeatureMap().size() + " features.");
        ImmutableDataset hashFeatureMap = ImmutableDataset.hashFeatureMap(dataset, this.hasher);
        logger.log(Level.INFO, "After hashing, had " + hashFeatureMap.getFeatureMap().size() + " features.");
        Model<T> train = this.innerTrainer.train(hashFeatureMap, map);
        if (train.getFeatureIDMap() instanceof HashedFeatureMap) {
            return train;
        }
        throw new IllegalStateException("Trainer " + this.innerTrainer.getClass().getName() + " does not support hashing.");
    }

    @Override // org.tribuo.Trainer
    public int getInvocationCount() {
        return this.innerTrainer.getInvocationCount();
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m32getProvenance() {
        return new TrainerProvenanceImpl(this);
    }
}
