package org.tribuo.multilabel.ensemble;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ListProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Map;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Trainer;
import org.tribuo.classification.Label;
import org.tribuo.ensemble.WeightedEnsembleModel;
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.multilabel.baseline.ClassifierChainTrainer;
import org.tribuo.provenance.EnsembleModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;

/* loaded from: input_file:org/tribuo/multilabel/ensemble/CCEnsembleTrainer.class */
public final class CCEnsembleTrainer implements Trainer<MultiLabel> {
    private static final Logger logger = Logger.getLogger(CCEnsembleTrainer.class.getName());

    @Config(mandatory = true, description = "The trainer to use.")
    private Trainer<Label> innerTrainer;

    @Config(mandatory = true, description = "Number of classifier chains to build.")
    private int numMembers;

    @Config(mandatory = true, description = "RNG seed for random label orders.")
    private long seed;
    private int trainInvocationCounter = 0;
    private SplittableRandom rng;

    private CCEnsembleTrainer() {
    }

    public CCEnsembleTrainer(Trainer<Label> trainer, int i, long j) {
        if (i < 1) {
            throw new IllegalArgumentException("Must have a positive number of ensemble members, found " + i);
        }
        this.innerTrainer = trainer;
        this.numMembers = i;
        this.seed = j;
        postConfig();
    }

    public void postConfig() throws PropertyException {
        if (this.numMembers < 1) {
            throw new PropertyException("", "numMembers", "Must have a positive number of ensemble members, found " + this.numMembers);
        }
        this.rng = new SplittableRandom(this.seed);
    }

    public WeightedEnsembleModel<MultiLabel> train(Dataset<MultiLabel> dataset) {
        return train(dataset, Collections.emptyMap());
    }

    public WeightedEnsembleModel<MultiLabel> train(Dataset<MultiLabel> dataset, Map<String, Provenance> map) {
        return train(dataset, map, -1);
    }

    public WeightedEnsembleModel<MultiLabel> train(Dataset<MultiLabel> dataset, Map<String, Provenance> map, int i) {
        SplittableRandom split;
        TrainerProvenance m23getProvenance;
        synchronized (this) {
            if (i != -1) {
                setInvocationCount(i);
            }
            split = this.rng.split();
            m23getProvenance = m23getProvenance();
            this.trainInvocationCounter++;
        }
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        ImmutableOutputInfo outputIDInfo = dataset.getOutputIDInfo();
        ClassifierChainTrainer classifierChainTrainer = new ClassifierChainTrainer(this.innerTrainer, split.nextLong());
        ArrayList arrayList = new ArrayList(this.numMembers);
        for (int i2 = 0; i2 < this.numMembers; i2++) {
            logger.info("Building chain " + i2);
            arrayList.add(classifierChainTrainer.train(dataset));
        }
        return new WeightedEnsembleModel<>("classifier-chain-ensemble", new EnsembleModelProvenance(WeightedEnsembleModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), m23getProvenance, map, ListProvenance.createListProvenance(arrayList)), featureIDMap, outputIDInfo, arrayList, new MultiLabelVotingCombiner());
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public synchronized void setInvocationCount(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.rng = new SplittableRandom(this.seed);
        this.trainInvocationCounter = 0;
        while (this.trainInvocationCounter < i) {
            this.rng.split();
            this.trainInvocationCounter++;
        }
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m23getProvenance() {
        return new TrainerProvenanceImpl(this);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m20train(Dataset dataset, Map map, int i) {
        return train((Dataset<MultiLabel>) dataset, (Map<String, Provenance>) map, i);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m21train(Dataset dataset, Map map) {
        return train((Dataset<MultiLabel>) dataset, (Map<String, Provenance>) map);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m22train(Dataset dataset) {
        return train((Dataset<MultiLabel>) dataset);
    }
}
