001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.classification.sequence;
018
019import org.tribuo.ImmutableFeatureMap;
020import org.tribuo.ImmutableOutputInfo;
021import org.tribuo.Prediction;
022import org.tribuo.classification.Label;
023import org.tribuo.provenance.ModelProvenance;
024import org.tribuo.sequence.SequenceExample;
025import org.tribuo.sequence.SequenceModel;
026
027import java.io.Serializable;
028import java.util.ArrayList;
029import java.util.List;
030
031/**
032 * A Sequence model which can provide confidence predictions for subsequence predictions.
033 * <p>
034 * Used to provide confidence scores on a per subsequence level.
035 * <p>
036 * The exemplar of this is providing a confidence score for each Named Entity present
037 * in a SequenceExample.
038 */
039public abstract class ConfidencePredictingSequenceModel extends SequenceModel<Label> {
040    private static final long serialVersionUID = 1L;
041
042    /**
043     * Constructs a ConfidencePredictingSequenceModel with the supplied parameters.
044     * @param name The model name.
045     * @param description The model provenance.
046     * @param featureIDMap The feature domain.
047     * @param labelIDMap The output domain.
048     */
049    protected ConfidencePredictingSequenceModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<Label> labelIDMap) {
050        super(name,description,featureIDMap,labelIDMap);
051    }
052
053    /**
054     * The scoring function for the subsequences. Provides the scores which should be assigned to each subsequence.
055     * @param example The input sequence example.
056     * @param predictions The predictions produced by this model.
057     * @param subsequences The subsequences to score.
058     * @param <SUB> The subsequence type.
059     * @return The scores for the subsequences.
060     */
061    public abstract <SUB extends Subsequence> List<Double> scoreSubsequences(SequenceExample<Label> example, List<Prediction<Label>> predictions, List<SUB> subsequences);
062
063    /**
064     * A scoring method which multiplies together the per prediction scores.
065     * @param predictions The element level predictions.
066     * @param subsequences The subsequences denoting prediction boundaries.
067     * @param <SUB> The subsequence type.
068     * @return A list of scores for each subsequence.
069     */
070    public static <SUB extends Subsequence> List<Double> multiplyWeights(List<Prediction<Label>> predictions, List<SUB> subsequences) {
071        List<Double> scores = new ArrayList<>(subsequences.size());
072        for(Subsequence subsequence : subsequences) {
073            scores.add(multiplyWeights(predictions, subsequence));
074        }
075        return scores;
076    }
077
078    private static <SUB extends Subsequence> Double multiplyWeights(List<Prediction<Label>> predictions, SUB subsequence) {
079        double counter = 1.0;
080        for (int i=subsequence.begin; i<subsequence.end; i++) {
081            counter *= predictions.get(i).getOutput().getScore();
082        }
083        return counter;
084    }
085
086    /**
087     * A range class used to define a subsequence of a SequenceExample.
088     */
089    public static class Subsequence implements Serializable {
090        private static final long serialVersionUID = 1L;
091        /**
092         * The subsequence start index.
093         */
094        public final int begin;
095        /**
096         * The subsequence end index.
097         */
098        public final int end;
099
100        /**
101         * Constructs a subsequence for the fixed range, exclusive of the end.
102         * @param begin The start element.
103         * @param end The end element.
104         */
105        public Subsequence(int begin, int end) {
106            this.begin = begin;
107            this.end = end;
108        }
109
110        /**
111         * Returns the number of elements in this subsequence.
112         * @return The length of the subsequence.
113         */
114        public int length() {
115            return end - begin;
116        }
117
118        @Override
119        public String toString() {
120            return "("+begin+","+end+")";
121        }
122    }
123
124}