001/* 002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved. 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016 017package org.tribuo.classification.sequence.viterbi; 018 019import com.oracle.labs.mlrg.olcut.config.Config; 020import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance; 021import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl; 022import org.tribuo.Feature; 023import org.tribuo.classification.Label; 024 025import java.util.ArrayList; 026import java.util.Collections; 027import java.util.List; 028 029/** 030 * A label feature extractor that produces several kinds of label-based features. 031 * <p> 032 * The options are: the most recent output, the least recent output, recent bigrams, recent trigrams, recent 4-grams. 033 */ 034public class DefaultFeatureExtractor implements LabelFeatureExtractor { 035 036 private static final long serialVersionUID = 1L; 037 038 /** 039 * indicates the position of the first (most recent) outcome to include. For example, the 040 * default value of 1 means that if the outcomes produced so far by the classifier were [A, B, 041 * C, D], then the first outcome to be used as a feature would be D since it is the most recent. 042 */ 043 @Config(mandatory = true, description = "Position of the most recent outcome to include.") 044 private int mostRecentOutcome; 045 046 /** 047 * indicates the position of the last (least recent) outcome to include. For example, the 048 * default value of 3 means that if the outcomes produced so far by the classifier were [A, B, 049 * C, D], then the last outcome to be used as a feature would be B since and is considered the 050 * least recent. 051 */ 052 @Config(mandatory = true, description = "Position of the least recent output to include.") 053 private int leastRecentOutcome; 054 055 /** 056 * when true indicates that bigrams of outcomes should be included as features 057 */ 058 @Config(mandatory = true, description = "Use bigrams of the labels as features.") 059 private boolean useBigram; 060 061 /** 062 * indicates that trigrams of outcomes should be included as features 063 */ 064 @Config(mandatory = true, description = "Use trigrams of the labels as features.") 065 private boolean useTrigram; 066 067 /** 068 * indicates that 4-grams of outcomes should be included as features 069 */ 070 @Config(mandatory = true, description = "Use 4-grams of the labels as features.") 071 private boolean use4gram; 072 073 /** 074 * Constructs a default feature extractor for bigrams and trigrams using the past 3 outcomes. 075 */ 076 public DefaultFeatureExtractor() { 077 this(1, 3, true, true, false); 078 } 079 080 /** 081 * Constructs a default feature extractor using the supplied parameters. 082 * @param mostRecentOutcome The most recent outcome to include as a feature. 083 * @param leastRecentOutcome The least recent outcome to include as a feature. 084 * @param useBigram Use bigrams of the outcomes. 085 * @param useTrigram Use trigrams of the outcomes. 086 * @param use4gram Use 4-grams of the outcomes. 087 */ 088 public DefaultFeatureExtractor(int mostRecentOutcome, int leastRecentOutcome, boolean useBigram, boolean useTrigram, boolean use4gram) { 089 this.mostRecentOutcome = mostRecentOutcome; 090 this.leastRecentOutcome = leastRecentOutcome; 091 this.useBigram = useBigram; 092 this.useTrigram = useTrigram; 093 this.use4gram = use4gram; 094 } 095 096 @Override 097 public String toString() { 098 return "DefaultFeatureExtractor(mostRecent=" + mostRecentOutcome + ",leastRecent=" + leastRecentOutcome + ",useBigram=" + useBigram + ",useTrigram=" + useTrigram + ",use4gram=" + use4gram + ")"; 099 } 100 101 @Override 102 public List<Feature> extractFeatures(List<Label> previousOutcomes, double value) { 103 if (previousOutcomes == null || previousOutcomes.size() == 0) { 104 return Collections.emptyList(); 105 } 106 107 List<Feature> features = new ArrayList<>(); 108 109 for (int i = mostRecentOutcome; i <= leastRecentOutcome; i++) { 110 int index = previousOutcomes.size() - i; 111 if (index >= 0) { 112 Feature feature = new Feature("PreviousOutcome_L" + i + "_" + previousOutcomes.get(index).getLabel(), value); 113 features.add(feature); 114 } 115 } 116 117 if (useBigram && previousOutcomes.size() >= 2) { 118 int size = previousOutcomes.size(); 119 String featureValue = previousOutcomes.get(size - 1).getLabel() + "_" + previousOutcomes.get(size - 2).getLabel(); 120 Feature feature = new Feature("PreviousOutcomes_L1_2gram_L2R_" + featureValue, value); 121 features.add(feature); 122 } 123 124 if (useTrigram && previousOutcomes.size() >= 3) { 125 int size = previousOutcomes.size(); 126 String featureValue = previousOutcomes.get(size - 1).getLabel() + "_" + previousOutcomes.get(size - 2).getLabel() + "_" 127 + previousOutcomes.get(size - 3).getLabel(); 128 Feature feature = new Feature("PreviousOutcomes_L1_3gram_L2R_" + featureValue, value); 129 features.add(feature); 130 } 131 132 if (use4gram && previousOutcomes.size() >= 4) { 133 int size = previousOutcomes.size(); 134 String featureValue = previousOutcomes.get(size - 1).getLabel() + "_" + previousOutcomes.get(size - 2).getLabel() + "_" 135 + previousOutcomes.get(size - 3).getLabel() + "_" + previousOutcomes.get(size - 4).getLabel(); 136 Feature feature = new Feature("PreviousOutcomes_L1_4gram_L2R_" + featureValue, value); 137 features.add(feature); 138 } 139 140 return features; 141 } 142 143 @Override 144 public ConfiguredObjectProvenance getProvenance() { 145 return new ConfiguredObjectProvenanceImpl(this, "LabelFeatureExtractor"); 146 } 147}