package org.gradoop.flink.algorithms.fsm.dimspan;

import org.apache.flink.api.common.functions.GroupCombineFunction;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.aggregation.AggregationFunction;
import org.apache.flink.api.java.aggregation.SumAggregationFunction;
import org.apache.flink.api.java.operators.IterativeDataSet;
import org.apache.flink.api.java.operators.MapOperator;
import org.apache.flink.api.java.operators.SingleInputUdfOperator;
import org.gradoop.flink.algorithms.fsm.dimspan.comparison.AlphabeticalLabelComparator;
import org.gradoop.flink.algorithms.fsm.dimspan.comparison.InverseProportionalLabelComparator;
import org.gradoop.flink.algorithms.fsm.dimspan.comparison.LabelComparator;
import org.gradoop.flink.algorithms.fsm.dimspan.comparison.ProportionalLabelComparator;
import org.gradoop.flink.algorithms.fsm.dimspan.config.DIMSpanConfig;
import org.gradoop.flink.algorithms.fsm.dimspan.config.DIMSpanConstants;
import org.gradoop.flink.algorithms.fsm.dimspan.config.DataflowStep;
import org.gradoop.flink.algorithms.fsm.dimspan.config.DictionaryType;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.conversion.DFSCodeToEPGMGraphTransaction;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.mining.CompressPattern;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.mining.CreateCollector;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.mining.ExpandFrequentPatterns;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.mining.Frequent;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.mining.GrowFrequentPatterns;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.mining.InitSingleEdgePatternEmbeddingsMap;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.mining.IsFrequentPatternCollector;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.mining.NotObsolete;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.mining.ReportSupportedPatterns;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.mining.VerifyPattern;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.preprocessing.AggregateMultipleFunctions;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.preprocessing.CreateDictionary;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.preprocessing.EncodeAndPruneEdges;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.preprocessing.EncodeAndPruneVertices;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.preprocessing.MinFrequency;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.preprocessing.NotEmpty;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.preprocessing.ReportEdgeLabels;
import org.gradoop.flink.algorithms.fsm.dimspan.functions.preprocessing.ReportVertexLabels;
import org.gradoop.flink.algorithms.fsm.dimspan.gspan.DirectedGSpanLogic;
import org.gradoop.flink.algorithms.fsm.dimspan.gspan.GSpanLogic;
import org.gradoop.flink.algorithms.fsm.dimspan.gspan.UndirectedGSpanLogic;
import org.gradoop.flink.algorithms.fsm.dimspan.tuples.LabeledGraphIntString;
import org.gradoop.flink.algorithms.fsm.dimspan.tuples.LabeledGraphStringString;
import org.gradoop.flink.model.impl.layouts.transactional.tuples.GraphTransaction;
import org.gradoop.flink.model.impl.operators.count.Count;
import org.gradoop.flink.model.impl.tuples.WithCount;

/* loaded from: input_file:org/gradoop/flink/algorithms/fsm/dimspan/DIMSpan.class */
public class DIMSpan {
    private static final int MAX_ITERATIONS = 100;
    protected final DIMSpanConfig fsmConfig;
    protected DataSet<Long> graphCount;
    protected DataSet<Long> minFrequency;
    protected final GSpanLogic gSpan;
    private DataSet<String[]> vertexDictionary;
    private DataSet<String[]> edgeDictionary;
    private final LabelComparator comparator;

    public DIMSpan(DIMSpanConfig dIMSpanConfig) {
        this.fsmConfig = dIMSpanConfig;
        this.gSpan = dIMSpanConfig.isDirected() ? new DirectedGSpanLogic(dIMSpanConfig) : new UndirectedGSpanLogic(dIMSpanConfig);
        if (dIMSpanConfig.getDictionaryType() == DictionaryType.PROPORTIONAL) {
            this.comparator = new ProportionalLabelComparator();
        } else if (dIMSpanConfig.getDictionaryType() == DictionaryType.INVERSE_PROPORTIONAL) {
            this.comparator = new InverseProportionalLabelComparator();
        } else {
            this.comparator = new AlphabeticalLabelComparator();
        }
    }

    public DataSet<GraphTransaction> execute(DataSet<LabeledGraphStringString> dataSet) {
        return postProcess(mine(preProcess(dataSet)));
    }

    private DataSet<int[]> preProcess(DataSet<LabeledGraphStringString> dataSet) {
        this.graphCount = Count.count(dataSet);
        this.minFrequency = this.graphCount.map(new MinFrequency(this.fsmConfig));
        return encodeEdges(encodeVertices(dataSet)).filter(new NotEmpty());
    }

    protected DataSet<WithCount<int[]>> mine(DataSet<int[]> dataSet) {
        IterativeDataSet<T> iterate = dataSet.map(new InitSingleEdgePatternEmbeddingsMap(this.gSpan, this.fsmConfig)).union(dataSet.getExecutionEnvironment().fromElements(true).map(new CreateCollector())).iterate(100);
        DataSet<?> frequentPatterns = getFrequentPatterns(iterate.flatMap(new ReportSupportedPatterns()));
        return iterate.closeWith(iterate.map(new GrowFrequentPatterns(this.gSpan, this.fsmConfig)).withBroadcastSet(frequentPatterns, "fp").filter(new NotObsolete()), frequentPatterns).filter(new IsFrequentPatternCollector()).flatMap(new ExpandFrequentPatterns());
    }

    private DataSet<GraphTransaction> postProcess(DataSet<WithCount<int[]>> dataSet) {
        return ((MapOperator) ((MapOperator) dataSet.map(new DFSCodeToEPGMGraphTransaction(this.fsmConfig)).withBroadcastSet((DataSet<?>) this.vertexDictionary, DIMSpanConstants.VERTEX_DICTIONARY)).withBroadcastSet((DataSet<?>) this.edgeDictionary, DIMSpanConstants.EDGE_DICTIONARY)).withBroadcastSet((DataSet<?>) this.graphCount, DIMSpanConstants.GRAPH_COUNT);
    }

    private DataSet<LabeledGraphIntString> encodeVertices(DataSet<LabeledGraphStringString> dataSet) {
        this.vertexDictionary = getFrequentLabels(dataSet.flatMap(new ReportVertexLabels())).reduceGroup(new CreateDictionary(this.comparator));
        return dataSet.map(new EncodeAndPruneVertices()).withBroadcastSet((DataSet<?>) this.vertexDictionary, DIMSpanConstants.VERTEX_DICTIONARY);
    }

    private DataSet<int[]> encodeEdges(DataSet<LabeledGraphIntString> dataSet) {
        this.edgeDictionary = getFrequentLabels(dataSet.flatMap(new ReportEdgeLabels())).reduceGroup(new CreateDictionary(this.comparator));
        return dataSet.map(new EncodeAndPruneEdges(this.fsmConfig)).withBroadcastSet((DataSet<?>) this.edgeDictionary, DIMSpanConstants.EDGE_DICTIONARY);
    }

    private DataSet<WithCount<String>> getFrequentLabels(DataSet<WithCount<String>> dataSet) {
        return this.fsmConfig.getDictionaryType() != DictionaryType.RANDOM ? dataSet.groupBy(0).sum(1).filter(new Frequent()).withBroadcastSet((DataSet<?>) this.minFrequency, DIMSpanConstants.MIN_FREQUENCY) : dataSet.distinct();
    }

    private DataSet<WithCount<int[]>> getFrequentPatterns(DataSet<WithCount<int[]>> dataSet) {
        DataSet combineGroup = dataSet.groupBy(0).combineGroup(sumPartition());
        if (this.fsmConfig.getPatternVerificationInStep() == DataflowStep.COMBINE) {
            combineGroup = combineGroup.filter(new VerifyPattern(this.gSpan, this.fsmConfig));
        }
        if (this.fsmConfig.getPatternCompressionInStep() == DataflowStep.COMBINE) {
            combineGroup = combineGroup.map(new CompressPattern());
        }
        SingleInputUdfOperator withBroadcastSet = combineGroup.groupBy(0).sum(1).filter(new Frequent()).withBroadcastSet(this.minFrequency, DIMSpanConstants.MIN_FREQUENCY);
        if (this.fsmConfig.getPatternVerificationInStep() == DataflowStep.FILTER) {
            withBroadcastSet = withBroadcastSet.filter(new VerifyPattern(this.gSpan, this.fsmConfig));
        }
        if (this.fsmConfig.getPatternCompressionInStep() == DataflowStep.FILTER) {
            withBroadcastSet = withBroadcastSet.map(new CompressPattern());
        }
        return withBroadcastSet;
    }

    private GroupCombineFunction<WithCount<int[]>, WithCount<int[]>> sumPartition() {
        return new AggregateMultipleFunctions(new AggregationFunction[]{new SumAggregationFunction.SumAggregationFunctionFactory().createAggregationFunction(Long.class)}, new int[]{1});
    }

    public String getName() {
        return getClass().getSimpleName();
    }
}
