package weka.classifiers.functions;

import java.util.Collections;
import java.util.Enumeration;
import java.util.Vector;
import org.apache.commons.math3.analysis.UnivariateFunction;
import org.apache.commons.math3.analysis.solvers.BisectionSolver;
import weka.classifiers.Classifier;
import weka.classifiers.SingleClassifierEnhancer;
import weka.classifiers.bayes.NaiveBayes;
import weka.classifiers.functions.explicitboundaries.combiners.PotentialFunction;
import weka.classifiers.functions.explicitboundaries.combiners.PotentialFunctionExp4;
import weka.classifiers.functions.explicitboundaries.combiners.PotentialFunctionTanh;
import weka.classifiers.functions.nearestCentroid.IClusterPrototype;
import weka.classifiers.functions.nearestCentroid.prototypes.CustomizablePrototype;
import weka.classifiers.functions.nearestCentroid.prototypes.MahalanobisPrototype;
import weka.classifiers.rules.ZeroR;
import weka.clusterers.ClassSpecificClusterer;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.Utils;
import weka.core.UtilsPT;
import weka.tools.SerialCopier;
import weka.tools.arrayFunctions.MeanFunction;
import weka.tools.arrayFunctions.MultivariateFunction;
import weka.tools.data.InstancesOperator;
import weka.tools.data.splitters.CopySplitter;
import weka.tools.data.splitters.DataSplitter;

/* loaded from: input_file:weka/classifiers/functions/BoundaryAndCentroidsClassifierMulticlass.class */
public class BoundaryAndCentroidsClassifierMulticlass extends SingleClassifierEnhancer {
    private static final long serialVersionUID = 8010565844371661224L;
    protected IClusterPrototype prototypeProto;
    protected IClusterPrototype[][] classProtos;
    protected double[][] clusterPotentialArgumentMultipliers;
    protected PotentialFunction potFunction;
    protected double proportion;
    protected boolean classesOnly;
    protected double eps;
    protected ZeroR defaultModel;
    protected double[] classFreqs;
    protected boolean usePriors;
    protected ClassSpecificClusterer classSpecClusterer;
    protected DataSplitter dataSplitter;
    protected double quantile;
    protected double quantilePotentialVal;
    protected double minSearch;
    protected double maxSearch;
    protected int nBisectIterations;
    protected MultivariateFunction clusterCombiner;
    protected boolean useSoftMax;

    public BoundaryAndCentroidsClassifierMulticlass(Classifier classifier) {
        this.proportion = 0.5d;
        this.classesOnly = false;
        this.eps = Double.MIN_VALUE;
        this.usePriors = false;
        this.quantile = 0.9d;
        this.quantilePotentialVal = 0.1d;
        this.minSearch = 1.0E-6d;
        this.maxSearch = 5.0d;
        this.nBisectIterations = 1000;
        this.useSoftMax = true;
        setClassifier(classifier);
        this.prototypeProto = new CustomizablePrototype();
        this.potFunction = new PotentialFunctionTanh();
        this.classSpecClusterer = new ClassSpecificClusterer();
        this.dataSplitter = new CopySplitter();
        this.clusterCombiner = new MeanFunction();
    }

    public BoundaryAndCentroidsClassifierMulticlass() {
        this(new NaiveBayes());
    }

    public void buildClassifier(Instances instances) throws Exception {
        if (!this.m_DoNotCheckCapabilities) {
            getCapabilities().testWithFail(instances);
        }
        instances.numInstances();
        this.classFreqs = InstancesOperator.classFreq(instances);
        this.defaultModel = null;
        this.dataSplitter.train(instances);
        Instances[] split = this.dataSplitter.split(instances);
        for (Instances instances2 : split) {
            for (int i : InstancesOperator.uniqObjPerClass(instances2)) {
                if (i <= 1) {
                    this.defaultModel = new ZeroR();
                    this.defaultModel.buildClassifier(instances);
                    return;
                }
            }
        }
        buildBaseClassifier(split[0]);
        buildClusters(split[1]);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v12, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v7, types: [weka.classifiers.functions.nearestCentroid.IClusterPrototype[], weka.classifiers.functions.nearestCentroid.IClusterPrototype[][]] */
    protected void buildClusters(Instances instances) throws Exception {
        this.classSpecClusterer.buildClusterer(instances);
        int numAttributes = instances.numAttributes();
        int classIndex = instances.classIndex();
        this.classesOnly = false;
        if ((numAttributes == 1) && (classIndex >= 0)) {
            this.classesOnly = true;
            return;
        }
        int numClasses = instances.numClasses();
        int[] numberOfClassSpecificClusters = this.classSpecClusterer.numberOfClassSpecificClusters();
        Instances[] classSpecSplit = InstancesOperator.classSpecSplit(instances);
        Instances[] instancesArr = new Instances[numClasses];
        this.classProtos = new IClusterPrototype[numClasses];
        for (int i = 0; i < numClasses; i++) {
            int i2 = numberOfClassSpecificClusters[i];
            instancesArr[i] = new Instances[i2];
            this.classProtos[i] = new IClusterPrototype[i2];
            for (int i3 = 0; i3 < i2; i3++) {
                instancesArr[i][i3] = new Instances(classSpecSplit[i], 0);
                this.classProtos[i][i3] = (IClusterPrototype) SerialCopier.makeCopy(this.prototypeProto);
            }
        }
        for (int i4 = 0; i4 < numClasses; i4++) {
            int numInstances = classSpecSplit[i4].numInstances();
            for (int i5 = 0; i5 < numInstances; i5++) {
                Instance instance = classSpecSplit[i4].get(i5);
                instancesArr[i4][Utils.maxIndex(this.classSpecClusterer.classSpecificDistributionForInstance(instance)[i4])].add(instance);
            }
        }
        for (int i6 = 0; i6 < numClasses; i6++) {
            int i7 = numberOfClassSpecificClusters[i6];
            for (int i8 = 0; i8 < i7; i8++) {
                this.classProtos[i6][i8].build(instancesArr[i6][i8]);
            }
        }
        this.clusterPotentialArgumentMultipliers = new double[numClasses];
        for (int i9 = 0; i9 < numClasses; i9++) {
            int i10 = numberOfClassSpecificClusters[i9];
            this.clusterPotentialArgumentMultipliers[i9] = new double[i10];
            for (int i11 = 0; i11 < i10; i11++) {
                double[] dArr = new double[instancesArr[i9][i11].numInstances()];
                for (int i12 = 0; i12 < dArr.length; i12++) {
                    dArr[i12] = this.classProtos[i9][i11].distance(instancesArr[i9][i11].get(i12));
                }
                final double quantile = UtilsPT.quantile(dArr, this.quantile);
                this.clusterPotentialArgumentMultipliers[i9][i11] = new BisectionSolver().solve(this.nBisectIterations, new UnivariateFunction() { // from class: weka.classifiers.functions.BoundaryAndCentroidsClassifierMulticlass.1
                    public double value(double d) {
                        try {
                            return (1.0d - BoundaryAndCentroidsClassifierMulticlass.this.potFunction.getPotentialValue(quantile * d)) - BoundaryAndCentroidsClassifierMulticlass.this.quantilePotentialVal;
                        } catch (Exception e) {
                            return 0.0d;
                        }
                    }
                }, this.minSearch, this.maxSearch);
            }
        }
    }

    protected void buildBaseClassifier(Instances instances) throws Exception {
        this.m_Classifier.buildClassifier(instances);
    }

    protected double[] getresponse(Instance instance) throws Exception {
        double[] dArr = new double[this.classFreqs.length];
        for (int i = 0; i < this.classFreqs.length; i++) {
            int length = this.classProtos[i].length;
            double[] dArr2 = new double[length];
            for (int i2 = 0; i2 < length; i2++) {
                int i3 = i2;
                dArr2[i3] = dArr2[i3] + (1.0d - this.potFunction.getPotentialValue(this.clusterPotentialArgumentMultipliers[i][i2] * this.classProtos[i][i2].distance(instance)));
            }
            dArr[i] = this.clusterCombiner.value(dArr2);
        }
        if (this.useSoftMax) {
            dArr = UtilsPT.softMax(dArr);
        } else {
            double sum = Utils.sum(dArr);
            if (sum > 5.0d * this.eps) {
                Utils.normalize(dArr, sum);
            } else {
                dArr = UtilsPT.softMax(dArr);
            }
        }
        double[] distributionForInstance = this.m_Classifier.distributionForInstance(instance);
        for (int i4 = 0; i4 < this.classFreqs.length; i4++) {
            distributionForInstance[i4] = (this.proportion * distributionForInstance[i4]) + ((1.0d - this.proportion) * dArr[i4]);
        }
        return distributionForInstance;
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        return this.defaultModel != null ? this.defaultModel.distributionForInstance(instance) : this.classesOnly ? this.m_Classifier.distributionForInstance(instance) : getresponse(instance);
    }

    public String globalInfo() {
        return "Class  that implements algorithm that combines centroid and linear based classifiers";
    }

    public IClusterPrototype getPrototypeProto() {
        return this.prototypeProto;
    }

    public void setPrototypeProto(IClusterPrototype iClusterPrototype) {
        this.prototypeProto = iClusterPrototype;
    }

    public String prototypeProtoTipText() {
        return "Prototype for Cluster prototype object";
    }

    public PotentialFunction getPotFunction() {
        return this.potFunction;
    }

    public void setPotFunction(PotentialFunction potentialFunction) {
        this.potFunction = potentialFunction;
    }

    public String potFunctionTipText() {
        return "Potential function to use";
    }

    public double getProportion() {
        return this.proportion;
    }

    public void setProportion(double d) {
        if (d > 1.0d) {
            this.proportion = 1.0d;
        } else if (d < 0.0d) {
            this.proportion = 0.0d;
        } else {
            this.proportion = d;
        }
    }

    public String proportionTipText() {
        return "Proportion between Plane and cluster potentials";
    }

    public double getEps() {
        return this.eps;
    }

    public void setEps(double d) {
        this.eps = d;
    }

    public String epsTipText() {
        return "Epsilon factor for the algorithm";
    }

    public Enumeration<Option> listOptions() {
        Vector vector = new Vector(1);
        vector.addElement(new Option("\tThe cluster prototype to use (default: weka.classifiers.functions.nearestCentroid.prototypes.CustomizablePrototype).\n", "P", 1, "-P"));
        vector.addElement(new Option("\tThe class-specific clusterer to use (default: weka.clusterers.ClassSpecificClusterer).\n", "CSP", 1, "-CSP"));
        vector.addElement(new Option("\tProportion between Centroid and Plane potentials(default: 0.5).\n", "PR", 1, "-PR"));
        vector.addElement(new Option("\tThe Potential function to use (default: weka.classifiers.functions.explicitboundaries.combiners.PotentialFunctionTanh).\n", "PO", 1, "-PO"));
        vector.addElement(new Option("\tEpsilon factor(default: Double.MIN_VALUE).\n", "EPS", 1, "-EPS"));
        vector.addElement(new Option("\tDetermines whether the prior class probabilities are used(default: FALSE).\n", "UP", 0, "-UP"));
        vector.addElement(new Option("\tThe DataSplitter to use (default:" + CopySplitter.class.getCanonicalName() + ".\n", "DS", 1, "-DS"));
        vector.addElement(new Option("\tQuantile of points to use(default: 0.9).\n", "QA", 1, "-QA"));
        vector.addElement(new Option("\tPotential Value for given quantile(default: 0.1).\n", "PQA", 1, "-PQA"));
        vector.addElement(new Option("\tMin Multiplier Value(default: 1E-3).\n", "MiPV", 1, "-MiPV"));
        vector.addElement(new Option("\tMax Multiplier Value(default: 5.0).\n", "MaPV", 1, "-MaPV"));
        vector.addElement(new Option("\tBisection Iterations(default: 1000).\n", "BI", 1, "-BI"));
        vector.addElement(new Option("\tThe Cluster Combiner to use (default:" + MeanFunction.class.getCanonicalName() + ".\n", "CC", 1, "-CC"));
        vector.addAll(Collections.list(super.listOptions()));
        return vector.elements();
    }

    public void setOptions(String[] strArr) throws Exception {
        setProportion(UtilsPT.parseDoubleOption(strArr, "PR", 0.5d));
        setPrototypeProto((IClusterPrototype) UtilsPT.parseObjectOptions(strArr, "P", new MahalanobisPrototype(), IClusterPrototype.class));
        setClassSpecificClusterer((ClassSpecificClusterer) UtilsPT.parseObjectOptions(strArr, "CSP", new ClassSpecificClusterer(), ClassSpecificClusterer.class));
        setDataSplitter((DataSplitter) UtilsPT.parseObjectOptions(strArr, "DS", new CopySplitter(), DataSplitter.class));
        setPotFunction((PotentialFunction) UtilsPT.parseObjectOptions(strArr, "PO", new PotentialFunctionExp4(), PotentialFunction.class));
        setEps(UtilsPT.parseDoubleOption(strArr, "EPS", Double.MIN_VALUE));
        setUsePriors(Utils.getFlag("UP", strArr));
        setQuantile(UtilsPT.parseDoubleOption(strArr, "QA", 0.9d));
        setQuantilePotentialVal(UtilsPT.parseDoubleOption(strArr, "PQA", 0.1d));
        setMinSearch(UtilsPT.parseDoubleOption(strArr, "MiPV", 0.001d));
        setMaxSearch(UtilsPT.parseDoubleOption(strArr, "MaPV", 5.0d));
        setnBisectIterations(UtilsPT.parseIntegerOption(strArr, "BI", 1000));
        setClusterCombiner((MultivariateFunction) UtilsPT.parseObjectOptions(strArr, "CC", new MeanFunction(), MultivariateFunction.class));
        setUseSoftMax(Utils.getFlag("USM", strArr));
        super.setOptions(strArr);
    }

    public String[] getOptions() {
        Vector vector = new Vector();
        vector.add("-P");
        vector.add(UtilsPT.getClassAndOptions(getPrototypeProto()));
        vector.add("-CSP");
        vector.add(UtilsPT.getClassAndOptions(getClassSpecificClusterer()));
        vector.add("-PR");
        vector.add("" + getProportion());
        vector.add("-DS");
        vector.add(UtilsPT.getClassAndOptions(getDataSplitter()));
        vector.add("-PO");
        vector.add(UtilsPT.getClassAndOptions(getPotFunction()));
        vector.add("-EPS");
        vector.add("" + getEps());
        if (isUsePriors()) {
            vector.add("-UP");
        }
        vector.add("-QA");
        vector.add("" + getQuantile());
        vector.add("-PQA");
        vector.add("" + getQuantilePotentialVal());
        vector.add("-MiPV");
        vector.add("" + getMinSearch());
        vector.add("-MaPV");
        vector.add("" + getMaxSearch());
        vector.add("-BI");
        vector.add("" + getnBisectIterations());
        vector.add("-CC");
        vector.add(UtilsPT.getClassAndOptions(getClusterCombiner()));
        if (isUseSoftMax()) {
            vector.add("-USM");
        }
        Collections.addAll(vector, super.getOptions());
        return (String[]) vector.toArray(new String[0]);
    }

    public boolean isUsePriors() {
        return this.usePriors;
    }

    public void setUsePriors(boolean z) {
        this.usePriors = z;
    }

    public String usePriorsTipText() {
        return "Determines whether prior class probabilities are used.";
    }

    public ClassSpecificClusterer getClassSpecificClusterer() {
        return this.classSpecClusterer;
    }

    public void setClassSpecificClusterer(ClassSpecificClusterer classSpecificClusterer) {
        this.classSpecClusterer = classSpecificClusterer;
    }

    public String classSpecificClustererTipText() {
        return "Class-specific clusterer to use";
    }

    public DataSplitter getDataSplitter() {
        return this.dataSplitter;
    }

    public void setDataSplitter(DataSplitter dataSplitter) {
        this.dataSplitter = dataSplitter;
    }

    public String dataSplitterTipText() {
        return "Data splitter to use";
    }

    public String quantileTipText() {
        return "Distance quantile to use";
    }

    public double getQuantile() {
        return this.quantile;
    }

    public void setQuantile(double d) {
        this.quantile = d;
    }

    public String quantilePotentialValTipText() {
        return "Potential function value for given quantile";
    }

    public double getQuantilePotentialVal() {
        return this.quantilePotentialVal;
    }

    public void setQuantilePotentialVal(double d) {
        this.quantilePotentialVal = d;
    }

    public String minSearchTipText() {
        return "Min Multiplier value to use";
    }

    public double getMinSearch() {
        return this.minSearch;
    }

    public void setMinSearch(double d) {
        this.minSearch = d;
    }

    public String maxSearchTipText() {
        return "Max Multiplier value to use";
    }

    public double getMaxSearch() {
        return this.maxSearch;
    }

    public void setMaxSearch(double d) {
        this.maxSearch = d;
    }

    public String nBisectIterationsTipText() {
        return "Max number of bisection operations to perform";
    }

    public int getnBisectIterations() {
        return this.nBisectIterations;
    }

    public void setnBisectIterations(int i) {
        this.nBisectIterations = i;
    }

    public MultivariateFunction getClusterCombiner() {
        return this.clusterCombiner;
    }

    public void setClusterCombiner(MultivariateFunction multivariateFunction) {
        this.clusterCombiner = multivariateFunction;
    }

    public String clusterCombinerTipText() {
        return "Method of combining cluster responses";
    }

    public String useSoftMaxTipText() {
        return "Determines whether softmax cluster potential normalization is applied.";
    }

    public boolean isUseSoftMax() {
        return this.useSoftMax;
    }

    public void setUseSoftMax(boolean z) {
        this.useSoftMax = z;
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        capabilities.enable(Capabilities.Capability.EMPTY_NOMINAL_CLASS);
        return capabilities;
    }
}
