/*
 * Decompiled with CFR 0.152.
 */
package org.carrot2.clustering.kmeans;

import com.carrotsearch.hppc.IntArrayList;
import com.carrotsearch.hppc.IntIntHashMap;
import com.carrotsearch.hppc.cursors.IntCursor;
import com.carrotsearch.hppc.cursors.IntIntCursor;
import com.carrotsearch.hppc.sorting.IndirectComparator;
import com.carrotsearch.hppc.sorting.IndirectSort;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.carrot2.attrs.AttrBoolean;
import org.carrot2.attrs.AttrComposite;
import org.carrot2.attrs.AttrInteger;
import org.carrot2.attrs.AttrObject;
import org.carrot2.attrs.AttrString;
import org.carrot2.clustering.Cluster;
import org.carrot2.clustering.ClusteringAlgorithm;
import org.carrot2.clustering.Document;
import org.carrot2.clustering.SharedInfrastructure;
import org.carrot2.language.LanguageComponents;
import org.carrot2.language.LexicalData;
import org.carrot2.language.Stemmer;
import org.carrot2.language.Tokenizer;
import org.carrot2.math.mahout.function.Functions;
import org.carrot2.math.mahout.matrix.DoubleMatrix1D;
import org.carrot2.math.mahout.matrix.DoubleMatrix2D;
import org.carrot2.math.mahout.matrix.impl.DenseDoubleMatrix1D;
import org.carrot2.math.mahout.matrix.impl.DenseDoubleMatrix2D;
import org.carrot2.text.preprocessing.BasicPreprocessingPipeline;
import org.carrot2.text.preprocessing.LabelFormatter;
import org.carrot2.text.preprocessing.PreprocessingContext;
import org.carrot2.text.vsm.ReducedVectorSpaceModelContext;
import org.carrot2.text.vsm.TermDocumentMatrixBuilder;
import org.carrot2.text.vsm.TermDocumentMatrixReducer;
import org.carrot2.text.vsm.VectorSpaceModelContext;

public class BisectingKMeansClusteringAlgorithm
extends AttrComposite
implements ClusteringAlgorithm {
    private static final Set<Class<?>> REQUIRED_LANGUAGE_COMPONENTS = new HashSet<Class>(Arrays.asList(Stemmer.class, Tokenizer.class, LexicalData.class, LabelFormatter.class));
    public static final String NAME = "Bisecting K-Means";
    public final AttrInteger clusterCount;
    public final AttrInteger maxIterations;
    public final AttrInteger partitionCount;
    public final AttrInteger labelCount;
    public final AttrString queryHint;
    public final AttrBoolean useDimensionalityReduction;
    public TermDocumentMatrixBuilder matrixBuilder;
    public TermDocumentMatrixReducer matrixReducer;
    public BasicPreprocessingPipeline preprocessing;
    private static final Comparator<IntArrayList> BY_SIZE_DESCENDING = (o1, o2) -> o2.size() - o1.size();

    public BisectingKMeansClusteringAlgorithm() {
        this.clusterCount = this.attributes.register("clusterCount", AttrInteger.builder().label("Cluster count").min(2).defaultValue(25));
        this.maxIterations = this.attributes.register("maxIterations", AttrInteger.builder().label("Maximum iterations").min(1).defaultValue(15));
        this.partitionCount = this.attributes.register("partitionCount", AttrInteger.builder().label("Partition count").min(2).max(10).defaultValue(2));
        this.labelCount = this.attributes.register("labelCount", AttrInteger.builder().label("Label count").min(1).max(10).defaultValue(3));
        this.queryHint = this.attributes.register("queryHint", SharedInfrastructure.queryHintAttribute());
        this.useDimensionalityReduction = this.attributes.register("useDimensionalityReduction", AttrBoolean.builder().label("Use dimensionality reduction").defaultValue(true));
        this.attributes.register("matrixBuilder", ((AttrObject.Builder)AttrObject.builder(TermDocumentMatrixBuilder.class).label("Term-document matrix builder")).getset(() -> this.matrixBuilder, v -> {
            this.matrixBuilder = v;
        }).defaultValue(TermDocumentMatrixBuilder::new));
        this.attributes.register("matrixReducer", ((AttrObject.Builder)AttrObject.builder(TermDocumentMatrixReducer.class).label("Term-document matrix reducer")).getset(() -> this.matrixReducer, v -> {
            this.matrixReducer = v;
        }).defaultValue(TermDocumentMatrixReducer::new));
        this.attributes.register("preprocessing", ((AttrObject.Builder)AttrObject.builder(BasicPreprocessingPipeline.class).label("Input preprocessing components")).getset(() -> this.preprocessing, v -> {
            this.preprocessing = v;
        }).defaultValue(BasicPreprocessingPipeline::new));
    }

    @Override
    public Set<Class<?>> requiredLanguageComponents() {
        return REQUIRED_LANGUAGE_COMPONENTS;
    }

    @Override
    public <T extends Document> List<Cluster<T>> cluster(Stream<? extends T> docStream, LanguageComponents languageComponents) {
        List documents = docStream.collect(Collectors.toList());
        PreprocessingContext preprocessingContext = this.preprocessing.preprocess(documents.stream(), (String)this.queryHint.get(), languageComponents);
        int[] stemsMfow = preprocessingContext.allStems.mostFrequentOriginalWordIndex;
        short[] wordsType = preprocessingContext.allWords.type;
        IntArrayList featureIndices = new IntArrayList(stemsMfow.length);
        for (int i = 0; i < stemsMfow.length; ++i) {
            short flag = wordsType[stemsMfow[i]];
            if ((flag & 0x3002) != 0) continue;
            featureIndices.add(stemsMfow[i]);
        }
        preprocessingContext.allLabels.featureIndex = featureIndices.toArray();
        preprocessingContext.allLabels.firstPhraseIndex = -1;
        ArrayList clusters = new ArrayList();
        if (preprocessingContext.hasLabels()) {
            IntArrayList largest;
            DoubleMatrix2D tdMatrix;
            VectorSpaceModelContext vsmContext = new VectorSpaceModelContext(preprocessingContext);
            ReducedVectorSpaceModelContext reducedVsmContext = new ReducedVectorSpaceModelContext(vsmContext);
            this.matrixBuilder.buildTermDocumentMatrix(vsmContext);
            this.matrixBuilder.buildTermPhraseMatrix(vsmContext);
            IntIntHashMap rowToStemIndex = new IntIntHashMap();
            for (IntIntCursor c : vsmContext.stemToRowIndex) {
                rowToStemIndex.put(c.value, c.key);
            }
            if (((Boolean)this.useDimensionalityReduction.get()).booleanValue() && (Integer)this.clusterCount.get() * 2 < preprocessingContext.documentCount) {
                this.matrixReducer.reduce(reducedVsmContext, (Integer)this.clusterCount.get() * 2);
                tdMatrix = reducedVsmContext.coefficientMatrix.viewDice();
            } else {
                tdMatrix = vsmContext.termDocumentMatrix;
            }
            IntArrayList columns = new IntArrayList(tdMatrix.columns());
            for (int c = 0; c < tdMatrix.columns(); ++c) {
                columns.add(c);
            }
            ArrayList<IntArrayList> rawClusters = new ArrayList<IntArrayList>();
            rawClusters.addAll(this.split((Integer)this.partitionCount.get(), tdMatrix, columns, (Integer)this.maxIterations.get()));
            Collections.sort(rawClusters, BY_SIZE_DESCENDING);
            int largestIndex = 0;
            while (rawClusters.size() < (Integer)this.clusterCount.get() && largestIndex < rawClusters.size() && (largest = (IntArrayList)rawClusters.get(largestIndex)).size() > (Integer)this.partitionCount.get() * 2) {
                List<IntArrayList> split = this.split((Integer)this.partitionCount.get(), tdMatrix, largest, (Integer)this.maxIterations.get());
                if (split.size() > 1) {
                    rawClusters.remove(largestIndex);
                    rawClusters.addAll(split);
                    Collections.sort(rawClusters, BY_SIZE_DESCENDING);
                    largestIndex = 0;
                    continue;
                }
                ++largestIndex;
            }
            LabelFormatter labelFormatter = languageComponents.get(LabelFormatter.class);
            for (IntArrayList rawCluster : rawClusters) {
                Cluster<Document> cluster = new Cluster<Document>();
                if (rawCluster.size() <= 1) continue;
                this.getLabels(cluster, rawCluster, vsmContext.termDocumentMatrix, rowToStemIndex, preprocessingContext.allStems.mostFrequentOriginalWordIndex, preprocessingContext.allWords.image, labelFormatter);
                for (int j = 0; j < rawCluster.size(); ++j) {
                    cluster.addDocument((Document)documents.get(rawCluster.get(j)));
                }
                clusters.add(cluster);
            }
        }
        return SharedInfrastructure.reorderByDescendingSizeAndLabel(clusters);
    }

    private void getLabels(Cluster<?> cluster, IntArrayList documents, DoubleMatrix2D termDocumentMatrix, IntIntHashMap rowToStemIndex, int[] mostFrequentOriginalWordIndex, char[][] wordImage, LabelFormatter labelFormatter) {
        final DenseDoubleMatrix1D centroid = new DenseDoubleMatrix1D(termDocumentMatrix.rows());
        for (IntCursor d : documents) {
            ((DoubleMatrix1D)centroid).assign(termDocumentMatrix.viewColumn(d.value), Functions.PLUS);
        }
        int[] order = IndirectSort.mergesort((int)0, (int)centroid.size(), (IndirectComparator)new IndirectComparator(){

            public int compare(int a, int b) {
                double valueB;
                double valueA = centroid.get(a);
                return valueA < (valueB = centroid.get(b)) ? -1 : (valueA > valueB ? 1 : 0);
            }
        });
        double minValueForLabel = centroid.get(order[order.length - Math.min((Integer)this.labelCount.get(), order.length)]);
        for (int i = 0; i < centroid.size(); ++i) {
            if (!(((DoubleMatrix1D)centroid).getQuick(i) >= minValueForLabel)) continue;
            cluster.addLabel(labelFormatter.format(new char[][]{wordImage[mostFrequentOriginalWordIndex[rowToStemIndex.get(i)]]}, new boolean[]{false}));
        }
    }

    private List<IntArrayList> split(int partitions, DoubleMatrix2D input, IntArrayList columns, int iterations) {
        int i;
        DoubleMatrix2D selected = input.viewSelection(null, columns.toArray()).copy();
        IntIntHashMap selectedToInput = new IntIntHashMap(selected.columns());
        for (int i2 = 0; i2 < columns.size(); ++i2) {
            selectedToInput.put(i2, columns.get(i2));
        }
        ArrayList<Object> result = new ArrayList<IntArrayList>();
        ArrayList<IntArrayList> previousResult = null;
        for (i = 0; i < partitions; ++i) {
            result.add(new IntArrayList(selected.columns()));
        }
        for (i = 0; i < selected.columns(); ++i) {
            ((IntArrayList)result.get(i % partitions)).add(i);
        }
        DoubleMatrix2D centroids = new DenseDoubleMatrix2D(selected.rows(), partitions).assign(selected.viewPart(0, 0, selected.rows(), partitions));
        DenseDoubleMatrix2D similarities = new DenseDoubleMatrix2D(partitions, selected.columns());
        for (int it = 0; it < iterations; ++it) {
            int i3;
            for (i3 = 0; i3 < result.size(); ++i3) {
                IntArrayList cluster = (IntArrayList)result.get(i3);
                for (int k = 0; k < selected.rows(); ++k) {
                    double sum = 0.0;
                    for (int j = 0; j < cluster.size(); ++j) {
                        sum += selected.get(k, cluster.get(j));
                    }
                    centroids.setQuick(k, i3, sum / (double)cluster.size());
                }
            }
            previousResult = result;
            result = new ArrayList();
            for (i3 = 0; i3 < partitions; ++i3) {
                result.add(new IntArrayList(selected.columns()));
            }
            centroids.zMult(selected, similarities, 1.0, 0.0, true, false);
            for (int c = 0; c < similarities.columns(); ++c) {
                int maxRow = 0;
                double max = similarities.get(0, c);
                for (int r = 1; r < similarities.rows(); ++r) {
                    if (!(max < similarities.get(r, c))) continue;
                    max = similarities.get(r, c);
                    maxRow = r;
                }
                ((IntArrayList)result.get(maxRow)).add(c);
            }
            if (Objects.equals(previousResult, result)) break;
        }
        Iterator it = result.iterator();
        while (it.hasNext()) {
            IntArrayList cluster = (IntArrayList)it.next();
            if (cluster.isEmpty()) {
                it.remove();
                continue;
            }
            for (int j = 0; j < cluster.size(); ++j) {
                cluster.set(j, selectedToInput.get(cluster.get(j)));
            }
        }
        return result;
    }
}

