package org.tribuo.transform.transformations;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.provenance.ObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.Provenance;
import com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import org.tribuo.transform.TransformStatistics;
import org.tribuo.transform.Transformation;
import org.tribuo.transform.TransformationProvenance;
import org.tribuo.transform.Transformer;

/* loaded from: input_file:org/tribuo/transform/transformations/MeanStdDevTransformation.class */
public final class MeanStdDevTransformation implements Transformation {
    private static final String TARGET_MEAN = "targetMean";
    private static final String TARGET_STDDEV = "targetStdDev";

    @Config(mandatory = true, description = "Mean value after transformation.")
    private double targetMean;

    @Config(mandatory = true, description = "Standard deviation after transformation.")
    private double targetStdDev;
    private MeanStdDevTransformationProvenance provenance;

    /* loaded from: input_file:org/tribuo/transform/transformations/MeanStdDevTransformation$MeanStdDevStatistics.class */
    private static class MeanStdDevStatistics implements TransformStatistics {
        private final double targetMean;
        private final double targetStdDev;
        private double mean = 0.0d;
        private double sumSquares = 0.0d;
        private long count = 0;

        public MeanStdDevStatistics(double d, double d2) {
            this.targetMean = d;
            this.targetStdDev = d2;
        }

        @Override // org.tribuo.transform.TransformStatistics
        public void observeValue(double d) {
            this.count++;
            double d2 = d - this.mean;
            this.mean += d2 / this.count;
            this.sumSquares += d2 * (d - this.mean);
        }

        @Override // org.tribuo.transform.TransformStatistics
        public void observeSparse() {
        }

        @Override // org.tribuo.transform.TransformStatistics
        public void observeSparse(int i) {
        }

        @Override // org.tribuo.transform.TransformStatistics
        public Transformer generateTransformer() {
            return new MeanStdDevTransformer(this.mean, Math.sqrt(this.sumSquares / (this.count - 1)), this.targetMean, this.targetStdDev);
        }

        public String toString() {
            return "MeanStdDevStatistics(mean=" + this.mean + ",sumSquares=" + this.sumSquares + ",count=" + this.count + "targetMean=" + this.targetMean + ",targetStdDev=" + this.targetStdDev + ")";
        }
    }

    /* loaded from: input_file:org/tribuo/transform/transformations/MeanStdDevTransformation$MeanStdDevTransformationProvenance.class */
    public static final class MeanStdDevTransformationProvenance implements TransformationProvenance {
        private static final long serialVersionUID = 1;
        private final DoubleProvenance targetMean;
        private final DoubleProvenance targetStdDev;

        MeanStdDevTransformationProvenance(MeanStdDevTransformation meanStdDevTransformation) {
            this.targetMean = new DoubleProvenance(MeanStdDevTransformation.TARGET_MEAN, meanStdDevTransformation.targetMean);
            this.targetStdDev = new DoubleProvenance(MeanStdDevTransformation.TARGET_STDDEV, meanStdDevTransformation.targetStdDev);
        }

        public MeanStdDevTransformationProvenance(Map<String, Provenance> map) {
            this.targetMean = ObjectProvenance.checkAndExtractProvenance(map, MeanStdDevTransformation.TARGET_MEAN, DoubleProvenance.class, MeanStdDevTransformationProvenance.class.getSimpleName());
            this.targetStdDev = ObjectProvenance.checkAndExtractProvenance(map, MeanStdDevTransformation.TARGET_STDDEV, DoubleProvenance.class, MeanStdDevTransformationProvenance.class.getSimpleName());
        }

        public String getClassName() {
            return MeanStdDevTransformation.class.getName();
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (!(obj instanceof MeanStdDevTransformationProvenance)) {
                return false;
            }
            MeanStdDevTransformationProvenance meanStdDevTransformationProvenance = (MeanStdDevTransformationProvenance) obj;
            return this.targetMean.equals(meanStdDevTransformationProvenance.targetMean) && this.targetStdDev.equals(meanStdDevTransformationProvenance.targetStdDev);
        }

        public int hashCode() {
            return Objects.hash(this.targetMean, this.targetStdDev);
        }

        public Map<String, Provenance> getConfiguredParameters() {
            HashMap hashMap = new HashMap();
            hashMap.put(MeanStdDevTransformation.TARGET_MEAN, this.targetMean);
            hashMap.put(MeanStdDevTransformation.TARGET_STDDEV, this.targetStdDev);
            return Collections.unmodifiableMap(hashMap);
        }
    }

    /* loaded from: input_file:org/tribuo/transform/transformations/MeanStdDevTransformation$MeanStdDevTransformer.class */
    private static class MeanStdDevTransformer implements Transformer {
        private static final long serialVersionUID = 1;
        private final double observedMean;
        private final double observedStdDev;
        private final double targetMean;
        private final double targetStdDev;

        public MeanStdDevTransformer(double d, double d2, double d3, double d4) {
            this.observedMean = d;
            this.observedStdDev = d2;
            this.targetMean = d3;
            this.targetStdDev = d4;
        }

        @Override // org.tribuo.transform.Transformer
        public double transform(double d) {
            return (((d - this.observedMean) / this.observedStdDev) * this.targetStdDev) + this.targetMean;
        }

        public String toString() {
            return "MeanStdDevTransformer(observedMean=" + this.observedMean + ",observedStdDev=" + this.observedStdDev + ",targetMean=" + this.targetMean + ",targetStdDev=" + this.targetStdDev + ")";
        }
    }

    public MeanStdDevTransformation() {
        this.targetMean = 0.0d;
        this.targetStdDev = 1.0d;
    }

    public MeanStdDevTransformation(double d, double d2) {
        this.targetMean = 0.0d;
        this.targetStdDev = 1.0d;
        this.targetMean = d;
        this.targetStdDev = d2;
        postConfig();
    }

    public void postConfig() {
        if (this.targetStdDev < 1.0E-12d) {
            throw new IllegalArgumentException("Target standard deviation must be positive, found " + this.targetStdDev);
        }
    }

    @Override // org.tribuo.transform.Transformation
    public TransformStatistics createStats() {
        return new MeanStdDevStatistics(this.targetMean, this.targetStdDev);
    }

    /* renamed from: getProvenance, reason: merged with bridge method [inline-methods] */
    public TransformationProvenance m56getProvenance() {
        if (this.provenance == null) {
            this.provenance = new MeanStdDevTransformationProvenance(this);
        }
        return this.provenance;
    }

    public String toString() {
        return "MeanStdDevTransformation(targetMean=" + this.targetMean + ",targetStdDev=" + this.targetStdDev + ")";
    }
}
