package com.o19s.es.ltr.feature.store;

import com.o19s.es.ltr.LtrQueryContext;
import com.o19s.es.ltr.feature.Feature;
import com.o19s.es.ltr.feature.FeatureSet;
import com.o19s.es.ltr.query.FeatureVectorWeight;
import com.o19s.es.ltr.ranker.LtrRanker;
import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.elasticsearch.common.lucene.search.function.LeafScoreFunction;
import org.elasticsearch.common.lucene.search.function.ScriptScoreFunction;
import org.elasticsearch.common.xcontent.LoggingDeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.script.ScoreScript;
import org.elasticsearch.script.Script;

/* loaded from: input_file:com/o19s/es/ltr/feature/store/ScriptFeature.class */
public class ScriptFeature implements Feature {
    public static final String TEMPLATE_LANGUAGE = "script_feature";
    public static final String FEATURE_VECTOR = "feature_vector";
    public static final String EXTRA_SCRIPT_PARAMS = "extra_script_params";
    private final String name;
    private final Script script;
    private final Collection<String> queryParams;
    private final Map<String, Object> baseScriptParams;
    private final Map<String, String> extraScriptParams;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/o19s/es/ltr/feature/store/ScriptFeature$LtrScript.class */
    public static class LtrScript extends Query {
        private final ScriptScoreFunction scoreFunction;
        private final FeatureSupplier featureSupplier;

        /* loaded from: input_file:com/o19s/es/ltr/feature/store/ScriptFeature$LtrScript$LtrScriptWeight.class */
        class LtrScriptWeight extends FeatureVectorWeight {
            protected LtrScriptWeight(Query query) {
                super(query);
            }

            @Override // com.o19s.es.ltr.query.FeatureVectorWeight
            public Explanation explain(LeafReaderContext leafReaderContext, LtrRanker.FeatureVector featureVector, int i) throws IOException {
                LtrScript.this.featureSupplier.set(() -> {
                    return featureVector;
                });
                Scorer scorer = scorer(leafReaderContext, () -> {
                    return featureVector;
                });
                int advance = scorer.iterator().advance(i);
                return advance == i ? Explanation.match(scorer.score(), "weight(" + getQuery() + " in doc " + advance + ")", new Explanation[0]) : Explanation.noMatch("no matching term", new Explanation[0]);
            }

            @Override // com.o19s.es.ltr.query.FeatureVectorWeight
            public Scorer scorer(LeafReaderContext leafReaderContext, Supplier<LtrRanker.FeatureVector> supplier) throws IOException {
                LtrScript.this.featureSupplier.set(supplier);
                final LeafScoreFunction leafScoreFunction = LtrScript.this.scoreFunction.getLeafScoreFunction(leafReaderContext);
                final DocIdSetIterator all = DocIdSetIterator.all(leafReaderContext.reader().maxDoc());
                return new Scorer(this) { // from class: com.o19s.es.ltr.feature.store.ScriptFeature.LtrScript.LtrScriptWeight.1
                    public int docID() {
                        return all.docID();
                    }

                    public float score() throws IOException {
                        return (float) leafScoreFunction.score(all.docID(), 0.0f);
                    }

                    public DocIdSetIterator iterator() {
                        return all;
                    }
                };
            }

            public void extractTerms(Set<Term> set) {
            }

            public boolean isCacheable(LeafReaderContext leafReaderContext) {
                return false;
            }
        }

        public boolean equals(Object obj) {
            if (this == obj) {
                return true;
            }
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            LtrScript ltrScript = (LtrScript) obj;
            return Objects.equals(this.scoreFunction, ltrScript.scoreFunction) && Objects.equals(this.featureSupplier, ltrScript.featureSupplier);
        }

        public int hashCode() {
            return Objects.hash(this.scoreFunction, this.featureSupplier);
        }

        LtrScript(ScriptScoreFunction scriptScoreFunction, FeatureSupplier featureSupplier) {
            this.scoreFunction = scriptScoreFunction;
            this.featureSupplier = featureSupplier;
        }

        public String toString(String str) {
            return "LtrScript:" + str;
        }

        public Weight createWeight(IndexSearcher indexSearcher, boolean z, float f) throws IOException {
            return new LtrScriptWeight(this);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v25, types: [java.util.Map] */
    public ScriptFeature(String str, Script script, Collection<String> collection) {
        this.name = (String) Objects.requireNonNull(str);
        this.script = (Script) Objects.requireNonNull(script);
        this.queryParams = collection;
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (Map.Entry entry : script.getParams().entrySet()) {
            if (((String) entry.getKey()).equals(EXTRA_SCRIPT_PARAMS)) {
                hashMap2 = (Map) entry.getValue();
            } else {
                hashMap.put(String.valueOf(entry.getKey()), entry.getValue());
            }
        }
        this.baseScriptParams = hashMap;
        this.extraScriptParams = hashMap2;
    }

    public static ScriptFeature compile(StoredFeature storedFeature) {
        try {
            return new ScriptFeature(storedFeature.name(), Script.parse(XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, storedFeature.template()), "native"), storedFeature.queryParams());
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override // com.o19s.es.ltr.feature.Feature
    public String name() {
        return this.name;
    }

    @Override // com.o19s.es.ltr.feature.Feature
    public Query doToQuery(LtrQueryContext ltrQueryContext, FeatureSet featureSet, Map<String, Object> map) {
        List list = (List) this.queryParams.stream().filter(str -> {
            return !map.containsKey(str);
        }).collect(Collectors.toList());
        if (!list.isEmpty()) {
            throw new IllegalArgumentException("Missing required param(s): [" + ((String) list.stream().collect(Collectors.joining(","))) + "]");
        }
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (String str2 : this.queryParams) {
            if (map.containsKey(str2)) {
                if (this.extraScriptParams.containsKey(str2)) {
                    hashMap2.put(this.extraScriptParams.get(str2), map.get(str2));
                } else {
                    hashMap.put(str2, map.get(str2));
                }
            }
        }
        HashMap hashMap3 = new HashMap();
        FeatureSupplier featureSupplier = new FeatureSupplier(featureSet);
        hashMap3.putAll(this.baseScriptParams);
        hashMap3.putAll(hashMap);
        hashMap3.putAll(hashMap2);
        hashMap3.put(FEATURE_VECTOR, featureSupplier);
        return new LtrScript(new ScriptScoreFunction(this.script, ((ScoreScript.Factory) ltrQueryContext.getQueryShardContext().getScriptService().compile(this.script, ScoreScript.CONTEXT)).newFactory(hashMap3, ltrQueryContext.getQueryShardContext().lookup())), featureSupplier);
    }
}
