package org.tribuo.ensemble;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.provenance.EnsembleModelProvenance;
import org.tribuo.util.Util;

/* loaded from: input_file:org/tribuo/ensemble/WeightedEnsembleModel.class */
public final class WeightedEnsembleModel<T extends Output<T>> extends EnsembleModel<T> {
    private static final long serialVersionUID = 1;
    protected final float[] weights;
    protected final EnsembleCombiner<T> combiner;

    public WeightedEnsembleModel(String str, EnsembleModelProvenance ensembleModelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, List<Model<T>> list, EnsembleCombiner<T> ensembleCombiner) {
        this(str, ensembleModelProvenance, immutableFeatureMap, immutableOutputInfo, list, ensembleCombiner, Util.generateUniformVector(list.size(), 1.0f / list.size()));
    }

    public WeightedEnsembleModel(String str, EnsembleModelProvenance ensembleModelProvenance, ImmutableFeatureMap immutableFeatureMap, ImmutableOutputInfo<T> immutableOutputInfo, List<Model<T>> list, EnsembleCombiner<T> ensembleCombiner, float[] fArr) {
        super(str, ensembleModelProvenance, immutableFeatureMap, immutableOutputInfo, list);
        this.weights = Arrays.copyOf(fArr, fArr.length);
        this.combiner = ensembleCombiner;
    }

    @Override // org.tribuo.Model
    public Prediction<T> predict(Example<T> example) {
        ArrayList arrayList = new ArrayList();
        Iterator<Model<T>> it = this.models.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().predict((Example) example));
        }
        return this.combiner.combine(this.outputIDInfo, arrayList, this.weights);
    }

    @Override // org.tribuo.ensemble.EnsembleModel, org.tribuo.Model
    public Optional<Excuse<T>> getExcuse(Example<T> example) {
        HashMap hashMap = new HashMap();
        Prediction<T> predict = predict((Example) example);
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.models.size(); i++) {
            Optional<Excuse<T>> excuse = this.models.get(i).getExcuse(example);
            if (excuse.isPresent()) {
                arrayList.add(excuse.get());
                for (Map.Entry<String, List<Pair<String, Double>>> entry : excuse.get().getScores().entrySet()) {
                    Map map = (Map) hashMap.computeIfAbsent(entry.getKey(), str -> {
                        return new HashMap();
                    });
                    for (Pair<String, Double> pair : entry.getValue()) {
                        map.merge(pair.getA(), Double.valueOf(((Double) pair.getB()).doubleValue() * this.weights[i]), (v0, v1) -> {
                            return Double.sum(v0, v1);
                        });
                    }
                }
            }
        }
        if (hashMap.isEmpty()) {
            return Optional.empty();
        }
        HashMap hashMap2 = new HashMap();
        for (Map.Entry entry2 : hashMap.entrySet()) {
            ArrayList arrayList2 = new ArrayList();
            for (Map.Entry entry3 : ((Map) entry2.getValue()).entrySet()) {
                arrayList2.add(new Pair(entry3.getKey(), entry3.getValue()));
            }
            arrayList2.sort((pair2, pair3) -> {
                return ((Double) pair3.getB()).compareTo((Double) pair2.getB());
            });
            hashMap2.put(entry2.getKey(), arrayList2);
        }
        return Optional.of(new EnsembleExcuse(example, predict, hashMap2, arrayList));
    }

    @Override // org.tribuo.ensemble.EnsembleModel
    protected EnsembleModel<T> copy(String str, EnsembleModelProvenance ensembleModelProvenance, List<Model<T>> list) {
        return new WeightedEnsembleModel(str, ensembleModelProvenance, this.featureIDMap, this.outputIDInfo, list, this.combiner);
    }
}
