package org.tribuo.multilabel;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.OptionalDouble;
import java.util.Set;
import java.util.stream.Collectors;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.classification.Classifiable;
import org.tribuo.classification.Label;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;

/* loaded from: input_file:org/tribuo/multilabel/MultiLabel.class */
public class MultiLabel implements Classifiable<MultiLabel> {
    private static final long serialVersionUID = 1;
    public static final String NEGATIVE_LABEL_STRING = "ML##NEGATIVE";
    public static final Label NEGATIVE_LABEL = new Label(NEGATIVE_LABEL_STRING);
    private final String label;
    private final double score;
    private final Set<Label> labels;
    private final Set<String> labelStrings;

    public MultiLabel(Set<Label> set) {
        this(set, Double.NaN);
    }

    public MultiLabel(Set<Label> set, double d) {
        this.label = MultiLabelFactory.generateLabelString(set);
        this.score = d;
        this.labels = Collections.unmodifiableSet(new HashSet(set));
        HashSet hashSet = new HashSet(set.size());
        Iterator<Label> it = set.iterator();
        while (it.hasNext()) {
            hashSet.add(it.next().getLabel());
        }
        this.labelStrings = Collections.unmodifiableSet(hashSet);
    }

    public MultiLabel(String str) {
        this(new Label(str));
    }

    public MultiLabel(Label label) {
        this.label = label.getLabel();
        this.score = Double.NaN;
        this.labels = Collections.singleton(label);
        this.labelStrings = Collections.singleton(label.getLabel());
    }

    public Label createLabel(Label label) {
        return this.labelStrings.contains(label.getLabel()) ? label : NEGATIVE_LABEL;
    }

    public String getLabelString() {
        return this.label;
    }

    public double getScore() {
        return this.score;
    }

    public OptionalDouble getLabelScore(Label label) {
        Label label2 = null;
        for (Label label3 : this.labels) {
            if (label3.getLabel().equals(label.getLabel())) {
                label2 = label3;
            }
        }
        return label2 != null ? OptionalDouble.of(label2.getScore()) : OptionalDouble.empty();
    }

    public Set<Label> getLabelSet() {
        return new HashSet(this.labels);
    }

    public Set<String> getNameSet() {
        return new HashSet(this.labelStrings);
    }

    public boolean contains(String str) {
        return this.labelStrings.contains(str);
    }

    public boolean contains(Label label) {
        return this.labels.contains(label);
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        MultiLabel multiLabel = (MultiLabel) obj;
        return this.labelStrings != null ? this.labelStrings.equals(multiLabel.labelStrings) : multiLabel.labelStrings == null;
    }

    public boolean fullEquals(MultiLabel multiLabel) {
        if (this == multiLabel) {
            return true;
        }
        if (multiLabel == null || getClass() != multiLabel.getClass() || Double.compare(this.score, multiLabel.score) != 0) {
            return false;
        }
        HashMap hashMap = new HashMap();
        for (Label label : this.labels) {
            hashMap.put(label.getLabel(), Double.valueOf(label.getScore()));
        }
        HashMap hashMap2 = new HashMap();
        for (Label label2 : multiLabel.labels) {
            hashMap2.put(label2.getLabel(), Double.valueOf(label2.getScore()));
        }
        if (hashMap.size() != hashMap2.size()) {
            return false;
        }
        for (Map.Entry entry : hashMap.entrySet()) {
            Double d = (Double) entry.getValue();
            Double d2 = (Double) hashMap2.get(entry.getKey());
            if (d2 == null || Double.compare(d.doubleValue(), d2.doubleValue()) != 0) {
                return false;
            }
        }
        return true;
    }

    public int hashCode() {
        return Objects.hash(this.labelStrings);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("(LabelSet={");
        Iterator<Label> it = this.labels.iterator();
        while (it.hasNext()) {
            sb.append(it.next().toString());
            sb.append(',');
        }
        sb.deleteCharAt(sb.length() - 1);
        sb.append('}');
        if (!Double.isNaN(this.score)) {
            sb.append(",OverallScore=");
            sb.append(this.score);
        }
        sb.append(")");
        return sb.toString();
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public MultiLabel m4copy() {
        return new MultiLabel(this.labels, this.score);
    }

    public String getSerializableForm(boolean z) {
        String str = (String) this.labels.stream().map(label -> {
            return String.format("%s=%b", label, true);
        }).collect(Collectors.joining(","));
        return z ? str + ":" + this.score : str;
    }

    public DenseVector convertToDenseVector(ImmutableOutputInfo<MultiLabel> immutableOutputInfo) {
        if (!(immutableOutputInfo instanceof ImmutableMultiLabelInfo)) {
            throw new IllegalStateException("Unexpected info type, found " + immutableOutputInfo.getClass().getName() + ", expected " + ImmutableMultiLabelInfo.class.getName());
        }
        ImmutableMultiLabelInfo immutableMultiLabelInfo = (ImmutableMultiLabelInfo) immutableOutputInfo;
        HashSet hashSet = new HashSet(this.labels.size());
        double[] dArr = new double[immutableMultiLabelInfo.size()];
        for (Label label : this.labels) {
            int id = immutableMultiLabelInfo.getID(label.getLabel());
            if (id == -1) {
                throw new IllegalArgumentException("Unknown label '" + label.getLabel() + "' which was not recognised by the supplied info object, info = " + immutableOutputInfo.toString());
            }
            if (hashSet.contains(Integer.valueOf(id))) {
                throw new IllegalArgumentException("Duplicate label ids found for id " + id + ", mapping to Label '" + label.getLabel() + "'");
            }
            double score = label.getScore();
            if (Double.isNaN(score)) {
                score = 1.0d;
            }
            hashSet.add(Integer.valueOf(id));
            dArr[id] = score;
        }
        return DenseVector.createDenseVector(dArr);
    }

    public SparseVector convertToSparseVector(ImmutableOutputInfo<MultiLabel> immutableOutputInfo) {
        if (!(immutableOutputInfo instanceof ImmutableMultiLabelInfo)) {
            throw new IllegalStateException("Unexpected info type, found " + immutableOutputInfo.getClass().getName() + ", expected " + ImmutableMultiLabelInfo.class.getName());
        }
        ImmutableMultiLabelInfo immutableMultiLabelInfo = (ImmutableMultiLabelInfo) immutableOutputInfo;
        HashMap hashMap = new HashMap();
        for (Label label : this.labels) {
            int id = immutableMultiLabelInfo.getID(label.getLabel());
            if (id == -1) {
                throw new IllegalArgumentException("Unknown label '" + label.getLabel() + "' which was not recognised by the supplied info object, info = " + immutableOutputInfo.toString());
            }
            double score = label.getScore();
            if (Double.isNaN(score)) {
                score = 1.0d;
            }
            if (((Double) hashMap.put(Integer.valueOf(id), Double.valueOf(score))) != null) {
                throw new IllegalArgumentException("Duplicate label ids found for id " + id + ", mapping to Label '" + label.getLabel() + "'");
            }
        }
        return SparseVector.createSparseVector(immutableOutputInfo.size(), hashMap);
    }

    public static MultiLabel parseString(String str) {
        return parseString(str, ',');
    }

    public static MultiLabel parseString(String str, char c) {
        if (c == '=') {
            throw new IllegalArgumentException("Can't split on an equals symbol");
        }
        String[] split = str.split("" + c);
        ArrayList arrayList = new ArrayList();
        for (String str2 : split) {
            arrayList.add(parseElement(str2));
        }
        return createFromPairList(arrayList);
    }

    public static Pair<String, Boolean> parseElement(String str) {
        if (str.isEmpty()) {
            return new Pair<>("", false);
        }
        String[] split = str.split("=");
        if (split.length == 2) {
            return new Pair<>(split[0], Boolean.valueOf(Boolean.parseBoolean(split[1])));
        }
        if (split.length == 1) {
            return new Pair<>(split[0], true);
        }
        throw new IllegalArgumentException("Failed to parse element " + str);
    }

    public static MultiLabel createFromPairList(List<Pair<String, Boolean>> list) {
        HashSet hashSet = new HashSet();
        for (int i = 0; i < list.size(); i++) {
            Pair<String, Boolean> pair = list.get(i);
            String str = (String) pair.getA();
            if (((Boolean) pair.getB()).booleanValue()) {
                hashSet.add(new Label(str));
            }
        }
        return new MultiLabel(hashSet);
    }

    public static int intersectionSize(MultiLabel multiLabel, MultiLabel multiLabel2) {
        HashSet hashSet = new HashSet();
        hashSet.addAll(multiLabel.labelStrings);
        hashSet.retainAll(multiLabel2.labelStrings);
        return hashSet.size();
    }

    public static int unionSize(MultiLabel multiLabel, MultiLabel multiLabel2) {
        HashSet hashSet = new HashSet();
        hashSet.addAll(multiLabel.labelStrings);
        hashSet.addAll(multiLabel2.labelStrings);
        return hashSet.size();
    }

    public static double jaccardScore(MultiLabel multiLabel, MultiLabel multiLabel2) {
        return intersectionSize(multiLabel, multiLabel2) / unionSize(multiLabel, multiLabel2);
    }
}
