package org.tribuo.multilabel.example;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.PropertyException;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.SkeletalConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.tribuo.ConfigurableDataSource;
import org.tribuo.Example;
import org.tribuo.MutableDataset;
import org.tribuo.OutputFactory;
import org.tribuo.classification.Label;
import org.tribuo.impl.ArrayExample;
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.multilabel.MultiLabelFactory;
import org.tribuo.provenance.ConfiguredDataSourceProvenance;
import org.tribuo.provenance.DataSourceProvenance;

/* loaded from: input_file:org/tribuo/multilabel/example/MultiLabelGaussianDataSource.class */
public final class MultiLabelGaussianDataSource implements ConfigurableDataSource<MultiLabel> {

    @Config(mandatory = true, description = "The number of samples to draw.")
    private int numSamples;

    @Config(description = "The feature weights. Must be a 4 element array.")
    private float[] yZeroWeights;

    @Config(description = "The feature weights. Must be a 4 element array.")
    private float[] yOneWeights;

    @Config(description = "The feature weights. Must be a 4 element array.")
    private float[] yTwoWeights;

    @Config(description = "The threshold for each class.")
    private float[] threshold;

    @Config(description = "Negate the computed value before thresholding it.")
    private boolean[] negate;

    @Config(description = "The variance of the noise gaussian.")
    private float variance;

    @Config(description = "The minimum values of the Xs.")
    private float[] xMin;

    @Config(description = "The maximum values of the Xs.")
    private float[] xMax;

    @Config(description = "The RNG seed.")
    private long seed;
    private List<Example<MultiLabel>> examples;
    private final MultiLabelFactory factory;
    private static final String[] FEATURE_NAMES = {"X_0", "X_1", "X_2", "X_3"};
    private static final String[] LABEL_NAMES = {"Y_0", "Y_1", "Y_2"};

    /* loaded from: input_file:org/tribuo/multilabel/example/MultiLabelGaussianDataSource$MultiLabelGaussianDataSourceProvenance.class */
    public static class MultiLabelGaussianDataSourceProvenance extends SkeletalConfiguredObjectProvenance implements ConfiguredDataSourceProvenance {
        private static final long serialVersionUID = 1;

        MultiLabelGaussianDataSourceProvenance(MultiLabelGaussianDataSource multiLabelGaussianDataSource) {
            super(multiLabelGaussianDataSource, "DataSource");
        }

        public MultiLabelGaussianDataSourceProvenance(Map<String, Provenance> map) {
            this(extractProvenanceInfo(map));
        }

        private MultiLabelGaussianDataSourceProvenance(SkeletalConfiguredObjectProvenance.ExtractedInfo extractedInfo) {
            super(extractedInfo);
        }

        protected static SkeletalConfiguredObjectProvenance.ExtractedInfo extractProvenanceInfo(Map<String, Provenance> map) {
            HashMap hashMap = new HashMap(map);
            return new SkeletalConfiguredObjectProvenance.ExtractedInfo(ObjectProvenance.checkAndExtractProvenance(hashMap, "class-name", StringProvenance.class, MultiLabelGaussianDataSourceProvenance.class.getSimpleName()).getValue(), ObjectProvenance.checkAndExtractProvenance(hashMap, "host-short-name", StringProvenance.class, MultiLabelGaussianDataSourceProvenance.class.getSimpleName()).getValue(), hashMap, Collections.emptyMap());
        }
    }

    private MultiLabelGaussianDataSource() {
        this.yZeroWeights = new float[]{1.0f, 1.0f, 1.0f, 1.0f};
        this.yOneWeights = new float[]{1.0f, 1.0f, 1.0f, 1.0f};
        this.yTwoWeights = new float[]{1.0f, -3.0f, 1.0f, 3.0f};
        this.threshold = new float[]{0.0f, 0.0f, 2.0f};
        this.negate = new boolean[]{false, true, false};
        this.variance = 0.1f;
        this.xMin = new float[]{-2.0f, -2.0f, -2.0f, -2.0f};
        this.xMax = new float[]{2.0f, 2.0f, 2.0f, 2.0f};
        this.seed = 12345L;
        this.factory = new MultiLabelFactory();
    }

    public MultiLabelGaussianDataSource(int i, float[] fArr, float[] fArr2, float[] fArr3, float[] fArr4, boolean[] zArr, float f, float[] fArr5, float[] fArr6, long j) {
        this.yZeroWeights = new float[]{1.0f, 1.0f, 1.0f, 1.0f};
        this.yOneWeights = new float[]{1.0f, 1.0f, 1.0f, 1.0f};
        this.yTwoWeights = new float[]{1.0f, -3.0f, 1.0f, 3.0f};
        this.threshold = new float[]{0.0f, 0.0f, 2.0f};
        this.negate = new boolean[]{false, true, false};
        this.variance = 0.1f;
        this.xMin = new float[]{-2.0f, -2.0f, -2.0f, -2.0f};
        this.xMax = new float[]{2.0f, 2.0f, 2.0f, 2.0f};
        this.seed = 12345L;
        this.factory = new MultiLabelFactory();
        this.numSamples = i;
        this.yZeroWeights = fArr;
        this.yOneWeights = fArr2;
        this.yTwoWeights = fArr3;
        this.threshold = fArr4;
        this.negate = zArr;
        this.variance = f;
        this.xMin = fArr5;
        this.xMax = fArr6;
        this.seed = j;
        postConfig();
    }

    public void postConfig() {
        Random random = new Random(this.seed);
        if (this.yZeroWeights.length != 4) {
            throw new PropertyException("", "yZeroWeights", "Must supply 4 yZeroWeights, found " + this.yZeroWeights.length);
        }
        if (this.yOneWeights.length != 4) {
            throw new PropertyException("", "yOneWeights", "Must supply 4 yOneWeights, found " + this.yOneWeights.length);
        }
        if (this.yTwoWeights.length != 4) {
            throw new PropertyException("", "yTwoWeights", "Must supply 4 yTwoWeights, found " + this.yTwoWeights.length);
        }
        if (this.threshold.length != 3) {
            throw new PropertyException("", "threshold", "Must supply 3 values for threshold, found " + this.threshold.length);
        }
        if (this.negate.length != 3) {
            throw new PropertyException("", "negate", "Must supply 3 values for negate, found " + this.negate.length);
        }
        if (this.xMin.length != 4) {
            throw new PropertyException("", "xMin", "Must supply 4 feature minimums, found " + this.xMin.length);
        }
        if (this.xMax.length != 4) {
            throw new PropertyException("", "xMax", "Must supply 4 feature maximums, found " + this.xMax.length);
        }
        float[] fArr = new float[4];
        for (int i = 0; i < 4; i++) {
            if (this.xMin[i] > this.xMax[i]) {
                throw new PropertyException("", "xMin", "Feature minimums must be below the maximums, found min = " + Arrays.toString(this.xMin) + " and max = " + Arrays.toString(this.xMax));
            }
            fArr[i] = this.xMax[i] - this.xMin[i];
        }
        if (this.variance <= 0.0d) {
            throw new PropertyException("", "variance", "Variance must be positive, found variance = " + this.variance);
        }
        ArrayList arrayList = new ArrayList(this.numSamples);
        for (int i2 = 0; i2 < this.numSamples; i2++) {
            double[] dArr = new double[4];
            for (int i3 = 0; i3 < dArr.length; i3++) {
                dArr[i3] = (random.nextDouble() * fArr[i3]) + this.xMin[i3];
            }
            double nextGaussian = (random.nextGaussian() * this.variance) + (this.yZeroWeights[0] * dArr[0]) + (this.yZeroWeights[1] * dArr[1]) + (this.yZeroWeights[2] * dArr[0] * dArr[1]) + (this.yZeroWeights[3] * Math.pow(dArr[1], 3.0d));
            double nextGaussian2 = (random.nextGaussian() * this.variance) + (this.yOneWeights[0] * dArr[0]) + (this.yOneWeights[1] * dArr[1]) + (this.yOneWeights[2] * dArr[0] * dArr[1]) + (this.yOneWeights[3] * Math.pow(dArr[1], 3.0d));
            double nextGaussian3 = (random.nextGaussian() * this.variance) + (this.yTwoWeights[0] * dArr[0]) + (this.yTwoWeights[1] * dArr[2]) + (this.yTwoWeights[2] * dArr[0] * dArr[1]) + (this.yTwoWeights[3] * dArr[1] * dArr[2] * dArr[2]);
            if (this.negate[0]) {
                nextGaussian = -nextGaussian;
            }
            if (this.negate[1]) {
                nextGaussian2 = -nextGaussian2;
            }
            if (this.negate[2]) {
                nextGaussian3 = -nextGaussian3;
            }
            HashSet hashSet = new HashSet();
            if (nextGaussian > this.threshold[0]) {
                hashSet.add(new Label(LABEL_NAMES[0]));
            }
            if (nextGaussian2 > this.threshold[1]) {
                hashSet.add(new Label(LABEL_NAMES[1]));
            }
            if (nextGaussian3 > this.threshold[2]) {
                hashSet.add(new Label(LABEL_NAMES[2]));
            }
            arrayList.add(new ArrayExample(new MultiLabel(hashSet), FEATURE_NAMES, dArr));
        }
        this.examples = Collections.unmodifiableList(arrayList);
    }

    public OutputFactory<MultiLabel> getOutputFactory() {
        return this.factory;
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public DataSourceProvenance m31getProvenance() {
        return new MultiLabelGaussianDataSourceProvenance(this);
    }

    public Iterator<Example<MultiLabel>> iterator() {
        return this.examples.iterator();
    }

    public static MutableDataset<MultiLabel> generateDataset(int i, float[] fArr, float[] fArr2, float[] fArr3, float[] fArr4, boolean[] zArr, float f, float[] fArr5, float[] fArr6, long j) {
        return new MutableDataset<>(new MultiLabelGaussianDataSource(i, fArr, fArr2, fArr3, fArr4, zArr, f, fArr5, fArr6, j));
    }

    public static MultiLabelGaussianDataSource makeDefaultSource(int i, long j) {
        MultiLabelGaussianDataSource multiLabelGaussianDataSource = new MultiLabelGaussianDataSource();
        multiLabelGaussianDataSource.numSamples = i;
        multiLabelGaussianDataSource.seed = j;
        multiLabelGaussianDataSource.postConfig();
        return multiLabelGaussianDataSource;
    }
}
