package org.fnlp.ml.classifier.hier;

import java.io.IOException;
import org.fnlp.ml.classifier.hier.inf.MultiLinearMax;
import org.fnlp.ml.classifier.linear.inf.Inferencer;
import org.fnlp.ml.eval.Evaluation;
import org.fnlp.ml.feature.BaseGenerator;
import org.fnlp.ml.loss.Loss;
import org.fnlp.ml.types.Instance;
import org.fnlp.ml.types.InstanceSet;
import org.fnlp.ml.types.sv.HashSparseVector;
import org.fnlp.util.MyArrays;
import org.fnlp.util.MyHashSparseArrays;

/* loaded from: input_file:org/fnlp/ml/classifier/hier/PATrainer.class */
public class PATrainer {
    private HashSparseVector[] weights;
    private Linear classifier;
    private MultiLinearMax msolver;
    private BaseGenerator featureGen;
    private Loss loss;
    private int maxIter;
    private Tree tree;
    private float c;
    public boolean interim;
    public boolean optim;
    private boolean incremental;
    private static final int historyNum = 5;
    private static final float eps = 1.0E-10f;

    public PATrainer(Linear linear, Loss loss, int i, float f, Tree tree) {
        this.maxIter = Integer.MAX_VALUE;
        this.interim = false;
        this.optim = false;
        this.incremental = false;
        this.msolver = (MultiLinearMax) linear.inf;
        this.msolver.isUseTarget(true);
        this.featureGen = linear.gen;
        this.loss = loss;
        this.maxIter = i;
        this.tree = tree;
        this.c = f;
        this.incremental = true;
        this.weights = linear.weights;
    }

    public PATrainer(Inferencer inferencer, BaseGenerator baseGenerator, Loss loss, int i, float f, Tree tree) {
        this.maxIter = Integer.MAX_VALUE;
        this.interim = false;
        this.optim = false;
        this.incremental = false;
        this.msolver = (MultiLinearMax) inferencer;
        this.featureGen = baseGenerator;
        this.loss = loss;
        this.maxIter = i;
        this.tree = tree;
        this.c = f;
    }

    public Linear getClassifier() {
        return this.classifier;
    }

    public Linear train(InstanceSet instanceSet, Evaluation evaluation) {
        System.out.println("Sample Size: " + instanceSet.size());
        System.out.println("Class Size: " + instanceSet.getAlphabetFactory().DefaultLabelAlphabet().size());
        if (!this.incremental) {
            this.weights = Mean.mean(instanceSet, this.tree);
            this.msolver.setWeight(this.weights);
        }
        float[] fArr = new float[historyNum];
        int size = instanceSet.size();
        int i = size / 10;
        System.out.println("Begin Training...");
        long currentTimeMillis = System.currentTimeMillis();
        int i2 = 0;
        while (true) {
            int i3 = i2;
            i2++;
            if (i3 >= this.maxIter) {
                break;
            }
            System.out.print("Loop: " + i2);
            float f = 0.0f;
            instanceSet.shuffle();
            long currentTimeMillis2 = System.currentTimeMillis();
            for (int i4 = 0; i4 < size; i4++) {
                Instance instanceSet2 = instanceSet.getInstance(i4);
                int intValue = ((Integer) instanceSet2.getTarget()).intValue();
                Predict best = this.msolver.getBest(instanceSet2, 1);
                Predict predict = (Predict) instanceSet2.getTempData();
                instanceSet2.deleteTempData();
                int intValue2 = best.getLabel(0).intValue();
                int dist = this.tree == null ? best.getLabel(0).intValue() == intValue ? 0 : 1 : this.tree.dist(intValue2, intValue);
                float score = dist - (predict.getScore(0) - best.getScore(0));
                if (score > 0.0f) {
                    f += 1.0f;
                    float min = Math.min(this.c, score / (this.featureGen.getVector(instanceSet2).l2Norm2() * dist));
                    if (this.tree != null) {
                        for (int i5 : this.tree.getPath(intValue)) {
                            this.weights[i5].plus(this.featureGen.getVector(instanceSet2), min);
                        }
                        for (int i6 : this.tree.getPath(intValue2)) {
                            this.weights[i6].plus(this.featureGen.getVector(instanceSet2), -min);
                        }
                    } else {
                        this.weights[intValue].plus(this.featureGen.getVector(instanceSet2), min);
                        this.weights[intValue2].plus(this.featureGen.getVector(instanceSet2), -min);
                    }
                }
                if (i == 0 || i4 % i == 0) {
                    System.out.print('.');
                }
            }
            float f2 = 1.0f - (f / size);
            System.out.print("\t Accuracy:" + f2);
            System.out.println("\t Time(s):" + ((System.currentTimeMillis() - currentTimeMillis2) / 1000));
            if (this.optim && i2 <= 2) {
                int i7 = 0;
                int i8 = 0;
                for (int i9 = 0; i9 < this.weights.length; i9++) {
                    i7 += this.weights[i9].size();
                    MyHashSparseArrays.trim(this.weights[i9], 0.99f);
                    i8 += this.weights[i9].size();
                }
                System.out.println("优化：\t原特征数：" + i7 + "\t新特征数：" + i8);
            }
            if (this.interim) {
                try {
                    new Linear(this.weights, this.msolver, this.featureGen, instanceSet.getPipes(), instanceSet.getAlphabetFactory()).saveTo("./tmp/model.gz");
                } catch (IOException e) {
                    System.err.println("write model error!");
                }
                this.msolver.isUseTarget(true);
            }
            if (evaluation != null) {
                System.out.print("Test:\t");
                evaluation.eval(new Linear(this.weights, this.msolver), 2);
                this.msolver.isUseTarget(true);
            }
            fArr[i2 % historyNum] = f2;
            if (MyArrays.viarance(fArr) < eps) {
                System.out.println("convergence!");
                break;
            }
        }
        System.out.println("Training End");
        System.out.println("Training Time(s):" + ((System.currentTimeMillis() - currentTimeMillis) / 1000));
        this.classifier = new Linear(this.weights, this.msolver, this.featureGen, instanceSet.getPipes(), instanceSet.getAlphabetFactory());
        return this.classifier;
    }
}
