package org.tribuo.classification.baseline;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import java.time.OffsetDateTime;
import java.util.Map;
import org.tribuo.Dataset;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Model;
import org.tribuo.MutableOutputInfo;
import org.tribuo.Trainer;
import org.tribuo.classification.Label;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.provenance.impl.TrainerProvenanceImpl;

/* loaded from: input_file:org/tribuo/classification/baseline/DummyClassifierTrainer.class */
public final class DummyClassifierTrainer implements Trainer<Label> {

    @Config(mandatory = true, description = "Type of dummy classifier.")
    private DummyType dummyType;

    @Config(description = "Label to use for the constant classifier.")
    private String constantLabel;

    @Config(description = "Seed for the RNG.")
    private long seed = 1;
    private int invocationCount = 0;

    /* loaded from: input_file:org/tribuo/classification/baseline/DummyClassifierTrainer$DummyType.class */
    public enum DummyType {
        STRATIFIED,
        MOST_FREQUENT,
        UNIFORM,
        CONSTANT
    }

    private DummyClassifierTrainer() {
    }

    public void postConfig() {
        if (this.dummyType == DummyType.CONSTANT && this.constantLabel == null) {
            throw new PropertyException("", "constantLabel", "Please supply a label string when using the type CONSTANT.");
        }
    }

    public Model<Label> train(Dataset<Label> dataset, Map<String, Provenance> map) {
        ModelProvenance modelProvenance = new ModelProvenance(DummyClassifierModel.class.getName(), OffsetDateTime.now(), dataset.getProvenance(), m11getProvenance(), map);
        ImmutableFeatureMap featureIDMap = dataset.getFeatureIDMap();
        this.invocationCount++;
        switch (this.dummyType) {
            case CONSTANT:
                MutableOutputInfo generateMutableOutputInfo = dataset.getOutputInfo().generateMutableOutputInfo();
                Label label = new Label(this.constantLabel);
                generateMutableOutputInfo.observe(label);
                return new DummyClassifierModel(modelProvenance, featureIDMap, generateMutableOutputInfo.generateImmutableOutputInfo(), label);
            case MOST_FREQUENT:
                return new DummyClassifierModel(modelProvenance, featureIDMap, dataset.getOutputIDInfo());
            case UNIFORM:
            case STRATIFIED:
                return new DummyClassifierModel(modelProvenance, featureIDMap, dataset.getOutputIDInfo(), this.dummyType, this.seed);
            default:
                throw new IllegalStateException("Unknown dummyType " + this.dummyType);
        }
    }

    public int getInvocationCount() {
        return this.invocationCount;
    }

    public String toString() {
        switch (this.dummyType) {
            case CONSTANT:
                return "DummyClassifierTrainer(dummyType=" + this.dummyType + ",constantLabel=" + this.constantLabel + ")";
            case MOST_FREQUENT:
                return "DummyClassifierTrainer(dummyType=" + this.dummyType + ")";
            case UNIFORM:
            case STRATIFIED:
                return "DummyClassifierTrainer(dummyType=" + this.dummyType + ",seed=" + this.seed + ")";
            default:
                return "DummyClassifierTrainer(dummyType=" + this.dummyType + ")";
        }
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TrainerProvenance m11getProvenance() {
        return new TrainerProvenanceImpl(this);
    }

    public static DummyClassifierTrainer createStratifiedTrainer(long j) {
        DummyClassifierTrainer dummyClassifierTrainer = new DummyClassifierTrainer();
        dummyClassifierTrainer.dummyType = DummyType.STRATIFIED;
        dummyClassifierTrainer.seed = j;
        return dummyClassifierTrainer;
    }

    public static DummyClassifierTrainer createConstantTrainer(String str) {
        DummyClassifierTrainer dummyClassifierTrainer = new DummyClassifierTrainer();
        dummyClassifierTrainer.dummyType = DummyType.CONSTANT;
        dummyClassifierTrainer.constantLabel = str;
        return dummyClassifierTrainer;
    }

    public static DummyClassifierTrainer createUniformTrainer(long j) {
        DummyClassifierTrainer dummyClassifierTrainer = new DummyClassifierTrainer();
        dummyClassifierTrainer.dummyType = DummyType.UNIFORM;
        dummyClassifierTrainer.seed = j;
        return dummyClassifierTrainer;
    }

    public static DummyClassifierTrainer createMostFrequentTrainer() {
        DummyClassifierTrainer dummyClassifierTrainer = new DummyClassifierTrainer();
        dummyClassifierTrainer.dummyType = DummyType.MOST_FREQUENT;
        return dummyClassifierTrainer;
    }
}
