001/* 002 * Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.ensemble; 018 019import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 020import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 021import org.tribuo.Example; 022import org.tribuo.ImmutableOutputInfo; 023import org.tribuo.Prediction; 024import org.tribuo.classification.Label; 025import org.tribuo.ensemble.EnsembleCombiner; 026import org.tribuo.util.onnx.ONNXInitializer; 027import org.tribuo.util.onnx.ONNXNode; 028import org.tribuo.util.onnx.ONNXOperators; 029import org.tribuo.util.onnx.ONNXRef; 030 031import java.util.Collections; 032import java.util.HashMap; 033import java.util.LinkedHashMap; 034import java.util.List; 035import java.util.Map; 036 037/** 038 * A combiner which performs a weighted or unweighted vote across the predicted labels. 039 * <p> 040 * This uses the full distribution of predictions from each ensemble member, unlike {@link VotingCombiner} 041 * which uses the most likely prediction for each ensemble member. 042 */ 043public final class FullyWeightedVotingCombiner implements EnsembleCombiner<Label> { 044 private static final long serialVersionUID = 1L; 045 046 /** 047 * Constructs a weighted voting combiner. 048 */ 049 public FullyWeightedVotingCombiner() {} 050 051 @Override 052 public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions) { 053 int numPredictions = predictions.size(); 054 int numUsed = 0; 055 double weight = 1.0 / numPredictions; 056 double sum = 0.0; 057 double[] score = new double[outputInfo.size()]; 058 for (Prediction<Label> p : predictions) { 059 if (numUsed < p.getNumActiveFeatures()) { 060 numUsed = p.getNumActiveFeatures(); 061 } 062 for (Label e : p.getOutputScores().values()) { 063 double curScore = weight * e.getScore(); 064 sum += curScore; 065 score[outputInfo.getID(e)] += curScore; 066 } 067 } 068 069 double maxScore = Double.NEGATIVE_INFINITY; 070 Label maxLabel = null; 071 Map<String,Label> predictionMap = new LinkedHashMap<>(); 072 for (int i = 0; i < score.length; i++) { 073 String name = outputInfo.getOutput(i).getLabel(); 074 Label label = new Label(name,score[i]/sum); 075 predictionMap.put(name,label); 076 if (label.getScore() > maxScore) { 077 maxScore = label.getScore(); 078 maxLabel = label; 079 } 080 } 081 082 Example<Label> example = predictions.get(0).getExample(); 083 084 return new Prediction<>(maxLabel,predictionMap,numUsed,example,true); 085 } 086 087 @Override 088 public Prediction<Label> combine(ImmutableOutputInfo<Label> outputInfo, List<Prediction<Label>> predictions, float[] weights) { 089 if (predictions.size() != weights.length) { 090 throw new IllegalArgumentException("predictions and weights must be the same length. predictions.size()="+predictions.size()+", weights.length="+weights.length); 091 } 092 int numUsed = 0; 093 double sum = 0.0; 094 double[] score = new double[outputInfo.size()]; 095 for (int i = 0; i < weights.length; i++) { 096 Prediction<Label> p = predictions.get(i); 097 if (numUsed < p.getNumActiveFeatures()) { 098 numUsed = p.getNumActiveFeatures(); 099 } 100 for (Label e : p.getOutputScores().values()) { 101 double curScore = weights[i] * e.getScore(); 102 sum += curScore; 103 score[outputInfo.getID(e)] += curScore; 104 } 105 } 106 107 double maxScore = Double.NEGATIVE_INFINITY; 108 Label maxLabel = null; 109 Map<String,Label> predictionMap = new LinkedHashMap<>(); 110 for (int i = 0; i < score.length; i++) { 111 String name = outputInfo.getOutput(i).getLabel(); 112 Label label = new Label(name,score[i]/sum); 113 predictionMap.put(name,label); 114 if (label.getScore() > maxScore) { 115 maxScore = label.getScore(); 116 maxLabel = label; 117 } 118 } 119 120 Example<Label> example = predictions.get(0).getExample(); 121 122 return new Prediction<>(maxLabel,predictionMap,numUsed,example,true); 123 } 124 125 @Override 126 public String toString() { 127 return "FullyWeightedVotingCombiner()"; 128 } 129 130 @Override 131 public ConfiguredObjectProvenance getProvenance() { 132 return new ConfiguredObjectProvenanceImpl(this,"EnsembleCombiner"); 133 } 134 135 /** 136 * Exports this voting combiner to ONNX. 137 * <p> 138 * The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members]. 139 * @param input the node to be ensembled according to this implementation. 140 * @return The leaf node of the voting operation. 141 */ 142 @Override 143 public ONNXNode exportCombiner(ONNXNode input) { 144 // Take the mean over the maxed predictions 145 Map<String,Object> attributes = new HashMap<>(); 146 attributes.put("axes",new int[]{2}); 147 attributes.put("keepdims",0); 148 return input.apply(ONNXOperators.REDUCE_MEAN, attributes); 149 } 150 151 /** 152 * Exports this voting combiner to ONNX. 153 * <p> 154 * The input should be a 3-tensor [batch_size, num_outputs, num_ensemble_members]. 155 * @param input the node to be ensembled according to this implementation. 156 * @param weight The node of weights for ensembling. 157 * @return The leaf node of the voting operation. 158 */ 159 @Override 160 public <T extends ONNXRef<?>> ONNXNode exportCombiner(ONNXNode input, T weight) { 161 // Unsqueeze the weights to make sure they broadcast how I want them too. 162 // Now the size is [1, 1, num_members]. 163 ONNXInitializer unsqueezeAxes = input.onnxContext().array("unsqueeze_ensemble_output", new long[]{0, 1}); 164 165 ONNXNode unsqueezed = weight.apply(ONNXOperators.UNSQUEEZE, unsqueezeAxes); 166 167 // Multiply the input by the weights. 168 ONNXNode mulByWeights = input.apply(ONNXOperators.MUL, unsqueezed); 169 170 // Sum the weights 171 ONNXNode weightSum = weight.apply(ONNXOperators.REDUCE_SUM); 172 173 // Take the weighted mean over the outputs 174 ONNXInitializer sumAxes = input.onnxContext().array("sum_across_ensemble_axes", new long[]{2}); 175 return mulByWeights.apply(ONNXOperators.REDUCE_SUM, sumAxes, Collections.singletonMap("keepdims", 0)) 176 .apply(ONNXOperators.DIV, weightSum); 177 } 178}