package com.github.bentorfs.ai.algorithms.ml.associationrules.apriori;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/github/bentorfs/ai/algorithms/ml/associationrules/apriori/AprioriAlgorithm.class */
public class AprioriAlgorithm {
    protected Logger logger = LoggerFactory.getLogger(getClass());
    private double minSupport;
    private double maxSupport;
    private double minConfidence;
    private double maxItemSetSize;

    public AprioriAlgorithm(int i, double d, double d2, double d3) {
        this.minSupport = d;
        this.maxSupport = d2;
        this.minConfidence = d3;
        this.maxItemSetSize = i;
    }

    public List<ItemSet> getFrequentItemSets(Collection<Transaction> collection) {
        ArrayList arrayList = new ArrayList();
        Set<Item> allItems = getAllItems(collection);
        this.logger.info("Generating frequent itemsets from " + collection.size() + " transactions and " + allItems.size() + " different items");
        List<ItemSet> itemSetsOfSize1 = getItemSetsOfSize1(allItems);
        int i = 1;
        do {
            List<ItemSet> frequentItemSets = getFrequentItemSets(itemSetsOfSize1, collection);
            arrayList.addAll(frequentItemSets);
            i++;
            itemSetsOfSize1 = getCandidateItemSets(frequentItemSets, i);
            if (itemSetsOfSize1.size() <= 0) {
                break;
            }
        } while (i < this.maxItemSetSize);
        this.logger.info("Done generating frequent itemsets. " + arrayList.size() + " frequent itemsets found");
        return arrayList;
    }

    private List<ItemSet> getFrequentItemSets(List<ItemSet> list, Collection<Transaction> collection) {
        this.logger.info("Selecting frequent itemsets from " + list.size() + " candidates");
        ArrayList arrayList = new ArrayList();
        for (ItemSet itemSet : list) {
            double d = 0.0d;
            Iterator<Transaction> it = collection.iterator();
            while (it.hasNext()) {
                if (it.next().getItems().containsAll(itemSet.getItems())) {
                    d += 1.0d;
                }
            }
            double size = d / collection.size();
            if (size >= this.minSupport && size <= this.maxSupport) {
                itemSet.setSupport(size);
                arrayList.add(itemSet);
            }
        }
        this.logger.info("Found " + arrayList.size() + " frequent itemsets");
        return arrayList;
    }

    private List<ItemSet> getCandidateItemSets(List<ItemSet> list, int i) {
        this.logger.info("Generating candidate frequent itemsets of size " + i + " from itemsets of size " + (i - 1));
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < list.size(); i2++) {
            ItemSet itemSet = list.get(i2);
            for (int i3 = i2 + 1; i3 < list.size(); i3++) {
                ItemSet itemSet2 = list.get(i3);
                TreeSet<Item> items = itemSet.getItems();
                Item pollLast = items.pollLast();
                TreeSet<Item> items2 = itemSet2.getItems();
                Item pollLast2 = items2.pollLast();
                if (items.equals(items2) && pollLast.isCompatibleWith(pollLast2)) {
                    ItemSet itemSet3 = new ItemSet();
                    itemSet3.getItems().addAll(items);
                    itemSet3.getItems().add(pollLast);
                    itemSet3.getItems().add(pollLast2);
                    arrayList.add(itemSet3);
                }
                items.add(pollLast);
                items2.add(pollLast2);
            }
        }
        this.logger.info("Found " + arrayList.size() + " candidate frequent itemsets of size " + i);
        return arrayList;
    }

    private List<ItemSet> getItemSetsOfSize1(Set<Item> set) {
        ArrayList arrayList = new ArrayList();
        for (Item item : set) {
            ItemSet itemSet = new ItemSet();
            itemSet.getItems().add(item);
            arrayList.add(itemSet);
        }
        return arrayList;
    }

    private Set<Item> getAllItems(Collection<Transaction> collection) {
        HashSet hashSet = new HashSet();
        Iterator<Transaction> it = collection.iterator();
        while (it.hasNext()) {
            hashSet.addAll(it.next().getItems());
        }
        return hashSet;
    }

    public List<AssociationRule> getAssociationRules(List<ItemSet> list) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < list.size(); i++) {
            ItemSet itemSet = list.get(i);
            for (int i2 = 0; i2 < list.size(); i2++) {
                ItemSet itemSet2 = list.get(i2);
                if (itemSet != itemSet2 && itemSet.getItems().size() > itemSet2.getItems().size() && itemSet.getItems().containsAll(itemSet2.getItems())) {
                    double support = itemSet.getSupport() / itemSet2.getSupport();
                    if (support >= this.minConfidence) {
                        HashSet hashSet = new HashSet(itemSet2.getItems());
                        HashSet hashSet2 = new HashSet(itemSet.getItems());
                        hashSet2.removeAll(itemSet2.getItems());
                        arrayList.add(new AssociationRule(hashSet, hashSet2, itemSet2.getSupport(), support));
                    }
                }
            }
        }
        return arrayList;
    }
}
