001/*
002 * Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 *     http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016
017package org.tribuo.regression.rtree.impl;
018
019import com.oracle.labs.mlrg.olcut.util.Pair;
020import org.tribuo.common.tree.impl.IntArrayContainer;
021
022import java.util.ArrayList;
023import java.util.Collections;
024import java.util.HashMap;
025import java.util.Iterator;
026import java.util.List;
027import java.util.Map;
028
029/**
030 * An inverted feature, which stores a reference to all the values of this feature.
031 * <p>
032 * Can be split into two values based on an example index list.
033 */
034public class TreeFeature implements Iterable<InvertedFeature> {
035
036    private final int id;
037
038    private final List<InvertedFeature> feature;
039
040    private final Map<Double,InvertedFeature> valueMap;
041
042    private boolean sorted = true;
043
044    /**
045     * Constructs an inverted feature with the specified feature id.
046     * @param id The feature id.
047     */
048    public TreeFeature(int id) {
049        this.id = id;
050        this.feature = new ArrayList<>();
051        this.valueMap = new HashMap<>();
052    }
053
054    /**
055     * This constructor doesn't make a valueMap, and is only used when all data has been observed.
056     * So it will throw NullPointerException if you call observeValue();
057     * @param id The id number for this feature.
058     * @param data The data.
059     */
060    private TreeFeature(int id, List<InvertedFeature> data) {
061        this.id = id;
062        this.feature = data;
063        this.valueMap = null;
064    }
065
066    /**
067     * Constructor used by {@link TreeFeature#deepCopy}.
068     * @param id The id number for this feature.
069     * @param data The data.
070     * @param valueMap The value map.
071     * @param sorted Is this data sorted.
072     */
073    private TreeFeature(int id, List<InvertedFeature> data, Map<Double,InvertedFeature> valueMap, boolean sorted) {
074        this.id = id;
075        this.feature = data;
076        this.valueMap = valueMap;
077        this.sorted = sorted;
078    }
079
080    @Override
081    public Iterator<InvertedFeature> iterator() {
082        return feature.iterator();
083    }
084
085    /**
086     * Gets the inverted feature values for this feature.
087     * @return The list of feature values.
088     */
089    public List<InvertedFeature> getFeature() {
090        return feature;
091    }
092
093    /**
094     * Observes a value for this feature.
095     * @param value The value observed.
096     * @param exampleID The example id number.
097     */
098    public void observeValue(double value, int exampleID) {
099        Double dValue = value;
100        InvertedFeature f = valueMap.get(dValue);
101        if (f == null) {
102            f = new InvertedFeature(value,exampleID);
103            valueMap.put(dValue,f);
104            feature.add(f);
105            // feature list is no longer guaranteed to be sorted
106            sorted = false;
107        } else {
108            // Update currently known feature
109            f.add(exampleID);
110        }
111    }
112
113    /**
114     * Sort the list using InvertedFeature's natural ordering. Must be done after all elements are inserted.
115     */
116    public void sort() {
117        feature.sort(null);
118        sorted = true;
119    }
120
121    /**
122     * Fixes the size of each {@link InvertedFeature}'s inner arrays.
123     */
124    public void fixSize() {
125        feature.forEach(InvertedFeature::fixSize);
126    }
127
128    /**
129     * Splits this tree feature into two.
130     *
131     * @param leftIndices The indices to go in the left branch.
132     * @param rightIndices The indices to go in the right branch.
133     * @param firstBuffer A buffer for temporary work.
134     * @param secondBuffer A buffer for temporary work.
135     * @return A pair of TreeFeatures, the first element is the left branch, the second the right.
136     */
137    public Pair<TreeFeature,TreeFeature> split(int[] leftIndices, int[] rightIndices, IntArrayContainer firstBuffer, IntArrayContainer secondBuffer) {
138        if (!sorted) {
139            throw new IllegalStateException("TreeFeature must be sorted before split is called");
140        }
141
142        List<InvertedFeature> leftFeatures;
143        List<InvertedFeature> rightFeatures;
144        if (feature.size() == 1) {
145            double value = feature.get(0).value;
146            leftFeatures = Collections.singletonList(new InvertedFeature(value,leftIndices));
147            rightFeatures = Collections.singletonList(new InvertedFeature(value,rightIndices));
148        } else {
149            leftFeatures = new ArrayList<>();
150            rightFeatures = new ArrayList<>();
151            firstBuffer.fill(leftIndices);
152            for (InvertedFeature f : feature) {
153                // Check if we've exhausted all the left side indices
154                if (firstBuffer.size > 0) {
155                    Pair<InvertedFeature, InvertedFeature> split = f.split(firstBuffer, secondBuffer);
156                    IntArrayContainer tmp = secondBuffer;
157                    secondBuffer = firstBuffer;
158                    firstBuffer = tmp;
159                    InvertedFeature left = split.getA();
160                    InvertedFeature right = split.getB();
161                    if (left != null) {
162                        leftFeatures.add(left);
163                    }
164                    if (right != null) {
165                        rightFeatures.add(right);
166                    }
167                } else {
168                    rightFeatures.add(f);
169                }
170            }
171        }
172
173        return new Pair<>(new TreeFeature(id,leftFeatures),new TreeFeature(id,rightFeatures));
174
175    }
176
177    public String toString() {
178        return "TreeFeature(id="+id+",values="+feature.toString()+")";
179    }
180
181    /**
182     * Returns a deep copy of this tree feature.
183     * @return A deep copy.
184     */
185    public TreeFeature deepCopy() {
186        Map<Double,InvertedFeature> newValueMap;
187        List<InvertedFeature> newFeature = new ArrayList<>();
188        if (valueMap != null) {
189            newValueMap = new HashMap<>();
190            for (Map.Entry<Double,InvertedFeature> e : valueMap.entrySet()) {
191                InvertedFeature featureCopy = e.getValue().deepCopy();
192                newValueMap.put(e.getKey(),featureCopy);
193                newFeature.add(featureCopy);
194                newFeature.sort(null);
195            }
196        } else {
197            newValueMap = null;
198            for (InvertedFeature f : feature) {
199                newFeature.add(f.deepCopy());
200            }
201        }
202        return new TreeFeature(id,newFeature,newValueMap,true);
203    }
204}