package org.wlld.randomForest;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.wlld.tools.ArithUtil;

/* loaded from: input_file:org/wlld/randomForest/RandomForest.class */
public class RandomForest {
    private Tree[] forest;
    private Random random = new Random();
    private double trustTh = 0.1d;
    private double trustPunishment = 0.1d;

    public double getTrustPunishment() {
        return this.trustPunishment;
    }

    public void setTrustPunishment(double d) {
        this.trustPunishment = d;
    }

    public double getTrustTh() {
        return this.trustTh;
    }

    public void setTrustTh(double d) {
        this.trustTh = d;
    }

    public RandomForest() {
    }

    public RandomForest(int i) throws Exception {
        if (i <= 0) {
            throw new Exception("Number of trees must be greater than 0");
        }
        this.forest = new Tree[i];
    }

    public RfModel getModel() {
        RfModel rfModel = new RfModel();
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.forest.length; i++) {
            hashMap.put(Integer.valueOf(i), this.forest[i].getRootNode());
        }
        rfModel.setNodeMap(hashMap);
        return rfModel;
    }

    public int forest(Object obj) throws Exception {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.forest.length; i++) {
            TreeWithTrust judge = this.forest[i].judge(obj);
            int type = judge.getType();
            double trust = judge.getTrust();
            if (hashMap.containsKey(Integer.valueOf(type))) {
                hashMap.put(Integer.valueOf(type), Double.valueOf(ArithUtil.add(((Double) hashMap.get(Integer.valueOf(type))).doubleValue(), trust)));
            } else {
                hashMap.put(Integer.valueOf(type), Double.valueOf(trust));
            }
        }
        int i2 = 0;
        double d = 0.0d;
        for (Map.Entry entry : hashMap.entrySet()) {
            double doubleValue = ((Double) entry.getValue()).doubleValue();
            if (doubleValue > d) {
                i2 = ((Integer) entry.getKey()).intValue();
                d = doubleValue;
            }
        }
        if (d < ArithUtil.mul(this.forest.length, this.trustTh)) {
            i2 = 0;
        }
        return i2;
    }

    public void init(DataTable dataTable) throws Exception {
        if (dataTable.getSize() <= 4) {
            throw new Exception("Number of feature categories must be greater than 3");
        }
        int div = (int) ArithUtil.div(Math.log(dataTable.getSize()), Math.log(2.0d));
        for (int i = 0; i < this.forest.length; i++) {
            this.forest[i] = new Tree(getRandomData(dataTable, div), this.trustPunishment);
        }
    }

    public void study() throws Exception {
        for (int i = 0; i < this.forest.length; i++) {
            this.forest[i].study();
        }
    }

    public void insert(Object obj) {
        for (int i = 0; i < this.forest.length; i++) {
            this.forest[i].getDataTable().insert(obj);
        }
    }

    private DataTable getRandomData(DataTable dataTable, int i) throws Exception {
        Set<String> keyType = dataTable.getKeyType();
        HashSet hashSet = new HashSet();
        String key = dataTable.getKey();
        ArrayList arrayList = new ArrayList();
        for (String str : keyType) {
            if (!str.equals(key)) {
                arrayList.add(str);
            }
        }
        for (int i2 = 0; i2 < i; i2++) {
            int nextInt = this.random.nextInt(arrayList.size());
            hashSet.add((String) arrayList.get(nextInt));
            arrayList.remove(nextInt);
        }
        hashSet.add(key);
        DataTable dataTable2 = new DataTable(hashSet);
        dataTable2.setKey(key);
        return dataTable2;
    }
}
