package org.tribuo.multilabel.baseline;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.multilabel.MultiLabel;
import org.tribuo.provenance.ModelProvenance;

/* loaded from: input_file:org/tribuo/multilabel/baseline/IndependentMultiLabelModel.class */
public class IndependentMultiLabelModel extends Model<MultiLabel> {
    private static final long serialVersionUID = 1;
    private final List<Model<Label>> models;
    private final List<Label> labels;

    /* JADX INFO: Access modifiers changed from: package-private */
    public IndependentMultiLabelModel(List<Label> list, List<Model<Label>> list2, ModelProvenance modelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<MultiLabel> immutableOutputInfo) {
        super("binary-relevance", modelProvenance, immutableFeatureMap, immutableOutputInfo, list2.get(0).generatesProbabilities());
        this.labels = list;
        this.models = list2;
    }

    public Prediction<MultiLabel> predict(Example<MultiLabel> example) {
        HashSet hashSet = new HashSet();
        BinaryExample binaryExample = new BinaryExample(example, null);
        int i = 0;
        Iterator<Model<Label>> it = this.models.iterator();
        while (it.hasNext()) {
            Prediction predict = it.next().predict(binaryExample);
            if (i < predict.getNumActiveFeatures()) {
                i = predict.getNumActiveFeatures();
            }
            if (!predict.getOutput().getLabel().equals(MultiLabel.NEGATIVE_LABEL_STRING)) {
                hashSet.add(predict.getOutput());
            }
        }
        return new Prediction<>(new MultiLabel(hashSet), i, example);
    }

    public Map<String, List<Pair<String, Double>>> getTopFeatures(int i) {
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < this.models.size(); i2++) {
            Model<Label> model = this.models.get(i2);
            String label = this.labels.get(i2).getLabel();
            Map topFeatures = model.getTopFeatures(i);
            if (topFeatures != null) {
                if (topFeatures.size() == 1) {
                    hashMap.put(label, (List) topFeatures.get("ALL_OUTPUTS"));
                } else {
                    hashMap.merge(label, (List) topFeatures.get(label), (list, list2) -> {
                        list.addAll(list2);
                        return list;
                    });
                }
            }
        }
        return hashMap;
    }

    public Optional<Excuse<MultiLabel>> getExcuse(Example<MultiLabel> example) {
        return Optional.empty();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public IndependentMultiLabelModel m17copy(String str, ModelProvenance modelProvenance) {
        ArrayList arrayList = new ArrayList();
        Iterator<Model<Label>> it = this.models.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().copy());
        }
        return new IndependentMultiLabelModel(this.labels, arrayList, modelProvenance, this.featureIDMap, this.outputIDInfo);
    }
}
