package org.maochen.nlp.ml.util;

import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.maochen.nlp.ml.Tuple;

/* loaded from: input_file:org/maochen/nlp/ml/util/TrainingDataUtils.class */
public class TrainingDataUtils {
    public static List<Tuple> createBalancedTrainingData(List<Tuple> list) {
        ArrayList arrayList = new ArrayList(list);
        Collections.shuffle(arrayList);
        Map map = (Map) list.parallelStream().map(tuple -> {
            return new AbstractMap.SimpleImmutableEntry(tuple.label, 1);
        }).collect(Collectors.groupingBy((v0) -> {
            return v0.getKey();
        }, Collectors.counting()));
        long longValue = ((Long) map.values().stream().min((v0, v1) -> {
            return v0.compareTo(v1);
        }).get()).longValue();
        Map map2 = (Map) map.entrySet().stream().map((v0) -> {
            return v0.getKey();
        }).map(str -> {
            return new AbstractMap.SimpleImmutableEntry(str, 0);
        }).collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, (v0) -> {
            return v0.getValue();
        }));
        ListIterator listIterator = arrayList.listIterator(arrayList.size());
        while (listIterator.hasPrevious()) {
            Tuple tuple2 = (Tuple) listIterator.previous();
            int intValue = ((Integer) map2.get(tuple2.label)).intValue();
            if (intValue < longValue) {
                map2.put(tuple2.label, Integer.valueOf(intValue + 1));
            } else {
                listIterator.remove();
            }
        }
        return arrayList;
    }

    public static Pair<List<Tuple>, List<Tuple>> splitData(List<Tuple> list, double d) {
        int i;
        if (d < 0.0d || d > 1.0d) {
            throw new RuntimeException("Proportion should between 0.0 - 1.0");
        }
        if (d > 0.5d) {
            d = 1.0d - d;
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int floor = (int) Math.floor(d * list.size());
        HashSet hashSet = new HashSet();
        for (int i2 = 0; i2 < floor && list.size() > hashSet.size(); i2++) {
            double random = Math.random();
            int size = list.size();
            while (true) {
                i = (int) (random * (size - 1));
                if (hashSet.contains(Integer.valueOf(i))) {
                    random = Math.random();
                    size = list.size();
                }
            }
            hashSet.add(Integer.valueOf(i));
        }
        Stream stream = hashSet.stream();
        list.getClass();
        arrayList.addAll((Collection) stream.map((v1) -> {
            return r2.get(v1);
        }).collect(Collectors.toList()));
        IntStream.range(0, list.size()).filter(i3 -> {
            return !hashSet.contains(Integer.valueOf(i3));
        }).forEach(i4 -> {
            arrayList2.add(list.get(i4));
        });
        return new ImmutablePair(arrayList, arrayList2);
    }
}
