001/* 002 * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.ensemble; 018 019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 021import org.tribuo.Example; 022import org.tribuo.ImmutableOutputInfo; 023import org.tribuo.Prediction; 024import org.tribuo.classification.Label; 025import org.tribuo.ensemble.EnsembleCombiner; 026import org.tribuo.util.onnx.ONNXInitializer; 027import org.tribuo.util.onnx.ONNXNode; 028import org.tribuo.util.onnx.ONNXOperators; 029import org.tribuo.util.onnx.ONNXRef; 030 031import java.util.Collections; 032import java.util.HashMap; 033import java.util.LinkedHashMap; 034import java.util.List; 035import java.util.Map; 036 037/** 038 * A combiner which performs a weighted or unweighted vote across the predicted labels. 039 * <p> 040 * This uses the most likely prediction from each ensemble member, unlike {@link FullyWeightedVotingCombiner} 041 * which uses the full distribution of predictions for each ensemble member. 042 */ 043public final class VotingCombiner implements EnsembleCombiner<Label> { 044 private static final long serialVersionUID = 1L; 045 046 /** 047 * Constructs a voting combiner. 048 */ 049 public VotingCombiner() {} 050 051 @Override 052 public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions) { 053 int numPredictions = predictions.size(); 054 int numUsed = 0; 055 double weight = 1.0 / numPredictions; 056 double[] score = new double[outputInfo.size()]; 057 for (Prediction<Label> p : predictions) { 058 if (numUsed < p.getNumActiveFeatures()) { 059 numUsed = p.getNumActiveFeatures(); 060 } 061 score[outputInfo.getID(p.getOutput())] += weight; 062 } 063 064 double maxScore = Double.NEGATIVE_INFINITY; 065 Label maxLabel = null; 066 Map<String,Label> predictionMap = new LinkedHashMap<>(); 067 for (int i = 0; i < score.length; i++) { 068 String name = outputInfo.getOutput(i).getLabel(); 069 Label label = new Label(name,score[i]); 070 predictionMap.put(name,label); 071 if (label.getScore() > maxScore) { 072 maxScore = label.getScore(); 073 maxLabel = label; 074 } 075 } 076 077 Example<Label> example = predictions.get(0).getExample(); 078 079 return new Prediction<>(maxLabel,predictionMap,numUsed,example,true); 080 } 081 082 @Override 083 public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions, float[] weights) { 084 if (predictions.size() != weights.length) { 085 throw new IllegalArgumentException("predictions and weights must be the same length. predictions.size()="+predictions.size()+", weights.length="+weights.length); 086 } 087 int numUsed = 0; 088 double sum = 0.0; 089 double[] score = new double[outputInfo.size()]; 090 for (int i = 0; i < weights.length; i++) { 091 Prediction<Label> p = predictions.get(i); 092 if (numUsed < p.getNumActiveFeatures()) { 093 numUsed = p.getNumActiveFeatures(); 094 } 095 score[outputInfo.getID(p.getOutput())] += weights[i]; 096 sum += weights[i]; 097 } 098 099 double maxScore = Double.NEGATIVE_INFINITY; 100 Label maxLabel = null; 101 Map<String,Label> predictionMap = new LinkedHashMap<>(); 102 for (int i = 0; i < score.length; i++) { 103 String name = outputInfo.getOutput(i).getLabel(); 104 Label label = new Label(name,score[i]/sum); 105 predictionMap.put(name,label); 106 if (label.getScore() > maxScore) { 107 maxScore = label.getScore(); 108 maxLabel = label; 109 } 110 } 111 112 Example<Label> example = predictions.get(0).getExample(); 113 114 return new Prediction<>(maxLabel,predictionMap,numUsed,example,true); 115 } 116 117 @Override 118 public String toString() { 119 return "VotingCombiner()"; 120 } 121 122 @Override 123 public ConfiguredObjectProvenance getProvenance() { 124 return new ConfiguredObjectProvenanceImpl(this,"EnsembleCombiner"); 125 } 126 127 /** 128 * Exports this voting combiner to ONNX. 129 * <p> 130 * The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members]. 131 * @param input The input tensor to combine. 132 * @return the final node proto representing the voting operation. 133 */ 134 @Override 135 public ONNXNode exportCombiner(ONNXNode input) { 136 // Hardmax! 137 // Take the mean over the maxed predictions 138 Map<String,Object> attributes = new HashMap<>(); 139 attributes.put("axes",new int[]{2}); 140 attributes.put("keepdims",0); 141 return input.apply(ONNXOperators.HARDMAX, Collections.singletonMap("axis", 1)) 142 .apply(ONNXOperators.REDUCE_MEAN, attributes); 143 } 144 145 /** 146 * Exports this voting combiner to ONNX 147 * <p> 148 * The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members]. 149 * @param input The input tensor to combine. 150 * @param weight The combination weight node. 151 * @return the final node proto representing the voting operation. 152 */ 153 @Override 154 public <T extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, T weight) { 155 // Unsqueeze the weights to make sure they broadcast how I want them too. 156 // Now the size is [1, 1, num_members]. 157 ONNXInitializer unsqueezeAxes = input.onnxContext().array("unsqueeze_ensemble_output", new long[]{0, 1}); 158 ONNXInitializer sumAxes = input.onnxContext().array("sum_across_ensemble_axes", new long[]{2}); 159 160 ONNXNode unsqueezed = weight.apply(ONNXOperators.UNSQUEEZE, unsqueezeAxes); 161 162 // Hardmax! 163 // Multiply the input by the weights. 164 ONNXNode mulByWeights = input.apply(ONNXOperators.HARDMAX, Collections.singletonMap("axis", 1)) 165 .apply(ONNXOperators.MUL, unsqueezed); 166 167 // Sum the weights 168 ONNXNode weightSum = weight.apply(ONNXOperators.REDUCE_SUM); 169 170 // Take the weighted mean over the outputs 171 return mulByWeights.apply(ONNXOperators.REDUCE_SUM, sumAxes, Collections.singletonMap("keepdims", 0)) 172 .apply(ONNXOperators.DIV, weightSum); 173 } 174}