package ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling;

import ai.libs.jaicore.ml.core.dataset.schema.DatasetPropertyComputer;
import ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling.DiscretizationHelper;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import org.api4.java.ai.ml.core.dataset.IDataset;
import org.api4.java.ai.ml.core.dataset.IInstance;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledDataset;
import org.api4.java.ai.ml.core.dataset.supervised.ILabeledInstance;
import org.api4.java.common.control.ILoggingCustomizable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/ml/core/filter/sampling/inmemory/stratified/sampling/AttributeBasedStratiAmountSelectorAndAssigner.class */
public class AttributeBasedStratiAmountSelectorAndAssigner implements IStratiAmountSelector, IStratiAssigner, ILoggingCustomizable {
    private Logger logger;
    private static final DiscretizationHelper.DiscretizationStrategy DEFAULT_DISCRETIZATION_STRATEGY = DiscretizationHelper.DiscretizationStrategy.EQUAL_SIZE;
    private final DiscretizationHelper discretizationHelper;
    private static final int DEFAULT_DISCRETIZATION_CATEGORY_AMOUNT = 5;
    private List<Integer> attributeIndices;
    private Map<List<Object>, Integer> stratumIDs;
    private int numCPUs;
    private IDataset<?> dataset;
    private int numAttributes;
    private Map<Integer, AttributeDiscretizationPolicy> discretizationPolicies;
    private DiscretizationHelper.DiscretizationStrategy discretizationStrategy;
    private int numberOfCategories;
    private boolean initialized;

    public AttributeBasedStratiAmountSelectorAndAssigner() {
        this.logger = LoggerFactory.getLogger(AttributeBasedStratiAmountSelectorAndAssigner.class);
        this.discretizationHelper = new DiscretizationHelper();
        this.numCPUs = 1;
        this.discretizationStrategy = DEFAULT_DISCRETIZATION_STRATEGY;
        this.numberOfCategories = DEFAULT_DISCRETIZATION_CATEGORY_AMOUNT;
    }

    public AttributeBasedStratiAmountSelectorAndAssigner(List<Integer> list) {
        this(list, null);
        this.discretizationStrategy = DEFAULT_DISCRETIZATION_STRATEGY;
        this.numberOfCategories = DEFAULT_DISCRETIZATION_CATEGORY_AMOUNT;
    }

    public AttributeBasedStratiAmountSelectorAndAssigner(List<Integer> list, DiscretizationHelper.DiscretizationStrategy discretizationStrategy, int i) {
        this(list, null);
        this.discretizationStrategy = discretizationStrategy;
        this.numberOfCategories = i;
    }

    public AttributeBasedStratiAmountSelectorAndAssigner(List<Integer> list, Map<Integer, AttributeDiscretizationPolicy> map) {
        this.logger = LoggerFactory.getLogger(AttributeBasedStratiAmountSelectorAndAssigner.class);
        this.discretizationHelper = new DiscretizationHelper();
        this.numCPUs = 1;
        if (list == null || list.isEmpty()) {
            throw new IllegalArgumentException("No attribute indices are provided!");
        }
        this.attributeIndices = list;
        this.discretizationPolicies = map;
        this.logger.info("Created assigner. Attributes to be discretized: {}", map == null ? "none" : map.keySet());
    }

    @Override // ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling.IStratiAmountSelector
    public int selectStratiAmount(IDataset<?> iDataset) {
        this.logger.debug("Selecting number of strati for dataset with {} items.", Integer.valueOf(iDataset.size()));
        if (this.dataset == null) {
            init(iDataset, -1);
        } else if (!this.dataset.equals(iDataset)) {
            throw new IllegalArgumentException("Can only select strati amount for a dataset provided before.");
        }
        return this.stratumIDs.size();
    }

    private void discretizeAttributeValues(Map<Integer, Set<Object>> map) {
        if (this.discretizationPolicies == null) {
            this.logger.info("No discretization policies provided. Computing defaults.");
            this.discretizationPolicies = this.discretizationHelper.createDefaultDiscretizationPolicies(this.dataset, this.attributeIndices, map, this.discretizationStrategy, this.numberOfCategories);
        }
        if (!this.discretizationPolicies.isEmpty()) {
            if (this.logger.isInfoEnabled()) {
                this.logger.info("Discretizing numeric attributes using policies: {}", this.discretizationPolicies);
            }
            this.discretizationHelper.discretizeAttributeValues(this.discretizationPolicies, map);
        }
        this.logger.info("computeAttributeValues(): leave");
    }

    public void setNumCPUs(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Number of CPU cores must be nonnegative");
        }
        this.numCPUs = i;
    }

    public int getNumCPUs() {
        return this.numCPUs;
    }

    public void init(IDataset<?> iDataset) {
        init(iDataset, -1);
    }

    @Override // ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling.IStratiAssigner
    public void init(IDataset<?> iDataset, int i) {
        this.logger.debug("init(): enter");
        if (this.initialized) {
            this.logger.warn("Ignoring further initialization.");
            return;
        }
        if (iDataset == null) {
            throw new IllegalArgumentException("Cannot set dataset to NULL");
        }
        this.dataset = iDataset;
        this.numAttributes = iDataset.getNumAttributes();
        int numAttributes = iDataset.getNumAttributes();
        Iterator<Integer> it = this.attributeIndices.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (intValue < 0) {
                throw new IllegalArgumentException("Attribute index for stratified splits must not be negative!");
            }
            if (intValue > numAttributes) {
                throw new IllegalArgumentException("Attribute index for stratified splits must not exceed number of attributes!");
            }
            if (intValue == numAttributes && !(iDataset instanceof ILabeledDataset)) {
                throw new IllegalArgumentException("Attribute index for stratified splits must only equal the number of attributes if the dataset is labeled, because then the label column id is the number of attributes!");
            }
        }
        Map<Integer, Set<Object>> computeAttributeValues = DatasetPropertyComputer.computeAttributeValues(iDataset, this.attributeIndices, this.numCPUs);
        discretizeAttributeValues(computeAttributeValues);
        Set cartesianProduct = Sets.cartesianProduct(new ArrayList(computeAttributeValues.values()));
        this.logger.info("There are {} elements in the cartesian product of the attribute values", Integer.valueOf(cartesianProduct.size()));
        this.logger.info("Assigning stratum numbers to elements in the cartesian product..");
        this.stratumIDs = new HashMap();
        int i2 = 0;
        Iterator it2 = cartesianProduct.iterator();
        while (it2.hasNext()) {
            int i3 = i2;
            i2++;
            this.stratumIDs.put((List) it2.next(), Integer.valueOf(i3));
        }
        this.logger.info("Initialized strati assigner with {} strati.", Integer.valueOf(this.stratumIDs.size()));
        this.initialized = true;
    }

    @Override // ai.libs.jaicore.ml.core.filter.sampling.inmemory.stratified.sampling.IStratiAssigner
    public int assignToStrati(IInstance iInstance) {
        Object attributeValue;
        if (!this.initialized) {
            throw new IllegalStateException("Assigner has not been initialized yet.");
        }
        ArrayList arrayList = new ArrayList(this.attributeIndices.size());
        for (int i = 0; i < this.attributeIndices.size(); i++) {
            int intValue = this.attributeIndices.get(i).intValue();
            if (toBeDiscretized(intValue)) {
                attributeValue = Integer.valueOf(this.discretizationHelper.discretize(((Double) (intValue == this.dataset.getNumAttributes() ? ((ILabeledInstance) iInstance).getLabel() : iInstance.getAttributeValue(intValue))).doubleValue(), this.discretizationPolicies.get(Integer.valueOf(intValue))));
                Objects.requireNonNull(attributeValue);
            } else if (intValue == this.numAttributes) {
                attributeValue = ((ILabeledInstance) iInstance).getLabel();
                if (attributeValue == null) {
                    throw new IllegalArgumentException("Cannot assign data point " + iInstance + " to any stratum, because it has no label.");
                }
            } else {
                attributeValue = iInstance.getAttributeValue(intValue);
                Objects.requireNonNull(attributeValue);
            }
            arrayList.add(attributeValue);
        }
        int intValue2 = this.stratumIDs.get(arrayList).intValue();
        this.logger.debug("Attribute values are: {}. Corresponding stratum is: {}", arrayList, Integer.valueOf(intValue2));
        return intValue2;
    }

    private boolean toBeDiscretized(int i) {
        return this.discretizationPolicies.containsKey(Integer.valueOf(i));
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String str) {
        this.logger = LoggerFactory.getLogger(str);
        this.discretizationHelper.setLoggerName(str + ".discretizer");
    }
}
