package sklearn2pmml.ensemble;

import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.Predicate;
import org.dmg.pmml.mining.MiningModel;
import org.dmg.pmml.mining.Segment;
import org.dmg.pmml.mining.Segmentation;
import org.jpmml.converter.MultiLabel;
import org.jpmml.converter.ScalarLabel;
import org.jpmml.converter.Schema;
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.python.DataFrameScope;
import org.jpmml.python.PredicateTranslator;
import org.jpmml.python.TupleUtil;
import sklearn.Estimator;
import sklearn.EstimatorUtil;
import sklearn.HasClasses;
import sklearn.HasEstimatorEnsemble;

/* loaded from: input_file:sklearn2pmml/ensemble/EstimatorChain.class */
public class EstimatorChain extends Estimator implements HasClasses, HasEstimatorEnsemble<Estimator> {
    public EstimatorChain(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Estimator
    public MiningFunction getMiningFunction() {
        return EstimatorUtil.getMiningFunction(getEstimators());
    }

    @Override // sklearn.Estimator, sklearn.HasNumberOfOutputs
    public int getNumberOfOutputs() {
        if (getMultioutput().booleanValue()) {
            return getEstimators().size();
        }
        return 1;
    }

    @Override // sklearn.Estimator
    public boolean isSupervised() {
        return true;
    }

    @Override // sklearn.HasClasses
    public List<?> getClasses() {
        List<? extends Estimator> estimators = getEstimators();
        if (estimators.size() == 1) {
            return EstimatorUtil.getClasses(estimators.get(0));
        }
        if (estimators.size() < 2) {
            throw new IllegalArgumentException();
        }
        ArrayList arrayList = new ArrayList();
        Iterator<? extends Estimator> it = estimators.iterator();
        while (it.hasNext()) {
            arrayList.add(EstimatorUtil.getClasses(it.next()));
        }
        List list = (List) arrayList.stream().distinct().collect(Collectors.toList());
        return list.size() == 1 ? (List) list.get(0) : arrayList;
    }

    @Override // sklearn.Estimator
    /* renamed from: encodeModel */
    public Model mo16encodeModel(Schema schema) {
        MultiLabel multiLabel;
        Boolean multioutput = getMultioutput();
        final List<Object[]> steps = getSteps();
        if (steps.isEmpty()) {
            throw new IllegalArgumentException();
        }
        ScalarLabel label = schema.getLabel();
        List features = schema.getFeatures();
        if (label instanceof ScalarLabel) {
            final ScalarLabel scalarLabel = label;
            multiLabel = new MultiLabel(new AbstractList() { // from class: sklearn2pmml.ensemble.EstimatorChain.1
                @Override // java.util.AbstractCollection, java.util.Collection, java.util.List
                public int size() {
                    return steps.size();
                }

                @Override // java.util.AbstractList, java.util.List
                public ScalarLabel get(int i) {
                    return scalarLabel;
                }
            });
        } else {
            if (!(label instanceof MultiLabel)) {
                throw new IllegalArgumentException();
            }
            multiLabel = (MultiLabel) label;
        }
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        Segmentation segmentation = new Segmentation(multioutput.booleanValue() ? Segmentation.MultipleModelMethod.MULTI_MODEL_CHAIN : Segmentation.MultipleModelMethod.MODEL_CHAIN, (List) null);
        PredicateTranslator predicateTranslator = new PredicateTranslator(new DataFrameScope("X", features));
        for (int i = 0; i < steps.size(); i++) {
            Object[] objArr = steps.get(i);
            String str = (String) TupleUtil.extractElement(objArr, 0, String.class);
            Estimator estimator = (Estimator) TupleUtil.extractElement(objArr, 1, Estimator.class);
            String str2 = (String) TupleUtil.extractElement(objArr, 2, String.class);
            arrayList.add(estimator);
            Schema relabeledSchema = schema.toRelabeledSchema(multiLabel.getLabel(i));
            Predicate translatePredicate = predicateTranslator.translatePredicate(str2);
            Model encode = estimator.encode(relabeledSchema);
            arrayList2.add(encode);
            if (estimator instanceof Link) {
                schema = ((Link) estimator).augmentSchema(encode, relabeledSchema);
            }
            segmentation.addSegments(new Segment[]{new Segment(translatePredicate, encode).setId(str)});
        }
        return new MiningModel(EstimatorUtil.getMiningFunction(arrayList), MiningModelUtil.createMiningSchema(arrayList2)).setSegmentation(segmentation);
    }

    @Override // sklearn.HasEstimatorEnsemble
    public List<? extends Estimator> getEstimators() {
        List<Object[]> steps = getSteps();
        if (steps.isEmpty()) {
            throw new IllegalArgumentException();
        }
        return TupleUtil.extractElementList(steps, 1, Estimator.class);
    }

    public Boolean getMultioutput() {
        return getBoolean("multioutput");
    }

    public List<Object[]> getSteps() {
        return getTupleList("steps");
    }
}
