package io.datarouter.batchsizeoptimizer;

import io.datarouter.batchsizeoptimizer.math.PolynomialRegressionOptimumFinder;
import io.datarouter.batchsizeoptimizer.math.PolynomialRegressionOptimumFinderPoint;
import io.datarouter.batchsizeoptimizer.storage.optimizedbatch.DatarouterOpOptimizedBatchSizeDao;
import io.datarouter.batchsizeoptimizer.storage.optimizedbatch.OpOptimizedBatchSize;
import io.datarouter.batchsizeoptimizer.storage.optimizedbatch.OpOptimizedBatchSizeKey;
import io.datarouter.batchsizeoptimizer.storage.performancerecord.DatarouterOpPerformanceRecordDao;
import io.datarouter.batchsizeoptimizer.storage.performancerecord.OpPerformanceRecord;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import javax.inject.Inject;
import javax.inject.Singleton;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Singleton
/* loaded from: input_file:io/datarouter/batchsizeoptimizer/BatchSizeOptimizer.class */
public class BatchSizeOptimizer {
    private static final Logger logger = LoggerFactory.getLogger(BatchSizeOptimizer.class);
    public static final int DEFAULT_BATCH_SIZE = 1000;
    public static final double DEFAULT_CURIOSITY = 0.1d;
    private static final int MAX_BATCH_SIZE = 10000;
    private static final int MIN_BATCH_SIZE = 1;
    private static final double STEP = 0.1d;
    private final DatarouterOpPerformanceRecordDao opPerformanceRecordDao;
    private final DatarouterOpOptimizedBatchSizeDao opOptimizedBatchSizeDao;
    private final Map<OpOptimizedBatchSizeKey, CachedOpOptimizedBatchSize> cachedOpOptimizedBatchSize = new ConcurrentHashMap();
    private final Random random = new Random();

    /* loaded from: input_file:io/datarouter/batchsizeoptimizer/BatchSizeOptimizer$NodePerformanceStats.class */
    public static class NodePerformanceStats {
        private long count = 0;
        private long speedSum = 0;

        public void addRecord(OpPerformanceRecord opPerformanceRecord) {
            long longValue = opPerformanceRecord.getRowCount().longValue() / opPerformanceRecord.getBatchSize().intValue();
            this.count += longValue;
            this.speedSum += longValue * opPerformanceRecord.getRowsPerSeconds().longValue();
        }

        public int getMean() {
            if (this.count == 0) {
                return 0;
            }
            return (int) (this.speedSum / this.count);
        }
    }

    @Inject
    public BatchSizeOptimizer(DatarouterOpPerformanceRecordDao datarouterOpPerformanceRecordDao, DatarouterOpOptimizedBatchSizeDao datarouterOpOptimizedBatchSizeDao) {
        this.opPerformanceRecordDao = datarouterOpPerformanceRecordDao;
        this.opOptimizedBatchSizeDao = datarouterOpOptimizedBatchSizeDao;
    }

    public int getRecommendedBatchSize(String str) {
        return getRecommendedBatchSize(str, Integer.MAX_VALUE);
    }

    public int getRecommendedBatchSize(String str, int i) {
        int optimalBatchSize = getOptimalBatchSize(str);
        int max = Math.max(MIN_BATCH_SIZE, (int) (optimalBatchSize * getCuriosity(str)));
        return i <= optimalBatchSize - max ? i : Math.max(MIN_BATCH_SIZE, optimalBatchSize + (this.random.nextInt(MIN_BATCH_SIZE + (max * 2)) - max));
    }

    public void recordBatchSizeAndTime(String str, int i, long j, long j2) {
        this.opPerformanceRecordDao.put(new OpPerformanceRecord(str, Integer.valueOf(i), Long.valueOf(j), Long.valueOf(j2)));
    }

    public void computeAndSaveOptimalBatchSizeForAllOps() {
        ArrayList arrayList = new ArrayList();
        TreeMap treeMap = new TreeMap();
        String str = null;
        OpPerformanceRecord opPerformanceRecord = null;
        for (OpPerformanceRecord opPerformanceRecord2 : this.opPerformanceRecordDao.scan().iterable()) {
            if (!opPerformanceRecord2.getKey().getOpName().equals(str)) {
                if (str != null) {
                    arrayList.add(computeOptimalBatchSizeAndCuriosityForOp(str, treeMap, opPerformanceRecord2));
                    treeMap = new TreeMap();
                }
                str = opPerformanceRecord2.getKey().getOpName();
            }
            treeMap.computeIfAbsent(opPerformanceRecord2.getBatchSize(), num -> {
                return new NodePerformanceStats();
            }).addRecord(opPerformanceRecord2);
            opPerformanceRecord = opPerformanceRecord2;
        }
        if (str != null) {
            arrayList.add(computeOptimalBatchSizeAndCuriosityForOp(str, treeMap, opPerformanceRecord));
        }
        this.opOptimizedBatchSizeDao.putMulti((Collection) arrayList.stream().flatMap((v0) -> {
            return v0.stream();
        }).peek(opOptimizedBatchSize -> {
            logger.info("saving opName={} batchSize={} curiosity={}", new Object[]{opOptimizedBatchSize.getKey().getOpName(), opOptimizedBatchSize.getBatchSize(), opOptimizedBatchSize.getCuriosity()});
        }).collect(Collectors.toList()));
    }

    private double getCuriosity(String str) {
        return getOpOptimizedBatchSizeFromCache(str).getCuriosity().doubleValue();
    }

    private int getOptimalBatchSize(String str) {
        return getOpOptimizedBatchSizeFromCache(str).getBatchSize().intValue();
    }

    private OpOptimizedBatchSize getOpOptimizedBatchSizeFromCache(String str) {
        return (OpOptimizedBatchSize) this.cachedOpOptimizedBatchSize.computeIfAbsent(new OpOptimizedBatchSizeKey(str), opOptimizedBatchSizeKey -> {
            return new CachedOpOptimizedBatchSize(this.opOptimizedBatchSizeDao, opOptimizedBatchSizeKey);
        }).get();
    }

    private Optional<OpOptimizedBatchSize> computeOptimalBatchSizeAndCuriosityForOp(String str, SortedMap<Integer, NodePerformanceStats> sortedMap, OpPerformanceRecord opPerformanceRecord) {
        if (Instant.ofEpochMilli(opPerformanceRecord.getKey().getTimestamp().longValue()).plusSeconds(30L).isBefore(Instant.now()) || sortedMap.size() < 3) {
            return Optional.empty();
        }
        Integer valueOf = Integer.valueOf(computeOptimalBatchSizeForOp(str, sortedMap));
        Double valueOf2 = Double.valueOf(0.1d);
        if (valueOf.equals(Integer.valueOf(getOptimalBatchSize(str)))) {
            valueOf2 = Double.valueOf(getCuriosity(str) + 0.01d);
        }
        return Optional.of(new OpOptimizedBatchSize(str, valueOf, valueOf2));
    }

    private int computeOptimalBatchSizeForOp(String str, SortedMap<Integer, NodePerformanceStats> sortedMap) {
        PolynomialRegressionOptimumFinder polynomialRegressionOptimumFinder = new PolynomialRegressionOptimumFinder((List) sortedMap.entrySet().stream().map(entry -> {
            return new PolynomialRegressionOptimumFinderPoint((Integer) entry.getKey(), Integer.valueOf(((NodePerformanceStats) entry.getValue()).getMean()));
        }).collect(Collectors.toList()));
        int ceil = (int) Math.ceil(polynomialRegressionOptimumFinder.getOptimumAbscissa());
        if (polynomialRegressionOptimumFinder.optimumIsMaximum()) {
            return limitBatchIfNeeded(ceil);
        }
        int optimalBatchSize = getOptimalBatchSize(str);
        return ceil < optimalBatchSize ? limitBatchIfNeeded((int) (optimalBatchSize * 1.1d)) : limitBatchIfNeeded((int) (optimalBatchSize * 0.9d));
    }

    private static int limitBatchIfNeeded(int i) {
        return Math.min(MAX_BATCH_SIZE, Math.max(i, MIN_BATCH_SIZE));
    }
}
