package org.tribuo.classification.explanations.lime;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.SplittableRandom;
import java.util.logging.Logger;
import org.tribuo.CategoricalInfo;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Model;
import org.tribuo.MutableDataset;
import org.tribuo.OutputFactory;
import org.tribuo.Prediction;
import org.tribuo.RealInfo;
import org.tribuo.SparseModel;
import org.tribuo.SparseTrainer;
import org.tribuo.VariableIDInfo;
import org.tribuo.VariableInfo;
import org.tribuo.WeightedExamples;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.explanations.Explanation;
import org.tribuo.classification.explanations.TabularExplainer;
import org.tribuo.impl.ArrayExample;
import org.tribuo.interop.ExternalModel;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.provenance.SimpleDataSourceProvenance;
import org.tribuo.regression.RegressionFactory;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.evaluation.RegressionEvaluator;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/classification/explanations/lime/LIMEBase.class */
public class LIMEBase implements TabularExplainer<Regressor> {
    public static final double WIDTH_CONSTANT = 0.75d;
    public static final double DISTANCE_DELTA = 1.0E-12d;
    protected final SplittableRandom rng;
    protected final Model<Label> innerModel;
    protected final SparseTrainer<Regressor> explanationTrainer;
    protected final int numSamples;
    protected final long numTrainingExamples;
    protected final double kernelWidth;
    private final ImmutableFeatureMap fMap;
    private static final Logger logger = Logger.getLogger(LIMEBase.class.getName());
    protected static final OutputFactory<Regressor> regressionFactory = new RegressionFactory();
    protected static final RegressionEvaluator evaluator = new RegressionEvaluator(true);

    public LIMEBase(SplittableRandom splittableRandom, Model<Label> model, SparseTrainer<Regressor> sparseTrainer, int i) {
        if (!(sparseTrainer instanceof WeightedExamples)) {
            throw new IllegalArgumentException("SparseTrainer must implement WeightedExamples, found " + sparseTrainer.toString());
        }
        if (!model.generatesProbabilities()) {
            throw new IllegalArgumentException("LIME requires the model generate probabilities.");
        }
        if (model instanceof ExternalModel) {
            throw new IllegalArgumentException("LIME requires the model to have been trained in Tribuo. Found " + model.getClass() + " which is an external model.");
        }
        this.rng = splittableRandom;
        this.innerModel = model;
        this.explanationTrainer = sparseTrainer;
        this.numSamples = i;
        this.numTrainingExamples = model.getOutputIDInfo().getTotalObservations();
        this.kernelWidth = Math.pow(model.getFeatureIDMap().size() * 0.75d, 2.0d);
        this.fMap = model.getFeatureIDMap();
    }

    @Override // org.tribuo.classification.explanations.TabularExplainer
    public Explanation<Regressor> explain(Example<Label> example) {
        return (LIMEExplanation) explainWithSamples(example).getA();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Pair<LIMEExplanation, List<Example<Regressor>>> explainWithSamples(Example<Label> example) {
        Prediction predict = this.innerModel.predict(example);
        ArrayExample arrayExample = new ArrayExample(transformOutput(predict), example, 1.0f);
        List<Example<Regressor>> sampleData = sampleData(example);
        SparseModel<Regressor> trainExplainer = trainExplainer(arrayExample, sampleData);
        ArrayList arrayList = new ArrayList(trainExplainer.predict(sampleData));
        arrayList.add(trainExplainer.predict(arrayExample));
        return new Pair<>(new LIMEExplanation(trainExplainer, predict, evaluator.evaluate(trainExplainer, arrayList, new SimpleDataSourceProvenance("LIMEColumnar sampled data", regressionFactory))), sampleData);
    }

    private List<Example<Regressor>> sampleData(Example<Label> example) {
        ArrayList arrayList = new ArrayList();
        SparseVector createSparseVector = SparseVector.createSparseVector(example, this.fMap, false);
        Random random = new Random(this.rng.nextLong());
        for (int i = 0; i < this.numSamples; i++) {
            Example<Label> samplePoint = samplePoint(random, this.fMap, this.numTrainingExamples, createSparseVector);
            arrayList.add(new ArrayExample(transformOutput(this.innerModel.predict(samplePoint)), samplePoint, (float) kernelDist(measureDistance(this.fMap, this.numTrainingExamples, createSparseVector, SparseVector.createSparseVector(samplePoint, this.fMap, false)), this.kernelWidth)));
        }
        return arrayList;
    }

    public static Example<Label> samplePoint(Random random, ImmutableFeatureMap immutableFeatureMap, long j, SparseVector sparseVector) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Iterator it = immutableFeatureMap.iterator();
        while (it.hasNext()) {
            RealInfo realInfo = (VariableInfo) it.next();
            double d = sparseVector.get(((VariableIDInfo) realInfo).getID());
            if (realInfo instanceof CategoricalInfo) {
                double frequencyBasedSample = ((CategoricalInfo) realInfo).frequencyBasedSample(random, j);
                if (Math.abs(frequencyBasedSample) > 1.0E-10d) {
                    arrayList.add(realInfo.getName());
                    arrayList2.add(Double.valueOf(frequencyBasedSample));
                }
            } else {
                if (!(realInfo instanceof RealInfo)) {
                    throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + realInfo.getClass().getName());
                }
                RealInfo realInfo2 = realInfo;
                if (random.nextDouble() < realInfo2.getCount() / j) {
                    double nextGaussian = (random.nextGaussian() * Math.sqrt(realInfo2.getVariance())) + d;
                    arrayList.add(realInfo.getName());
                    arrayList2.add(Double.valueOf(nextGaussian));
                }
            }
        }
        return new ArrayExample(LabelFactory.UNKNOWN_LABEL, (String[]) arrayList.toArray(new String[0]), Util.toPrimitiveDouble(arrayList2));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SparseModel<Regressor> trainExplainer(Example<Regressor> example, List<Example<Regressor>> list) {
        MutableDataset mutableDataset = new MutableDataset(new SimpleDataSourceProvenance("explanationDataset", OffsetDateTime.now(), regressionFactory), regressionFactory);
        mutableDataset.add(example);
        mutableDataset.addAll(list);
        return this.explanationTrainer.train(mutableDataset);
    }

    public static double kernelDist(double d, double d2) {
        return Math.sqrt(Math.exp((-(d * d)) / d2));
    }

    public static double measureDistance(ImmutableFeatureMap immutableFeatureMap, long j, SparseVector sparseVector, SparseVector sparseVector2) {
        double d = 0.0d;
        VectorIterator it = sparseVector.iterator();
        VectorIterator it2 = sparseVector2.iterator();
        while (it.hasNext() && it2.hasNext()) {
            VectorTuple vectorTuple = (VectorTuple) it.next();
            VectorTuple vectorTuple2 = (VectorTuple) it2.next();
            while (it.hasNext() && vectorTuple.index < vectorTuple2.index) {
                d += calculateSingleDistance(immutableFeatureMap, j, vectorTuple.index, vectorTuple.value);
                vectorTuple = (VectorTuple) it.next();
            }
            while (it2.hasNext() && vectorTuple.index > vectorTuple2.index) {
                d += calculateSingleDistance(immutableFeatureMap, j, vectorTuple2.index, vectorTuple2.value);
                vectorTuple2 = (VectorTuple) it2.next();
            }
            d = vectorTuple.index == vectorTuple2.index ? d + calculateSingleDistance(immutableFeatureMap, j, vectorTuple.index, vectorTuple.value, vectorTuple2.value) : d + calculateSingleDistance(immutableFeatureMap, j, vectorTuple.index, vectorTuple.value) + calculateSingleDistance(immutableFeatureMap, j, vectorTuple2.index, vectorTuple2.value);
        }
        while (it.hasNext()) {
            VectorTuple vectorTuple3 = (VectorTuple) it.next();
            d += calculateSingleDistance(immutableFeatureMap, j, vectorTuple3.index, vectorTuple3.value);
        }
        while (it2.hasNext()) {
            VectorTuple vectorTuple4 = (VectorTuple) it2.next();
            d += calculateSingleDistance(immutableFeatureMap, j, vectorTuple4.index, vectorTuple4.value);
        }
        return Math.sqrt(d);
    }

    private static double calculateSingleDistance(ImmutableFeatureMap immutableFeatureMap, long j, int i, double d) {
        RealInfo realInfo = immutableFeatureMap.get(i);
        if (realInfo instanceof CategoricalInfo) {
            return 1.0d;
        }
        if (!(realInfo instanceof RealInfo)) {
            throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + realInfo.getClass().getName());
        }
        RealInfo realInfo2 = realInfo;
        double d2 = d * d;
        double max = j != ((long) realInfo.getCount()) ? Math.max(realInfo2.getMax(), 0.0d) - Math.min(realInfo2.getMin(), 0.0d) : realInfo2.getMax() - realInfo2.getMin();
        return d2 / (max * max);
    }

    private static double calculateSingleDistance(ImmutableFeatureMap immutableFeatureMap, long j, int i, double d, double d2) {
        RealInfo realInfo = immutableFeatureMap.get(i);
        if (realInfo instanceof CategoricalInfo) {
            return Math.abs(d - d2) > 1.0E-12d ? 1.0d : 0.0d;
        }
        if (!(realInfo instanceof RealInfo)) {
            throw new IllegalStateException("Unsupported info type, expected CategoricalInfo or RealInfo, found " + realInfo.getClass().getName());
        }
        RealInfo realInfo2 = realInfo;
        double d3 = d - d2;
        double max = j != ((long) realInfo.getCount()) ? Math.max(realInfo2.getMax(), 0.0d) - Math.min(realInfo2.getMin(), 0.0d) : realInfo2.getMax() - realInfo2.getMin();
        return (d3 * d3) / (max * max);
    }

    public static Regressor transformOutput(Prediction<Label> prediction) {
        Map outputScores = prediction.getOutputScores();
        String[] strArr = new String[outputScores.size()];
        double[] dArr = new double[outputScores.size()];
        int i = 0;
        for (Map.Entry entry : outputScores.entrySet()) {
            strArr[i] = (String) entry.getKey();
            dArr[i] = ((Label) entry.getValue()).getScore();
            i++;
        }
        return new Regressor(strArr, dArr);
    }

    @Override // org.tribuo.classification.explanations.TabularExplainer
    /* renamed from: explain, reason: avoid collision after fix types in other method */
    public /* bridge */ /* synthetic */ Explanation<Regressor> explain2(Example example) {
        return explain((Example<Label>) example);
    }
}
