package org.languagetool.rules.spelling.suggestions;

import java.io.FileNotFoundException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.SortedMap;
import java.util.stream.Collectors;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.pool2.BaseKeyedPooledObjectFactory;
import org.apache.commons.pool2.KeyedObjectPool;
import org.apache.commons.pool2.PooledObject;
import org.apache.commons.pool2.impl.DefaultPooledObject;
import org.apache.commons.pool2.impl.GenericKeyedObjectPool;
import org.jetbrains.annotations.NotNull;
import org.languagetool.AnalyzedSentence;
import org.languagetool.JLanguageTool;
import org.languagetool.Language;
import org.languagetool.languagemodel.LanguageModel;
import org.languagetool.rules.SuggestedReplacement;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/languagetool/rules/spelling/suggestions/XGBoostSuggestionsOrderer.class */
public class XGBoostSuggestionsOrderer extends SuggestionsOrdererFeatureExtractor implements SuggestionsRanker {
    private boolean modelAvailableForLanguage;
    private static final Logger logger = LoggerFactory.getLogger(XGBoostSuggestionsOrderer.class);
    private static final KeyedObjectPool<Language, Booster> modelPool = new GenericKeyedObjectPool(new BaseKeyedPooledObjectFactory<Language, Booster>() { // from class: org.languagetool.rules.spelling.suggestions.XGBoostSuggestionsOrderer.1
        public Booster create(Language language) throws Exception {
            String str = "";
            try {
                str = XGBoostSuggestionsOrderer.getModelPath(language);
                return XGBoost.loadModel(JLanguageTool.getDataBroker().getFromResourceDirAsStream(str));
            } catch (FileNotFoundException e) {
                XGBoostSuggestionsOrderer.logger.warn(String.format("Could not load suggestion ranking model at '%s'. Platform might be unsupported by the official XGBoost maven package, or model might be missing/corrupted.", str), e);
                return null;
            }
        }

        public PooledObject<Booster> wrap(Booster booster) {
            return new DefaultPooledObject(booster);
        }
    });
    private static final Map<String, Float> autoCorrectThreshold = new HashMap();
    private static final Map<String, List<Integer>> modelClasses = new HashMap();
    private static final Map<String, Integer> candidateFeatureCount = new HashMap();
    private static final Map<String, Integer> matchFeatureCount = new HashMap();
    private static boolean xgboostNotSupported = false;

    /* JADX INFO: Access modifiers changed from: private */
    @NotNull
    public static String getModelPath(Language language) {
        return "/" + language.getShortCode() + "/spelling_correction_model.bin";
    }

    public static void setAutoCorrectThresholdForLanguage(Language language, float f) {
        autoCorrectThreshold.replace(language.getShortCodeWithCountryAndVariant(), Float.valueOf(f));
    }

    public XGBoostSuggestionsOrderer(Language language, LanguageModel languageModel) {
        super(language, languageModel);
        this.modelAvailableForLanguage = false;
        String shortCodeWithCountryAndVariant = language.getShortCodeWithCountryAndVariant();
        if (xgboostNotSupported) {
            return;
        }
        if (System.getProperty("os.name").toLowerCase().startsWith("windows")) {
            xgboostNotSupported = true;
            System.err.println("Warning: At the moment, your platform (Windows) is not supported by the official XGBoost maven package; ML-based suggestion reordering is disabled.");
            return;
        }
        if (autoCorrectThreshold.containsKey(shortCodeWithCountryAndVariant) && modelClasses.containsKey(shortCodeWithCountryAndVariant) && JLanguageTool.getDataBroker().resourceExists(getModelPath(this.language))) {
            try {
                Booster booster = (Booster) modelPool.borrowObject(this.language);
                if (booster != null) {
                    modelPool.returnObject(this.language, booster);
                    this.modelAvailableForLanguage = true;
                }
            } catch (Exception e) {
                logger.warn("Could not load spelling suggestion ranking model for language " + this.language, e);
            } catch (ExceptionInInitializerError | NoClassDefFoundError | UnsatisfiedLinkError e2) {
                logger.warn("At the moment, your platform (Windows?) or architecture (32 bit?) is not supported by the official XGBoost maven package; ML-based suggestion reordering is disabled.", e2);
                xgboostNotSupported = true;
            }
        }
    }

    @Override // org.languagetool.rules.spelling.suggestions.SuggestionsOrdererFeatureExtractor
    protected void initParameters() {
        this.topN = 5;
        this.score = "noop";
        this.mistakeProb = 0.0d;
    }

    @Override // org.languagetool.rules.spelling.suggestions.SuggestionsOrdererFeatureExtractor, org.languagetool.rules.spelling.suggestions.SuggestionsOrderer
    public boolean isMlAvailable() {
        return super.isMlAvailable() && this.modelAvailableForLanguage;
    }

    @Override // org.languagetool.rules.spelling.suggestions.SuggestionsOrdererFeatureExtractor, org.languagetool.rules.spelling.suggestions.SuggestionsOrderer
    public List<SuggestedReplacement> orderSuggestions(List<String> list, String str, AnalyzedSentence analyzedSentence, int i) {
        if (!isMlAvailable()) {
            throw new IllegalStateException("Illegal call to orderSuggestions() - isMlAvailable() returned false.");
        }
        System.currentTimeMillis();
        String shortCodeWithCountryAndVariant = this.language.getShortCodeWithCountryAndVariant();
        Pair<List<SuggestedReplacement>, SortedMap<String, Float>> computeFeatures = computeFeatures(list, str, analyzedSentence, i);
        List<SuggestedReplacement> list2 = (List) computeFeatures.getLeft();
        SortedMap sortedMap = (SortedMap) computeFeatures.getRight();
        List<SortedMap> list3 = (List) list2.stream().map((v0) -> {
            return v0.getFeatures();
        }).collect(Collectors.toList());
        if (list2.isEmpty()) {
            return Collections.emptyList();
        }
        if (list2.size() != list3.size()) {
            throw new RuntimeException(String.format("Mismatch between candidates and corresponding feature list: length %d / %d", Integer.valueOf(list2.size()), Integer.valueOf(list3.size())));
        }
        int size = sortedMap.size() + (this.topN * ((SortedMap) list3.get(0)).size());
        float[] fArr = new float[size];
        int i2 = 0;
        int intValue = matchFeatureCount.getOrDefault(shortCodeWithCountryAndVariant, -1).intValue();
        int intValue2 = candidateFeatureCount.getOrDefault(shortCodeWithCountryAndVariant, -1).intValue();
        if (sortedMap.size() != intValue) {
            logger.warn(String.format("Match features '%s' do not have expected size %d.", sortedMap, Integer.valueOf(intValue)));
        }
        Iterator it = sortedMap.entrySet().iterator();
        while (it.hasNext()) {
            int i3 = i2;
            i2++;
            fArr[i3] = ((Float) ((Map.Entry) it.next()).getValue()).floatValue();
        }
        for (SortedMap sortedMap2 : list3) {
            if (sortedMap2.size() != intValue2) {
                logger.warn(String.format("Candidate features '%s' do not have expected size %d.", sortedMap2, Integer.valueOf(intValue2)));
            }
            Iterator it2 = sortedMap2.entrySet().iterator();
            while (it2.hasNext()) {
                int i4 = i2;
                i2++;
                fArr[i4] = ((Float) ((Map.Entry) it2.next()).getValue()).floatValue();
            }
        }
        List<Integer> list4 = modelClasses.get(shortCodeWithCountryAndVariant);
        try {
            try {
                System.currentTimeMillis();
                Booster booster = (Booster) modelPool.borrowObject(this.language);
                DMatrix dMatrix = new DMatrix(fArr, 1, size);
                System.currentTimeMillis();
                float[][] predict = booster.predict(dMatrix);
                if (predict.length != 1) {
                    throw new XGBoostError(String.format("XGBoost returned array with first dimension of length %d, expected 1.", Integer.valueOf(predict.length)));
                }
                float[] fArr2 = predict[0];
                if (fArr2.length != list4.size()) {
                    throw new XGBoostError(String.format("XGBoost returned array with second dimension of length %d, expected %d.", Integer.valueOf(fArr2.length), Integer.valueOf(list4.size())));
                }
                for (int i5 = 0; i5 < list2.size(); i5++) {
                    int indexOf = list4.indexOf(Integer.valueOf(i5));
                    float f = 0.0f;
                    if (indexOf != -1) {
                        f = fArr2[indexOf];
                    }
                    list2.get(i5).setConfidence(Float.valueOf(f));
                }
                if (booster != null) {
                    try {
                        modelPool.returnObject(this.language, booster);
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
                list2.sort(Collections.reverseOrder(Comparator.comparing((v0) -> {
                    return v0.getConfidence();
                })));
                return list2;
            } catch (Exception e2) {
                logger.error("Error while loading XGBoost model for spelling suggestions", e2);
                if (0 != 0) {
                    try {
                        modelPool.returnObject(this.language, (Object) null);
                    } catch (Exception e3) {
                        throw new RuntimeException(e3);
                    }
                }
                return list2;
            } catch (XGBoostError e4) {
                logger.error("Error while applying XGBoost model to spelling suggestions", e4);
                if (0 != 0) {
                    try {
                        modelPool.returnObject(this.language, (Object) null);
                    } catch (Exception e5) {
                        throw new RuntimeException(e5);
                    }
                }
                return list2;
            }
        } catch (Throwable th) {
            if (0 != 0) {
                try {
                    modelPool.returnObject(this.language, (Object) null);
                } catch (Exception e6) {
                    throw new RuntimeException(e6);
                }
            }
            throw th;
        }
    }

    @Override // org.languagetool.rules.spelling.suggestions.SuggestionsRanker
    public boolean shouldAutoCorrect(List<SuggestedReplacement> list) {
        if (list.isEmpty() || list.stream().anyMatch(suggestedReplacement -> {
            return suggestedReplacement.getConfidence() == null;
        })) {
            return false;
        }
        return list.get(0).getConfidence().floatValue() >= autoCorrectThreshold.getOrDefault(this.language.getShortCodeWithCountryAndVariant(), Float.valueOf(Float.MAX_VALUE)).floatValue();
    }

    static {
        List<Integer> asList = Arrays.asList(-1, 0, 1, 2, 3, 4);
        autoCorrectThreshold.put("en-US", Float.valueOf(0.99897194f));
        modelClasses.put("en-US", asList);
        candidateFeatureCount.put("en-US", 10);
        matchFeatureCount.put("en-US", 1);
    }
}
