package org.tribuo.ensemble;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.dataset.DatasetView;
import org.tribuo.provenance.EnsembleModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;

/* loaded from: input_file:org/tribuo/ensemble/BaggingTrainer.class */
public class BaggingTrainer<T extends Output<T>> implements Trainer<T> {
    private static final Logger logger = Logger.getLogger(BaggingTrainer.class.getName());

    @Config(mandatory = true, description = "The trainer to use for each ensemble member.")
    protected Trainer<T> innerTrainer;

    @Config(mandatory = true, description = "The number of ensemble members to train.")
    protected int numMembers;

    @Config(mandatory = true, description = "The seed for the RNG.")
    protected long seed;

    @Config(mandatory = true, description = "The combination function to aggregate each ensemble member's outputs.")
    protected EnsembleCombiner<T> combiner;
    protected SplittableRandom rng;
    protected int trainInvocationCounter;

    protected BaggingTrainer() {
    }

    public BaggingTrainer(Trainer<T> trainer, EnsembleCombiner<T> ensembleCombiner, int i) {
        this(trainer, ensembleCombiner, i, Trainer.DEFAULT_SEED);
    }

    public BaggingTrainer(Trainer<T> trainer, EnsembleCombiner<T> ensembleCombiner, int i, long j) {
        this.innerTrainer = trainer;
        this.combiner = ensembleCombiner;
        this.numMembers = i;
        this.seed = j;
        postConfig();
    }

    public synchronized void postConfig() {
        this.rng = new SplittableRandom(this.seed);
    }

    protected String ensembleName() {
        return "bagging-ensemble";
    }

    public String toString() {
        return "BaggingTrainer(innerTrainer=" + this.innerTrainer.toString() + ",combiner=" + this.combiner.toString() + ",numMembers=" + this.numMembers + ",seed=" + this.seed + ")";
    }

    @Override // org.tribuo.Trainer
    public Model<T> train(Dataset<T> dataset, Map<String, Provenance> map) {
        SplittableRandom split;
        TrainerProvenance m21getProvenance;
        synchronized (this) {
            split = this.rng.split();
            m21getProvenance = m21getProvenance();
            this.trainInvocationCounter++;
        }
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        ImmutableOutputInfo<T> outputIDInfo = dataset.getOutputIDInfo();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.numMembers; i++) {
            logger.info("Building model " + i);
            arrayList.add(trainSingleModel(dataset, featureIDMap, outputIDInfo, split, map));
        }
        return new WeightedEnsembleModel(ensembleName(), new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), m21getProvenance, map, ListProvenance.createListProvenance(arrayList)), featureIDMap, outputIDInfo, arrayList, this.combiner);
    }

    protected Model<T> trainSingleModel(Dataset<T> dataset, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, SplittableRandom splittableRandom, Map<String, Provenance> map) {
        return this.innerTrainer.train(DatasetView.createBootstrapView(dataset, dataset.size(), splittableRandom.nextInt(), immutableFeatureMap, immutableOutputInfo), map);
    }

    @Override // org.tribuo.Trainer
    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m21getProvenance() {
        return new TrainerProvenanceImpl(this);
    }
}
