package weka.classifiers.mi;

import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import weka.classifiers.RandomizableClassifier;
import weka.core.Capabilities;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.MultiInstanceCapabilitiesHandler;
import weka.core.Optimization;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.unsupervised.attribute.Standardize;

/* loaded from: input_file:weka/classifiers/mi/MIEMDD.class */
public class MIEMDD extends RandomizableClassifier implements OptionHandler, MultiInstanceCapabilitiesHandler, TechnicalInformationHandler {
    static final long serialVersionUID = 3899547154866223734L;
    protected int m_ClassIndex;
    protected double[] m_Par;
    protected int m_NumClasses;
    protected int[] m_Classes;
    protected double[][][] m_Data;
    protected Instances m_Attributes;
    protected double[][] m_emData;
    public static final int FILTER_NORMALIZE = 0;
    public static final int FILTER_STANDARDIZE = 1;
    public static final int FILTER_NONE = 2;
    public static final Tag[] TAGS_FILTER = {new Tag(0, "Normalize training data"), new Tag(1, "Standardize training data"), new Tag(2, "No normalization/standardization")};
    protected Filter m_Filter = null;
    protected int m_filterType = 1;
    protected ReplaceMissingValues m_Missing = new ReplaceMissingValues();

    /* loaded from: input_file:weka/classifiers/mi/MIEMDD$OptEng.class */
    private class OptEng extends Optimization {
        private OptEng() {
        }

        protected double objectiveFunction(double[] dArr) {
            double d;
            double log;
            double d2 = 0.0d;
            for (int i = 0; i < MIEMDD.this.m_Classes.length; i++) {
                double d3 = 0.0d;
                for (int i2 = 0; i2 < MIEMDD.this.m_emData[i].length; i2++) {
                    d3 += (MIEMDD.this.m_emData[i][i2] - dArr[i2 * 2]) * (MIEMDD.this.m_emData[i][i2] - dArr[i2 * 2]) * dArr[(i2 * 2) + 1] * dArr[(i2 * 2) + 1];
                }
                double exp = Math.exp(-d3);
                if (MIEMDD.this.m_Classes[i] == 1) {
                    if (exp <= m_Zero) {
                        exp = m_Zero;
                    }
                    d = d2;
                    log = Math.log(exp);
                } else {
                    double d4 = 1.0d - exp;
                    if (d4 <= m_Zero) {
                        d4 = m_Zero;
                    }
                    d = d2;
                    log = Math.log(d4);
                }
                d2 = d - log;
            }
            return d2;
        }

        protected double[] evaluateGradient(double[] dArr) {
            double[] dArr2 = new double[dArr.length];
            for (int i = 0; i < MIEMDD.this.m_Classes.length; i++) {
                double[] dArr3 = new double[dArr.length];
                double d = 0.0d;
                for (int i2 = 0; i2 < MIEMDD.this.m_emData[i].length; i2++) {
                    d += (MIEMDD.this.m_emData[i][i2] - dArr[i2 * 2]) * (MIEMDD.this.m_emData[i][i2] - dArr[i2 * 2]) * dArr[(i2 * 2) + 1] * dArr[(i2 * 2) + 1];
                }
                double exp = Math.exp(-d);
                for (int i3 = 0; i3 < MIEMDD.this.m_emData[i].length; i3++) {
                    dArr3[2 * i3] = 2.0d * (dArr[2 * i3] - MIEMDD.this.m_emData[i][i3]) * dArr[(i3 * 2) + 1] * dArr[(i3 * 2) + 1];
                    dArr3[(2 * i3) + 1] = 2.0d * (dArr[2 * i3] - MIEMDD.this.m_emData[i][i3]) * (dArr[2 * i3] - MIEMDD.this.m_emData[i][i3]) * dArr[(i3 * 2) + 1];
                }
                for (int i4 = 0; i4 < MIEMDD.this.m_emData[i].length; i4++) {
                    if (MIEMDD.this.m_Classes[i] == 1) {
                        int i5 = 2 * i4;
                        dArr2[i5] = dArr2[i5] + dArr3[2 * i4];
                        int i6 = (2 * i4) + 1;
                        dArr2[i6] = dArr2[i6] + dArr3[(2 * i4) + 1];
                    } else {
                        int i7 = 2 * i4;
                        dArr2[i7] = dArr2[i7] - ((dArr3[2 * i4] * exp) / (1.0d - exp));
                        int i8 = (2 * i4) + 1;
                        dArr2[i8] = dArr2[i8] - ((dArr3[(2 * i4) + 1] * exp) / (1.0d - exp));
                    }
                }
            }
            return dArr2;
        }

        public String getRevision() {
            return RevisionUtils.extract("$Revision: 8109 $");
        }
    }

    public String globalInfo() {
        return "EMDD model builds heavily upon Dietterich's Diverse Density (DD) algorithm.\nIt is a general framework for MI learning of converting the MI problem to a single-instance setting using EM. In this implementation, we use most-likely cause DD model and only use 3 random selected postive bags as initial starting points of EM.\n\nFor more information see:\n\n" + getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Qi Zhang and Sally A. Goldman");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "EM-DD: An Improved Multiple-Instance Learning Technique");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "Advances in Neural Information Processing Systems 14");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2001");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "1073-108");
        technicalInformation.setValue(TechnicalInformation.Field.PUBLISHER, "MIT Press");
        return technicalInformation;
    }

    public Enumeration listOptions() {
        Vector vector = new Vector();
        vector.addElement(new Option("\tWhether to 0=normalize/1=standardize/2=neither.\n\t(default 1=standardize)", "N", 1, "-N <num>"));
        Enumeration listOptions = super.listOptions();
        while (listOptions.hasMoreElements()) {
            vector.addElement(listOptions.nextElement());
        }
        return vector.elements();
    }

    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('N', strArr);
        if (option.length() != 0) {
            setFilterType(new SelectedTag(Integer.parseInt(option), TAGS_FILTER));
        } else {
            setFilterType(new SelectedTag(1, TAGS_FILTER));
        }
        super.setOptions(strArr);
    }

    public String[] getOptions() {
        Vector vector = new Vector();
        for (String str : super.getOptions()) {
            vector.add(str);
        }
        vector.add("-N");
        vector.add("" + this.m_filterType);
        return (String[]) vector.toArray(new String[vector.size()]);
    }

    public String filterTypeTipText() {
        return "The filter type for transforming the training data.";
    }

    public SelectedTag getFilterType() {
        return new SelectedTag(this.m_filterType, TAGS_FILTER);
    }

    public void setFilterType(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_FILTER) {
            this.m_filterType = selectedTag.getSelectedTag().getID();
        }
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.RELATIONAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.enable(Capabilities.Capability.BINARY_CLASS);
        capabilities.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        capabilities.enable(Capabilities.Capability.ONLY_MULTIINSTANCE);
        return capabilities;
    }

    public Capabilities getMultiInstanceCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.MISSING_VALUES);
        capabilities.disableAllClasses();
        capabilities.enable(Capabilities.Capability.NO_CLASS);
        return capabilities;
    }

    public void buildClassifier(Instances instances) throws Exception {
        int nextInt;
        int nextInt2;
        int nextInt3;
        double[] dArr;
        getCapabilities().testWithFail(instances);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        this.m_ClassIndex = instances2.classIndex();
        this.m_NumClasses = instances2.numClasses();
        int numAttributes = instances2.attribute(1).relation().numAttributes();
        int numInstances = instances2.numInstances();
        int[] iArr = new int[numInstances];
        Instances instances3 = new Instances(instances2.attribute(1).relation(), 0);
        this.m_Data = new double[numInstances][numAttributes];
        this.m_Classes = new int[numInstances];
        this.m_Attributes = instances3.stringFreeStructure();
        if (this.m_Debug) {
            System.out.println("\n\nExtracting data...");
        }
        for (int i = 0; i < numInstances; i++) {
            Instance instance = instances2.instance(i);
            this.m_Classes[i] = (int) instance.classValue();
            Instances relationalValue = instance.relationalValue(1);
            for (int i2 = 0; i2 < relationalValue.numInstances(); i2++) {
                instances3.add(relationalValue.instance(i2));
            }
            iArr[i] = relationalValue.numInstances();
        }
        if (this.m_filterType == 1) {
            this.m_Filter = new Standardize();
        } else if (this.m_filterType == 0) {
            this.m_Filter = new Normalize();
        } else {
            this.m_Filter = null;
        }
        if (this.m_Filter != null) {
            this.m_Filter.setInputFormat(instances3);
            instances3 = Filter.useFilter(instances3, this.m_Filter);
        }
        this.m_Missing.setInputFormat(instances3);
        Instances useFilter = Filter.useFilter(instances3, this.m_Missing);
        int i3 = 0;
        int i4 = 0;
        for (int i5 = 0; i5 < numInstances; i5++) {
            for (int i6 = 0; i6 < useFilter.numAttributes(); i6++) {
                this.m_Data[i5][i6] = new double[iArr[i5]];
                i3 = i4;
                for (int i7 = 0; i7 < iArr[i5]; i7++) {
                    this.m_Data[i5][i6][i7] = useFilter.instance(i3).value(i6);
                    i3++;
                }
            }
            i4 = i3;
        }
        if (this.m_Debug) {
            System.out.println("\n\nIteration History...");
        }
        this.m_emData = new double[numInstances][numAttributes];
        this.m_Par = new double[2 * numAttributes];
        double[] dArr2 = new double[numAttributes * 2];
        double[] dArr3 = new double[dArr2.length];
        double[] dArr4 = new double[dArr2.length];
        double[] dArr5 = new double[dArr2.length];
        double[][] dArr6 = new double[2][dArr2.length];
        double d = Double.MAX_VALUE;
        double d2 = Double.MAX_VALUE;
        for (int i8 = 0; i8 < dArr2.length; i8++) {
            dArr6[0][i8] = Double.NaN;
            dArr6[1][i8] = Double.NaN;
        }
        Random random = new Random(getSeed());
        FastVector fastVector = new FastVector();
        do {
            nextInt = random.nextInt(numInstances - 1);
        } while (this.m_Classes[nextInt] == 0);
        fastVector.addElement(new Integer(nextInt));
        while (true) {
            nextInt2 = random.nextInt(numInstances - 1);
            if (nextInt2 != nextInt && this.m_Classes[nextInt2] != 0) {
                break;
            }
        }
        fastVector.addElement(new Integer(nextInt2));
        while (true) {
            nextInt3 = random.nextInt(numInstances - 1);
            if (nextInt3 != nextInt && nextInt3 != nextInt2 && this.m_Classes[nextInt3] != 0) {
                break;
            }
        }
        fastVector.addElement(new Integer(nextInt3));
        for (int i9 = 0; i9 < fastVector.size(); i9++) {
            int intValue = ((Integer) fastVector.elementAt(i9)).intValue();
            if (this.m_Debug) {
                System.out.println("\nH0 at " + intValue);
            }
            for (int i10 = 0; i10 < this.m_Data[intValue][0].length; i10++) {
                for (int i11 = 0; i11 < numAttributes; i11++) {
                    dArr2[2 * i11] = this.m_Data[intValue][i11][i10];
                    dArr2[(2 * i11) + 1] = 1.0d;
                }
                double d3 = Double.MAX_VALUE;
                double d4 = 1.7976931348623158E307d;
                int i12 = 0;
                while (d4 < d3 && i12 < 10) {
                    i12++;
                    d3 = d4;
                    if (this.m_Debug) {
                        System.out.println("\niteration: " + i12);
                    }
                    for (int i13 = 0; i13 < this.m_Data.length; i13++) {
                        int findInstance = findInstance(i13, dArr2);
                        for (int i14 = 0; i14 < this.m_Data[0].length; i14++) {
                            this.m_emData[i13][i14] = this.m_Data[i13][i14][findInstance];
                        }
                    }
                    if (this.m_Debug) {
                        System.out.println("E-step for new H' finished");
                    }
                    OptEng optEng = new OptEng();
                    double[] findArgmin = optEng.findArgmin(dArr2, dArr6);
                    while (true) {
                        dArr = findArgmin;
                        if (dArr == null) {
                            double[] varbValues = optEng.getVarbValues();
                            if (this.m_Debug) {
                                System.out.println("200 iterations finished, not enough!");
                            }
                            findArgmin = optEng.findArgmin(varbValues, dArr6);
                        }
                    }
                    d4 = optEng.getMinFunction();
                    dArr4 = dArr2;
                    dArr2 = dArr;
                }
                double[] dArr7 = new double[2];
                int i15 = 0;
                if (d4 > d3) {
                    this.m_Par = dArr4;
                } else {
                    this.m_Par = dArr2;
                }
                for (int i16 = 0; i16 < instances2.numInstances(); i16++) {
                    double[] distributionForInstance = distributionForInstance(instances2.instance(i16));
                    if (distributionForInstance[1] >= 0.5d && this.m_Classes[i16] == 0) {
                        i15++;
                    } else if (distributionForInstance[1] < 0.5d && this.m_Classes[i16] == 1) {
                        i15++;
                    }
                }
                if (i15 < d2) {
                    dArr5 = this.m_Par;
                    d2 = i15;
                    d = d4 > d3 ? d3 : d4;
                    if (this.m_Debug) {
                        System.out.println("error= " + i15 + "  nll= " + d);
                    }
                }
            }
            if (this.m_Debug) {
                System.out.println(intValue + ":  -------------<Converged>--------------");
                System.out.println("current minimum error= " + d2 + "  nll= " + d);
            }
        }
        this.m_Par = dArr5;
    }

    protected int findInstance(int i, double[] dArr) {
        double d = Double.MAX_VALUE;
        int i2 = 0;
        int length = this.m_Data[i][0].length;
        for (int i3 = 0; i3 < length; i3++) {
            double d2 = 0.0d;
            for (int i4 = 0; i4 < this.m_Data[i].length; i4++) {
                d2 += (this.m_Data[i][i4][i3] - dArr[i4 * 2]) * (this.m_Data[i][i4][i3] - dArr[i4 * 2]) * dArr[(i4 * 2) + 1] * dArr[(i4 * 2) + 1];
            }
            if (d2 < d) {
                d = d2;
                i2 = i3;
            }
        }
        return i2;
    }

    public double[] distributionForInstance(Instance instance) throws Exception {
        Instances relationalValue = instance.relationalValue(1);
        if (this.m_Filter != null) {
            relationalValue = Filter.useFilter(relationalValue, this.m_Filter);
        }
        Instances useFilter = Filter.useFilter(relationalValue, this.m_Missing);
        int numInstances = useFilter.numInstances();
        int numAttributes = useFilter.numAttributes();
        double[][] dArr = new double[numInstances][numAttributes];
        for (int i = 0; i < numInstances; i++) {
            for (int i2 = 0; i2 < numAttributes; i2++) {
                dArr[i][i2] = useFilter.instance(i).value(i2);
            }
        }
        double d = Double.MAX_VALUE;
        double d2 = -1.0d;
        for (int i3 = 0; i3 < numInstances; i3++) {
            double d3 = 0.0d;
            for (int i4 = 0; i4 < numAttributes; i4++) {
                d3 += (dArr[i3][i4] - this.m_Par[i4 * 2]) * (dArr[i3][i4] - this.m_Par[i4 * 2]) * this.m_Par[(i4 * 2) + 1] * this.m_Par[(i4 * 2) + 1];
            }
            if (d3 < d) {
                d = d3;
                d2 = Math.exp(-d3);
            }
        }
        double[] dArr2 = {1.0d - dArr2[1], d2};
        return dArr2;
    }

    public String toString() {
        if (this.m_Par == null) {
            return "MIEMDD: No model built yet.";
        }
        String str = "MIEMDD\nCoefficients...\nVariable       Point       Scale\n";
        int i = 0;
        int i2 = 0;
        while (i < this.m_Par.length / 2) {
            str = ((str + this.m_Attributes.attribute(i2).name()) + " " + Utils.doubleToString(this.m_Par[i * 2], 12, 4)) + " " + Utils.doubleToString(this.m_Par[(i * 2) + 1], 12, 4) + "\n";
            i++;
            i2++;
        }
        return str;
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision: 8109 $");
    }

    public static void main(String[] strArr) {
        runClassifier(new MIEMDD(), strArr);
    }
}
