package org.tribuo.multilabel.baseline;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Map;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Model;
import org.tribuo.MutableDataset;
import org.tribuo.Trainer;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.multilabel.ImmutableMultiLabelInfo;
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;

/* loaded from: input_file:org/tribuo/multilabel/baseline/IndependentMultiLabelTrainer.class */
public class IndependentMultiLabelTrainer implements Trainer<MultiLabel> {

    @Config(mandatory = true, description = "Trainer to use for each individual label.")
    private Trainer<Label> innerTrainer;
    private int trainInvocationCounter = 0;

    private IndependentMultiLabelTrainer() {
    }

    public IndependentMultiLabelTrainer(Trainer<Label> trainer) {
        this.innerTrainer = trainer;
    }

    public Model<MultiLabel> train(Dataset<MultiLabel> dataset, Map<String, Provenance> map) {
        if (dataset.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        ImmutableMultiLabelInfo immutableMultiLabelInfo = (ImmutableMultiLabelInfo) dataset.getOutputIDInfo();
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        DatasetProvenance provenance = dataset.getProvenance();
        MutableDataset mutableDataset = new MutableDataset(provenance, new LabelFactory());
        Iterator<MultiLabel> it = immutableMultiLabelInfo.getDomain().iterator();
        while (it.hasNext()) {
            Label label = new Label(it.next().getLabelString());
            mutableDataset.clear();
            arrayList2.add(label);
            Iterator it2 = dataset.iterator();
            while (it2.hasNext()) {
                Example example = (Example) it2.next();
                mutableDataset.add(new BinaryExample(example, example.getOutput().createLabel(label)));
            }
            arrayList.add(this.innerTrainer.train(mutableDataset));
        }
        ModelProvenance modelProvenance = new ModelProvenance(IndependentMultiLabelModel.class.getName(), OffsetDateTime.now(), provenance, m11getProvenance(), map);
        this.trainInvocationCounter++;
        return new IndependentMultiLabelModel(arrayList2, arrayList, modelProvenance, featureIDMap, immutableMultiLabelInfo);
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public String toString() {
        return "IndependentMultiLabelTrainer(innerTrainer=" + this.innerTrainer.toString() + ")";
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m11getProvenance() {
        return new TrainerProvenanceImpl(this);
    }
}
