001/*
002 * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.ensemble;
018
019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
020import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
021import org.tribuo.Example;
022import org.tribuo.ImmutableOutputInfo;
023import org.tribuo.Prediction;
024import org.tribuo.classification.Label;
025import org.tribuo.ensemble.EnsembleCombiner;
026import org.tribuo.util.onnx.ONNXInitializer;
027import org.tribuo.util.onnx.ONNXNode;
028import org.tribuo.util.onnx.ONNXOperators;
029import org.tribuo.util.onnx.ONNXRef;
030
031import java.util.Collections;
032import java.util.HashMap;
033import java.util.LinkedHashMap;
034import java.util.List;
035import java.util.Map;
036
037/**
038 * A combiner which performs a weighted or unweighted vote across the predicted labels.
039 * <p>
040 * This uses the most likely prediction from each ensemble member, unlike {@link FullyWeightedVotingCombiner}
041 * which uses the full distribution of predictions for each ensemble member.
042 */
043public final class VotingCombiner implements EnsembleCombiner<Label> {
044    private static final long serialVersionUID = 1L;
045
046    /**
047     * Constructs a voting combiner.
048     */
049    public VotingCombiner() {}
050
051    @Override
052    public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions) {
053        int numPredictions = predictions.size();
054        int numUsed = 0;
055        double weight = 1.0 / numPredictions;
056        double[] score = new double[outputInfo.size()];
057        for (Prediction<Label> p : predictions) {
058            if (numUsed < p.getNumActiveFeatures()) {
059                numUsed = p.getNumActiveFeatures();
060            }
061            score[outputInfo.getID(p.getOutput())] += weight;
062        }
063
064        double maxScore = Double.NEGATIVE_INFINITY;
065        Label maxLabel = null;
066        Map<String,Label> predictionMap = new LinkedHashMap<>();
067        for (int i = 0; i < score.length; i++) {
068            String name = outputInfo.getOutput(i).getLabel();
069            Label label = new Label(name,score[i]);
070            predictionMap.put(name,label);
071            if (label.getScore() > maxScore) {
072                maxScore = label.getScore();
073                maxLabel = label;
074            }
075        }
076
077        Example<Label> example = predictions.get(0).getExample();
078
079        return new Prediction<>(maxLabel,predictionMap,numUsed,example,true);
080    }
081
082    @Override
083    public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions, float[] weights) {
084        if (predictions.size() != weights.length) {
085            throw new IllegalArgumentException("predictions and weights must be the same length. predictions.size()="+predictions.size()+", weights.length="+weights.length);
086        }
087        int numUsed = 0;
088        double sum = 0.0;
089        double[] score = new double[outputInfo.size()];
090        for (int i = 0; i < weights.length; i++) {
091            Prediction<Label> p = predictions.get(i);
092            if (numUsed < p.getNumActiveFeatures()) {
093                numUsed = p.getNumActiveFeatures();
094            }
095            score[outputInfo.getID(p.getOutput())] += weights[i];
096            sum += weights[i];
097        }
098
099        double maxScore = Double.NEGATIVE_INFINITY;
100        Label maxLabel = null;
101        Map<String,Label> predictionMap = new LinkedHashMap<>();
102        for (int i = 0; i < score.length; i++) {
103            String name = outputInfo.getOutput(i).getLabel();
104            Label label = new Label(name,score[i]/sum);
105            predictionMap.put(name,label);
106            if (label.getScore() > maxScore) {
107                maxScore = label.getScore();
108                maxLabel = label;
109            }
110        }
111
112        Example<Label> example = predictions.get(0).getExample();
113
114        return new Prediction<>(maxLabel,predictionMap,numUsed,example,true);
115    }
116
117    @Override
118    public String toString() {
119        return "VotingCombiner()";
120    }
121
122    @Override
123    public ConfiguredObjectProvenance getProvenance() {
124        return new ConfiguredObjectProvenanceImpl(this,"EnsembleCombiner");
125    }
126
127    /**
128     * Exports this voting combiner to ONNX.
129     * <p>
130     * The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].
131     * @param input The input tensor to combine.
132     * @return the final node proto representing the voting operation.
133     */
134    @Override
135    public ONNXNode exportCombiner(ONNXNode input) {
136        // Hardmax!
137        // Take the mean over the maxed predictions
138        Map<String,Object> attributes = new HashMap<>();
139        attributes.put("axes",new int[]{2});
140        attributes.put("keepdims",0);
141        return input.apply(ONNXOperators.HARDMAX, Collections.singletonMap("axis", 1))
142                .apply(ONNXOperators.REDUCE_MEAN, attributes);
143    }
144
145    /**
146     * Exports this voting combiner to ONNX
147     * <p>
148     * The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members].
149     * @param input The input tensor to combine.
150     * @param weight The combination weight node.
151     * @return the final node proto representing the voting operation.
152     */
153    @Override
154    public <T extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, T weight) {
155        // Unsqueeze the weights to make sure they broadcast how I want them too.
156        // Now the size is [1, 1, num_members].
157        ONNXInitializer unsqueezeAxes = input.onnxContext().array("unsqueeze_ensemble_output", new long[]{0, 1});
158        ONNXInitializer sumAxes = input.onnxContext().array("sum_across_ensemble_axes", new long[]{2});
159
160        ONNXNode unsqueezed = weight.apply(ONNXOperators.UNSQUEEZE, unsqueezeAxes);
161
162        // Hardmax!
163        // Multiply the input by the weights.
164        ONNXNode mulByWeights = input.apply(ONNXOperators.HARDMAX, Collections.singletonMap("axis", 1))
165                .apply(ONNXOperators.MUL, unsqueezed);
166
167        // Sum the weights
168        ONNXNode weightSum = weight.apply(ONNXOperators.REDUCE_SUM);
169
170        // Take the weighted mean over the outputs
171        return mulByWeights.apply(ONNXOperators.REDUCE_SUM, sumAxes, Collections.singletonMap("keepdims", 0))
172                .apply(ONNXOperators.DIV, weightSum);
173    }
174}