package sklearn.pipeline;

import com.google.common.collect.Lists;
import java.util.List;
import net.razorvine.pickle.objects.ClassDict;
import org.dmg.pmml.Model;
import org.dmg.pmml.PMML;
import org.jpmml.python.CastFunction;
import org.jpmml.python.CastUtil;
import org.jpmml.python.ClassDictUtil;
import org.jpmml.python.TupleUtil;
import org.jpmml.sklearn.Encodable;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Composite;
import sklearn.Estimator;
import sklearn.PassThrough;
import sklearn.SkLearnFields;
import sklearn.SkLearnSteps;
import sklearn.Step;
import sklearn.StepUtil;
import sklearn.Transformer;

/* loaded from: input_file:sklearn/pipeline/SkLearnPipeline.class */
public class SkLearnPipeline extends Composite implements Encodable {
    public SkLearnPipeline() {
        this("sklearn.pipeline", "Pipeline");
    }

    public SkLearnPipeline(String str, String str2) {
        super(str, str2);
    }

    @Override // sklearn.Composite
    public boolean hasTransformers() {
        List<Object[]> steps = getSteps();
        if (steps.isEmpty()) {
            return false;
        }
        return (steps.size() == 1 && hasFinalEstimator()) ? false : true;
    }

    @Override // sklearn.Composite
    public boolean hasFinalEstimator() {
        Object extractElement;
        List<Object[]> steps = getSteps();
        if (steps.isEmpty() || (extractElement = TupleUtil.extractElement(steps.get(steps.size() - 1), 1)) == null || SkLearnSteps.PASSTHROUGH.equals(extractElement)) {
            return false;
        }
        if (extractElement instanceof Composite) {
            return ((Composite) extractElement).hasFinalEstimator();
        }
        if (extractElement instanceof Estimator) {
            return true;
        }
        if (extractElement instanceof Transformer) {
            return false;
        }
        if (extractElement instanceof ClassDict) {
            ClassDict classDict = (ClassDict) extractElement;
            if (isEstimatorLike(classDict)) {
                return true;
            }
            if (isTransformerLike(classDict)) {
                return false;
            }
        }
        return Estimator.class.isInstance(CastUtil.deepCastTo(extractElement, Estimator.class));
    }

    @Override // sklearn.Composite
    public List<? extends Transformer> getTransformers() {
        List<Object[]> steps = getSteps();
        if (hasFinalEstimator()) {
            steps = steps.subList(0, steps.size() - 1);
        }
        return Lists.transform(TupleUtil.extractElementList(steps, 1), new CastFunction<Transformer>(Transformer.class) { // from class: sklearn.pipeline.SkLearnPipeline.1
            /* renamed from: apply, reason: merged with bridge method [inline-methods] */
            public Transformer m29apply(Object obj) {
                return (obj == null || SkLearnSteps.PASSTHROUGH.equals(obj)) ? PassThrough.INSTANCE : (Transformer) super.apply(obj);
            }

            public String formatMessage(Object obj) {
                return "The object (" + ClassDictUtil.formatClass(obj) + ") is not a supported Transformer";
            }
        });
    }

    @Override // sklearn.Composite
    public Estimator getFinalEstimator() {
        return getFinalEstimator(Estimator.class);
    }

    @Override // sklearn.Composite
    public <E extends Estimator> E getFinalEstimator(Class<? extends E> cls) {
        List<Object[]> steps = getSteps();
        if (steps.isEmpty()) {
            throw new IllegalArgumentException("Expected one or more steps, got zero steps");
        }
        Object extractElement = TupleUtil.extractElement(steps.get(steps.size() - 1), 1);
        if (extractElement == null || SkLearnSteps.PASSTHROUGH.equals(extractElement)) {
            throw new IllegalArgumentException();
        }
        return (E) new CastFunction<E>(cls) { // from class: sklearn.pipeline.SkLearnPipeline.2
            public String formatMessage(Object obj) {
                return "The object (" + ClassDictUtil.formatClass(obj) + ") is not a supported Estimator";
            }
        }.apply(extractElement);
    }

    @Override // sklearn.HasHead
    public Step getHead() {
        List<Object[]> steps = getSteps();
        if (steps.isEmpty()) {
            throw new IllegalArgumentException("Expected one or more steps, got zero steps");
        }
        return StepUtil.getHead((Step) new CastFunction<Step>(Step.class) { // from class: sklearn.pipeline.SkLearnPipeline.3
            /* renamed from: apply, reason: merged with bridge method [inline-methods] */
            public Step m30apply(Object obj) {
                if (obj == null || SkLearnSteps.PASSTHROUGH.equals(obj)) {
                    return null;
                }
                return (Step) super.apply(obj);
            }

            public String formatMessage(Object obj) {
                return "The object (" + ClassDictUtil.formatClass(obj) + ") is not a supported Transformer or Estimator";
            }
        }.apply(TupleUtil.extractElement(steps.get(0), 1)));
    }

    @Override // org.jpmml.sklearn.Encodable
    public PMML encodePMML() {
        SkLearnEncoder skLearnEncoder = new SkLearnEncoder();
        Estimator estimator = null;
        if (hasFinalEstimator()) {
            estimator = getFinalEstimator();
            initLabel(null, skLearnEncoder);
        }
        initFeatures(null, skLearnEncoder);
        if (estimator == null) {
            return skLearnEncoder.encodePMML(null);
        }
        Model encode = estimator.encode(skLearnEncoder.createSchema());
        skLearnEncoder.setModel(encode);
        return skLearnEncoder.encodePMML(encode);
    }

    public List<Object[]> getSteps() {
        return getTupleList("steps");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public SkLearnPipeline setSteps(List<Object[]> list) {
        put("steps", list);
        return this;
    }

    private static boolean isEstimatorLike(ClassDict classDict) {
        String className = classDict.getClassName();
        return className.endsWith("Estimator") || className.endsWith("Classifier") || className.endsWith("Regressor") || classDict.containsKey(SkLearnFields.N_OUTPUTS) || classDict.containsKey(SkLearnFields.N_CLASSES) || classDict.containsKey(SkLearnFields.CLASSES);
    }

    private static boolean isTransformerLike(ClassDict classDict) {
        return classDict.getClassName().endsWith("Transformer");
    }
}
