package org.tweetyproject.machinelearning;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.StringTokenizer;
import libsvm.svm_node;
import libsvm.svm_problem;
import org.tweetyproject.commons.util.Pair;
import org.tweetyproject.machinelearning.Category;
import org.tweetyproject.machinelearning.Observation;

/* JADX WARN: Classes with same name are omitted:
  input_file:org.tweetyproject.machinelearning-1.19-SNAPSHOT.jar:org/tweetyproject/machinelearning/TrainingSet.class
 */
/* loaded from: input_file:org.tweetyproject.machinelearning-1.20.jar:org/tweetyproject/machinelearning/TrainingSet.class */
public class TrainingSet<S extends Observation, T extends Category> extends HashSet<Pair<S, T>> {
    private static final long serialVersionUID = 6814079760992723045L;

    public boolean add(S s, T t) {
        return add(new Pair(s, t));
    }

    public Collection<T> getCategories() {
        HashSet hashSet = new HashSet();
        Iterator<Pair<S, T>> it = iterator();
        while (it.hasNext()) {
            hashSet.add((Category) ((Pair) it.next()).getSecond());
        }
        return hashSet;
    }

    public TrainingSet<S, T> getObservations(T t) {
        TrainingSet<S, T> trainingSet = new TrainingSet<>();
        Iterator<Pair<S, T>> it = iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            if (((Category) pair.getSecond()).equals(t)) {
                trainingSet.add(pair);
            }
        }
        return trainingSet;
    }

    /* JADX WARN: Type inference failed for: r1v8, types: [libsvm.svm_node[], libsvm.svm_node[][]] */
    public svm_problem toLibsvmProblem() {
        svm_problem svm_problemVar = new svm_problem();
        svm_problemVar.l = size();
        svm_problemVar.y = new double[svm_problemVar.l];
        svm_problemVar.x = new svm_node[svm_problemVar.l];
        int i = 0;
        Iterator<Pair<S, T>> it = iterator();
        while (it.hasNext()) {
            Pair pair = (Pair) it.next();
            svm_problemVar.y[i] = ((Category) pair.getSecond()).asDouble();
            svm_problemVar.x[i] = ((Observation) pair.getFirst()).toSvmNode();
            i++;
        }
        return svm_problemVar;
    }

    public static TrainingSet<DefaultObservation, DoubleCategory> loadLibsvmTrainingFile(File file) throws NumberFormatException, IOException {
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(new FileInputStream(file)));
        TrainingSet<DefaultObservation, DoubleCategory> trainingSet = new TrainingSet<>();
        while (true) {
            String readLine = bufferedReader.readLine();
            if (readLine == null) {
                bufferedReader.close();
                return trainingSet;
            }
            StringTokenizer stringTokenizer = new StringTokenizer(readLine, " ");
            DoubleCategory doubleCategory = new DoubleCategory(Double.parseDouble(stringTokenizer.nextToken()));
            DefaultObservation defaultObservation = new DefaultObservation();
            while (stringTokenizer.hasMoreElements()) {
                StringTokenizer stringTokenizer2 = new StringTokenizer(stringTokenizer.nextToken(), ":");
                stringTokenizer2.nextToken();
                defaultObservation.add(Double.valueOf(Double.parseDouble(stringTokenizer2.nextToken())));
            }
            trainingSet.add(defaultObservation, doubleCategory);
        }
    }
}
