001/*
002 * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.ensemble;
018
019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
021import org.tribuo.Example;
022import org.tribuo.ImmutableOutputInfo;
023import org.tribuo.Prediction;
024import org.tribuo.classification.Label;
025import org.tribuo.ensemble.EnsembleCombiner;
026import org.tribuo.util.onnx.ONNXInitializer;
027import org.tribuo.util.onnx.ONNXNode;
028import org.tribuo.util.onnx.ONNXOperators;
029import org.tribuo.util.onnx.ONNXRef;
030
031import java.util.Collections;
032import java.util.HashMap;
033import java.util.LinkedHashMap;
034import java.util.List;
035import java.util.Map;
036
037/**
038 * A combiner which performs a weighted or unweighted vote across the predicted labels.
039 * <p>
040 * This uses the full distribution of predictions from each ensemble member, unlike {@link VotingCombiner}
041 * which uses the most likely prediction for each ensemble member.
042 */
043public final class FullyWeightedVotingCombiner implements EnsembleCombiner<Label> {
044    private static final long serialVersionUID = 1L;
045
046    /**
047     * Constructs a weighted voting combiner.
048     */
049    public FullyWeightedVotingCombiner() {}
050
051    @Override
052    public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions) {
053        int numPredictions = predictions.size();
054        int numUsed = 0;
055        double weight = 1.0 / numPredictions;
056        double sum = 0.0;
057        double[] score = new double[outputInfo.size()];
058        for (Prediction<Label> p : predictions) {
059            if (numUsed < p.getNumActiveFeatures()) {
060                numUsed = p.getNumActiveFeatures();
061            }
062            for (Label e : p.getOutputScores().values()) {
063                double curScore = weight * e.getScore();
064                sum += curScore;
065                score[outputInfo.getID(e)] += curScore;
066            }
067        }
068
069        double maxScore = Double.NEGATIVE_INFINITY;
070        Label maxLabel = null;
071        Map<String,Label> predictionMap = new LinkedHashMap<>();
072        for (int i = 0; i < score.length; i++) {
073            String name = outputInfo.getOutput(i).getLabel();
074            Label label = new Label(name,score[i]/sum);
075            predictionMap.put(name,label);
076            if (label.getScore() > maxScore) {
077                maxScore = label.getScore();
078                maxLabel = label;
079            }
080        }
081
082        Example<Label> example = predictions.get(0).getExample();
083
084        return new Prediction<>(maxLabel,predictionMap,numUsed,example,true);
085    }
086
087    @Override
088    public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions, float[] weights) {
089        if (predictions.size() != weights.length) {
090            throw new IllegalArgumentException("predictions and weights must be the same length. predictions.size()="+predictions.size()+", weights.length="+weights.length);
091        }
092        int numUsed = 0;
093        double sum = 0.0;
094        double[] score = new double[outputInfo.size()];
095        for (int i = 0; i < weights.length; i++) {
096            Prediction<Label> p = predictions.get(i);
097            if (numUsed < p.getNumActiveFeatures()) {
098                numUsed = p.getNumActiveFeatures();
099            }
100            for (Label e : p.getOutputScores().values()) {
101                double curScore = weights[i] * e.getScore();
102                sum += curScore;
103                score[outputInfo.getID(e)] += curScore;
104            }
105        }
106
107        double maxScore = Double.NEGATIVE_INFINITY;
108        Label maxLabel = null;
109        Map<String,Label> predictionMap = new LinkedHashMap<>();
110        for (int i = 0; i < score.length; i++) {
111            String name = outputInfo.getOutput(i).getLabel();
112            Label label = new Label(name,score[i]/sum);
113            predictionMap.put(name,label);
114            if (label.getScore() > maxScore) {
115                maxScore = label.getScore();
116                maxLabel = label;
117            }
118        }
119
120        Example<Label> example = predictions.get(0).getExample();
121
122        return new Prediction<>(maxLabel,predictionMap,numUsed,example,true);
123    }
124
125    @Override
126    public String toString() {
127        return "FullyWeightedVotingCombiner()";
128    }
129
130    @Override
131    public ConfiguredObjectProvenance getProvenance() {
132        return new ConfiguredObjectProvenanceImpl(this,"EnsembleCombiner");
133    }
134
135    /**
136     * Exports this voting combiner to ONNX.
137     * <p>
138     * The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].
139     * @param input the node to be ensembled according to this implementation.
140     * @return The leaf node of the voting operation.
141     */
142    @Override
143    public ONNXNode exportCombiner(ONNXNode input) {
144        // Take the mean over the maxed predictions
145        Map<String,Object> attributes = new HashMap<>();
146        attributes.put("axes",new int[]{2});
147        attributes.put("keepdims",0);
148        return input.apply(ONNXOperators.REDUCE_MEAN, attributes);
149    }
150
151    /**
152     * Exports this voting combiner to ONNX.
153     * <p>
154     * The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].
155     * @param input the node to be ensembled according to this implementation.
156     * @param weight The node of weights for ensembling.
157     * @return The leaf node of the voting operation.
158     */
159    @Override
160    public <T extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, T weight) {
161        // Unsqueeze the weights to make sure they broadcast how I want them too.
162        // Now the size is [1, 1, num_members].
163        ONNXInitializer unsqueezeAxes = input.onnxContext().array("unsqueeze_ensemble_output", new long[]{0, 1});
164
165        ONNXNode unsqueezed = weight.apply(ONNXOperators.UNSQUEEZE, unsqueezeAxes);
166
167        // Multiply the input by the weights.
168        ONNXNode mulByWeights = input.apply(ONNXOperators.MUL, unsqueezed);
169
170        // Sum the weights
171        ONNXNode weightSum = weight.apply(ONNXOperators.REDUCE_SUM);
172
173        // Take the weighted mean over the outputs
174        ONNXInitializer sumAxes = input.onnxContext().array("sum_across_ensemble_axes", new long[]{2});
175        return mulByWeights.apply(ONNXOperators.REDUCE_SUM, sumAxes, Collections.singletonMap("keepdims", 0))
176                .apply(ONNXOperators.DIV, weightSum);
177    }
178}