001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sequence; 018 019import org.tribuo.ImmutableFeatureMap; 020import org.tribuo.ImmutableOutputInfo; 021import org.tribuo.Prediction; 022import org.tribuo.classification.Label; 023import org.tribuo.provenance.ModelProvenance; 024import org.tribuo.sequence.SequenceExample; 025import org.tribuo.sequence.SequenceModel; 026 027import java.io.Serializable; 028import java.util.ArrayList; 029import java.util.List; 030 031/** 032 * A Sequence model which can provide confidence predictions for subsequence predictions. 033 * <p> 034 * Used to provide confidence scores on a per subsequence level. 035 * <p> 036 * The exemplar of this is providing a confidence score for each Named Entity present 037 * in a SequenceExample. 038 */ 039public abstract class ConfidencePredictingSequenceModel extends SequenceModel<Label> { 040 private static final long serialVersionUID = 1L; 041 042 /** 043 * Constructs a ConfidencePredictingSequenceModel with the supplied parameters. 044 * @param name The model name. 045 * @param description The model provenance. 046 * @param featureIDMap The feature domain. 047 * @param labelIDMap The output domain. 048 */ 049 protected ConfidencePredictingSequenceModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap) { 050 super(name,description,featureIDMap,labelIDMap); 051 } 052 053 /** 054 * The scoring function for the subsequences. Provides the scores which should be assigned to each subsequence. 055 * @param example The input sequence example. 056 * @param predictions The predictions produced by this model. 057 * @param subsequences The subsequences to score. 058 * @param <SUB> The subsequence type. 059 * @return The scores for the subsequences. 060 */ 061 public abstract <SUB extends Subsequence> List<Double> scoreSubsequences(SequenceExample<Label> example, List<Prediction<Label>> predictions, List<SUB> subsequences); 062 063 /** 064 * A scoring method which multiplies together the per prediction scores. 065 * @param predictions The element level predictions. 066 * @param subsequences The subsequences denoting prediction boundaries. 067 * @param <SUB> The subsequence type. 068 * @return A list of scores for each subsequence. 069 */ 070 public static <SUB extends Subsequence> List<Double> multiplyWeights(List<Prediction<Label>> predictions, List<SUB> subsequences) { 071 List<Double> scores = new ArrayList<>(subsequences.size()); 072 for(Subsequence subsequence : subsequences) { 073 scores.add(multiplyWeights(predictions, subsequence)); 074 } 075 return scores; 076 } 077 078 private static <SUB extends Subsequence> Double multiplyWeights(List<Prediction<Label>> predictions, SUB subsequence) { 079 double counter = 1.0; 080 for (int i=subsequence.begin; i<subsequence.end; i++) { 081 counter *= predictions.get(i).getOutput().getScore(); 082 } 083 return counter; 084 } 085 086 /** 087 * A range class used to define a subsequence of a SequenceExample. 088 */ 089 public static class Subsequence implements Serializable { 090 private static final long serialVersionUID = 1L; 091 /** 092 * The subsequence start index. 093 */ 094 public final int begin; 095 /** 096 * The subsequence end index. 097 */ 098 public final int end; 099 100 /** 101 * Constructs a subsequence for the fixed range, exclusive of the end. 102 * @param begin The start element. 103 * @param end The end element. 104 */ 105 public Subsequence(int begin, int end) { 106 this.begin = begin; 107 this.end = end; 108 } 109 110 /** 111 * Returns the number of elements in this subsequence. 112 * @return The length of the subsequence. 113 */ 114 public int length() { 115 return end - begin; 116 } 117 118 @Override 119 public String toString() { 120 return "("+begin+","+end+")"; 121 } 122 } 123 124}