package org.tribuo.multilabel.baseline;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SplittableRandom;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Model;
import org.tribuo.MutableDataset;
import org.tribuo.Trainer;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.hash.HashedFeatureMap;
import org.tribuo.multilabel.ImmutableMultiLabelInfo;
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.provenance.DatasetProvenance;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/multilabel/baseline/ClassifierChainTrainer.class */
public final class ClassifierChainTrainer implements Trainer<MultiLabel> {
    private static final Logger logger = Logger.getLogger(ClassifierChainTrainer.class.getName());
    public static final String CC_PREFIX = "CC_FEATURES";
    public static final String CC_POSITIVE = "POSITIVE";
    public static final String CC_NEGATIVE = "NEGATIVE";
    public static final String CC_SEPARATOR = "_";

    @Config(mandatory = true, description = "The trainer to use.")
    private Trainer<Label> innerTrainer;

    @Config(mandatory = false, description = "Label order.")
    private List<String> labelOrder;

    @Config(mandatory = false, description = "Randomise the label chain order.")
    private boolean randomOrder;

    @Config(mandatory = false, description = "RNG seed for random label orders.")
    private long seed;
    private int trainInvocationCounter;
    private SplittableRandom rng;

    private ClassifierChainTrainer() {
        this.labelOrder = Collections.emptyList();
        this.randomOrder = false;
        this.seed = 12345L;
        this.trainInvocationCounter = 0;
    }

    public ClassifierChainTrainer(Trainer<Label> trainer, long j) {
        this.labelOrder = Collections.emptyList();
        this.randomOrder = false;
        this.seed = 12345L;
        this.trainInvocationCounter = 0;
        this.innerTrainer = trainer;
        this.labelOrder = Collections.emptyList();
        this.randomOrder = true;
        this.seed = j;
        postConfig();
    }

    public ClassifierChainTrainer(Trainer<Label> trainer, List<String> list) {
        this.labelOrder = Collections.emptyList();
        this.randomOrder = false;
        this.seed = 12345L;
        this.trainInvocationCounter = 0;
        this.innerTrainer = trainer;
        this.labelOrder = Collections.unmodifiableList(new ArrayList(list));
        this.randomOrder = false;
        this.seed = 12345L;
        postConfig();
    }

    public void postConfig() {
        if (!this.randomOrder && this.labelOrder.isEmpty()) {
            throw new PropertyException("", "randomOrder", "Either randomOrder must be true, or labelOrder must be non-empty");
        }
        this.rng = new SplittableRandom(this.seed);
    }

    public ClassifierChainModel train(Dataset<MultiLabel> dataset) {
        return train(dataset, Collections.emptyMap());
    }

    public ClassifierChainModel train(Dataset<MultiLabel> dataset, Map<String, Provenance> map) {
        return train(dataset, map, -1);
    }

    public ClassifierChainModel train(Dataset<MultiLabel> dataset, Map<String, Provenance> map, int i) {
        SplittableRandom split;
        TrainerProvenance m16getProvenance;
        ArrayList arrayList;
        if (dataset.getOutputInfo().getUnknownCount() > 0) {
            throw new IllegalArgumentException("The supplied Dataset contained unknown Outputs, and this Trainer is supervised.");
        }
        synchronized (this) {
            if (i != -1) {
                setInvocationCount(i);
            }
            split = this.rng.split();
            m16getProvenance = m16getProvenance();
            this.trainInvocationCounter++;
        }
        ImmutableMultiLabelInfo immutableMultiLabelInfo = (ImmutableMultiLabelInfo) dataset.getOutputIDInfo();
        Set<MultiLabel> domain = immutableMultiLabelInfo.getDomain();
        if (this.randomOrder) {
            arrayList = new ArrayList(domain.size());
            Iterator<MultiLabel> it = domain.iterator();
            while (it.hasNext()) {
                arrayList.add(new Label(it.next().getLabelString()));
            }
            Util.shuffle(arrayList, split);
        } else {
            HashSet hashSet = new HashSet(this.labelOrder);
            if (immutableMultiLabelInfo.size() != hashSet.size()) {
                throw new IllegalArgumentException("Must supply a total label ordering, labelOrder = " + this.labelOrder.toString() + ", train label domain = " + immutableMultiLabelInfo.getDomain());
            }
            Iterator it2 = hashSet.iterator();
            while (it2.hasNext()) {
                if (immutableMultiLabelInfo.getLabelCount((String) it2.next()) == 0) {
                    throw new IllegalArgumentException("Must supply a total label ordering, labelOrder = " + this.labelOrder.toString() + ", train label domain = " + immutableMultiLabelInfo.getDomain());
                }
            }
            arrayList = new ArrayList(this.labelOrder.size());
            Iterator<String> it3 = this.labelOrder.iterator();
            while (it3.hasNext()) {
                arrayList.add(new Label(it3.next()));
            }
        }
        logger.log(Level.INFO, "Training with label order " + arrayList);
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        if (featureIDMap instanceof HashedFeatureMap) {
            throw new IllegalStateException("Cannot use HashingTrainer wrapped around ClassifierChainTrainer.");
        }
        ArrayList arrayList2 = new ArrayList();
        DatasetProvenance provenance = dataset.getProvenance();
        MutableDataset mutableDataset = new MutableDataset(provenance, new LabelFactory());
        Label label = (Label) arrayList.get(0);
        Iterator it4 = dataset.iterator();
        while (it4.hasNext()) {
            Example example = (Example) it4.next();
            mutableDataset.add(new BinaryExample(example, example.getOutput().createLabel(label)));
        }
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            arrayList2.add(this.innerTrainer.train(mutableDataset));
            if (i2 != arrayList.size() - 1) {
                Label label2 = (Label) arrayList.get(i2 + 1);
                for (int i3 = 0; i3 < mutableDataset.size(); i3++) {
                    BinaryExample binaryExample = (BinaryExample) mutableDataset.getExample(i3);
                    Label m10getOutput = binaryExample.m10getOutput();
                    String label3 = ((Label) arrayList.get(i2)).getLabel();
                    binaryExample.add(new Feature(m10getOutput == MultiLabel.NEGATIVE_LABEL ? "CC_FEATURES_" + label3 + CC_SEPARATOR + CC_NEGATIVE : "CC_FEATURES_" + label3 + CC_SEPARATOR + CC_POSITIVE, 1.0d));
                    binaryExample.setLabel(dataset.getExample(i3).getOutput().createLabel(label2));
                }
                mutableDataset.regenerateOutputInfo();
                mutableDataset.regenerateFeatureInfo();
            }
        }
        return new ClassifierChainModel(Collections.unmodifiableList(arrayList), arrayList2, new ModelProvenance(ClassifierChainModel.class.getName(), OffsetDateTime.now(), provenance, m16getProvenance, map), featureIDMap, immutableMultiLabelInfo);
    }

    public int getInvocationCount() {
        return this.trainInvocationCounter;
    }

    public synchronized void setInvocationCount(int i) {
        if (i < 0) {
            throw new IllegalArgumentException("The supplied invocationCount is less than zero.");
        }
        this.rng = new SplittableRandom(this.seed);
        this.trainInvocationCounter = 0;
        while (this.trainInvocationCounter < i) {
            this.rng.split();
            this.trainInvocationCounter++;
        }
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m16getProvenance() {
        return new TrainerProvenanceImpl(this);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m13train(Dataset dataset, Map map, int i) {
        return train((Dataset<MultiLabel>) dataset, (Map<String, Provenance>) map, i);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m14train(Dataset dataset, Map map) {
        return train((Dataset<MultiLabel>) dataset, (Map<String, Provenance>) map);
    }

    /* renamed from: train, reason: collision with other method in class */
    public /* bridge */ /* synthetic */ Model m15train(Dataset dataset) {
        return train((Dataset<MultiLabel>) dataset);
    }
}
