package com.github.chen0040.svmext.classifiers;

import com.github.chen0040.data.frame.BasicDataFrame;
import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataRow;
import com.github.chen0040.data.utils.TupleTwo;
import com.github.chen0040.svmext.Learner;
import com.github.chen0040.svmext.regression.SVR;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/* loaded from: input_file:com/github/chen0040/svmext/classifiers/OneVsOneSVC.class */
public class OneVsOneSVC implements Learner {
    protected List<TupleTwo<SVR, SVR>> classifiers;
    private double alpha;
    private boolean shuffleData;
    private List<String> classLabels;
    private static String BINARY_LABEL = "success";

    public OneVsOneSVC(List<String> list) {
        this.alpha = 0.1d;
        this.shuffleData = false;
        this.classLabels = new ArrayList();
        this.classLabels.addAll(list);
        this.classifiers = new ArrayList();
    }

    public OneVsOneSVC() {
        this.alpha = 0.1d;
        this.shuffleData = false;
        this.classLabels = new ArrayList();
        this.classifiers = new ArrayList();
    }

    public boolean isShuffleData() {
        return this.shuffleData;
    }

    public void setShuffleData(boolean z) {
        this.shuffleData = z;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setAlpha(double d) {
        this.alpha = d;
    }

    protected void createClassifiers(DataFrame dataFrame) {
        this.classifiers = new ArrayList();
        if (this.classLabels.size() == 0) {
            this.classLabels.addAll((Collection) dataFrame.stream().map((v0) -> {
                return v0.categoricalTarget();
            }).distinct().collect(Collectors.toList()));
        }
        for (int i = 0; i < this.classLabels.size() - 1; i++) {
            for (int i2 = i + 1; i2 < this.classLabels.size(); i2++) {
                this.classifiers.add(new TupleTwo<>(createClassifier(this.classLabels.get(i)), createClassifier(this.classLabels.get(i2))));
            }
        }
    }

    protected SVR createClassifier(String str) {
        SVR svr = new SVR();
        svr.setName(str);
        return svr;
    }

    protected double getClassifierScore(DataRow dataRow, SVR svr) {
        return svr.transform(dataRow);
    }

    protected List<DataFrame> split(DataFrame dataFrame, int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add(new BasicDataFrame());
        }
        int i3 = 0;
        Iterator it = dataFrame.iterator();
        while (it.hasNext()) {
            ((DataFrame) arrayList.get(i3 % i)).addRow((DataRow) it.next());
            i3++;
        }
        return arrayList;
    }

    protected List<DataFrame> remerge(List<DataFrame> list, int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < list.size(); i2++) {
            BasicDataFrame basicDataFrame = new BasicDataFrame();
            for (int i3 = 0; i3 < i; i3++) {
                Iterator it = list.get((i2 + i3) % list.size()).iterator();
                while (it.hasNext()) {
                    basicDataFrame.addRow(((DataRow) it.next()).makeCopy());
                }
            }
            arrayList.add(basicDataFrame);
        }
        return arrayList;
    }

    @Override // com.github.chen0040.svmext.Learner
    public double transform(DataRow dataRow) {
        return this.classLabels.indexOf(classify(dataRow));
    }

    @Override // com.github.chen0040.svmext.Learner
    public void fit(DataFrame dataFrame) {
        createClassifiers(dataFrame);
        if (this.shuffleData) {
            dataFrame.shuffle();
        }
        List<DataFrame> split = split(dataFrame, this.classifiers.size());
        List<DataFrame> remerge = remerge(split, Math.max(1, ((int) this.alpha) * split.size()));
        for (int i = 0; i < this.classifiers.size(); i++) {
            TupleTwo<SVR, SVR> tupleTwo = this.classifiers.get(i);
            SVR svr = (SVR) tupleTwo._1();
            SVR svr2 = (SVR) tupleTwo._2();
            svr.fit(createBinaryBatch(remerge.get(i), svr.getName()));
            svr2.fit(createBinaryBatch(remerge.get(i), svr2.getName()));
        }
    }

    private DataFrame createBinaryBatch(DataFrame dataFrame, String str) {
        BasicDataFrame basicDataFrame = new BasicDataFrame();
        Iterator it = dataFrame.iterator();
        while (it.hasNext()) {
            DataRow dataRow = (DataRow) it.next();
            String categoricalTarget = dataRow.categoricalTarget();
            DataRow makeCopy = dataRow.makeCopy();
            makeCopy.setTargetCell(BINARY_LABEL, categoricalTarget.equals(str) ? 1.0d : 0.0d);
            basicDataFrame.addRow(makeCopy);
        }
        return basicDataFrame;
    }

    public String classify(DataRow dataRow) {
        DataRow makeCopy = dataRow.makeCopy();
        if (makeCopy.getTargetColumnNames().isEmpty()) {
            makeCopy.setTargetColumnNames(Collections.singletonList(BINARY_LABEL));
        }
        String str = null;
        int i = 0;
        for (Map.Entry<String, Integer> entry : score(makeCopy).entrySet()) {
            String key = entry.getKey();
            int intValue = entry.getValue().intValue();
            if (intValue > i) {
                i = intValue;
                str = key;
            }
        }
        if (str == null) {
            str = "NA";
        }
        return str;
    }

    public void reset() {
        this.classifiers.clear();
        this.classLabels.clear();
    }

    public List<String> getClassLabels() {
        return this.classLabels;
    }

    public Map<String, Integer> score(DataRow dataRow) {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.classifiers.size(); i++) {
            TupleTwo<SVR, SVR> tupleTwo = this.classifiers.get(i);
            SVR svr = (SVR) tupleTwo._1();
            SVR svr2 = (SVR) tupleTwo._2();
            double classifierScore = getClassifierScore(dataRow, svr);
            double classifierScore2 = getClassifierScore(dataRow, svr2);
            if (classifierScore != classifierScore2) {
                String name = classifierScore > classifierScore2 ? svr.getName() : svr2.getName();
                if (hashMap.containsKey(name)) {
                    hashMap.put(name, Integer.valueOf(((Integer) hashMap.get(name)).intValue() + 1));
                } else {
                    hashMap.put(name, 1);
                }
            }
        }
        return hashMap;
    }
}
