package org.vanilladb.core.query.algebra.materialize;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.NavigableSet;
import java.util.Set;
import java.util.TreeSet;
import org.vanilladb.core.query.algebra.Plan;
import org.vanilladb.core.query.algebra.ReduceRecordsPlan;
import org.vanilladb.core.query.algebra.Scan;
import org.vanilladb.core.sql.Constant;
import org.vanilladb.core.sql.ConstantRange;
import org.vanilladb.core.sql.DoubleConstant;
import org.vanilladb.core.sql.Schema;
import org.vanilladb.core.sql.aggfn.AggregationFn;
import org.vanilladb.core.sql.aggfn.AvgFn;
import org.vanilladb.core.sql.aggfn.CountFn;
import org.vanilladb.core.sql.aggfn.DistinctCountFn;
import org.vanilladb.core.sql.aggfn.MaxFn;
import org.vanilladb.core.sql.aggfn.MinFn;
import org.vanilladb.core.sql.aggfn.SumFn;
import org.vanilladb.core.storage.metadata.statistics.Bucket;
import org.vanilladb.core.storage.metadata.statistics.Histogram;
import org.vanilladb.core.storage.tx.Transaction;

/* loaded from: input_file:org/vanilladb/core/query/algebra/materialize/GroupByPlan.class */
public class GroupByPlan extends ReduceRecordsPlan {
    private Plan sp;
    private Set<String> groupFlds;
    private Set<AggregationFn> aggFns;
    private Schema schema = new Schema();
    private Histogram hist;

    public static Histogram groupByHistogram(Histogram histogram, Set<String> set, Set<AggregationFn> set2) {
        if (Double.compare(histogram.recordsOutput(), 1.0d) < 0) {
            return new Histogram(histogram.fields());
        }
        double d = 1.0d;
        Iterator<String> it = set.iterator();
        while (it.hasNext()) {
            double d2 = 0.0d;
            Iterator<Bucket> it2 = histogram.buckets(it.next()).iterator();
            while (it2.hasNext()) {
                d2 += it2.next().distinctValues();
            }
            d *= d2;
        }
        double min = Math.min(d, histogram.recordsOutput());
        double recordsOutput = min / histogram.recordsOutput();
        Histogram histogram2 = new Histogram(set);
        for (String str : set) {
            for (Bucket bucket : histogram.buckets(str)) {
                double frequency = bucket.frequency() * recordsOutput;
                if (Double.compare(frequency, 1.0d) >= 0) {
                    histogram2.addBucket(str, new Bucket(bucket.valueRange(), frequency, bucket.distinctValues(), bucket.valuePercentiles()));
                }
            }
        }
        if (set2 != null) {
            for (AggregationFn aggregationFn : set2) {
                String argumentFieldName = aggregationFn.argumentFieldName();
                String fieldName = aggregationFn.fieldName();
                if (aggregationFn.getClass().equals(SumFn.class)) {
                    histogram2.addBucket(fieldName, sumBucket(histogram.buckets(argumentFieldName), min));
                } else if (aggregationFn.getClass().equals(AvgFn.class)) {
                    histogram2.addBucket(fieldName, avgBucket(histogram.buckets(argumentFieldName), min));
                } else if (aggregationFn.getClass().equals(CountFn.class)) {
                    histogram2.addBucket(fieldName, countBucket(histogram.buckets(argumentFieldName), min));
                } else if (aggregationFn.getClass().equals(DistinctCountFn.class)) {
                    histogram2.addBucket(fieldName, distinctCountBucket(histogram.buckets(argumentFieldName), min));
                } else if (aggregationFn.getClass().equals(MinFn.class)) {
                    histogram2.addBucket(fieldName, minBucket(histogram.buckets(argumentFieldName), min));
                } else {
                    if (!aggregationFn.getClass().equals(MaxFn.class)) {
                        throw new UnsupportedOperationException();
                    }
                    histogram2.addBucket(fieldName, maxBucket(histogram.buckets(argumentFieldName), min));
                }
            }
        }
        return syncHistogram(histogram2);
    }

    private static Bucket sumBucket(Collection<Bucket> collection, double d) {
        Constant constant = null;
        DoubleConstant doubleConstant = new DoubleConstant(1.0d);
        double d2 = 0.0d;
        HashMap hashMap = new HashMap();
        for (Bucket bucket : collection) {
            if (constant == null || bucket.valueRange().low().compareTo(constant) < 0) {
                constant = bucket.valueRange().low();
            }
            d2 += bucket.frequency();
            hashMap.put(bucket.valueRange().high(), bucket);
        }
        NavigableSet<Constant> descendingSet = new TreeSet(hashMap.keySet()).descendingSet();
        double d3 = (d2 - d) + 1.0d;
        double d4 = 0.0d;
        for (Constant constant2 : descendingSet) {
            double min = Math.min(((Bucket) hashMap.get(constant2)).frequency(), d3 - d4);
            doubleConstant = doubleConstant.add(constant2.mul(new DoubleConstant(min)));
            d4 += min;
            if (Double.compare(d4, d3) >= 0) {
                break;
            }
        }
        return new Bucket(ConstantRange.newInstance(constant, true, doubleConstant, true), d, d);
    }

    private static Bucket avgBucket(Collection<Bucket> collection, double d) {
        Constant constant = null;
        Constant constant2 = null;
        for (Bucket bucket : collection) {
            if (constant == null || bucket.valueRange().low().compareTo(constant) < 0) {
                constant = bucket.valueRange().low();
            }
            if (constant2 == null || bucket.valueRange().high().compareTo(constant2) > 0) {
                constant2 = bucket.valueRange().high();
            }
        }
        return new Bucket(ConstantRange.newInstance(constant, true, constant2, true), d, d);
    }

    private static Bucket countBucket(Collection<Bucket> collection, double d) {
        DoubleConstant doubleConstant = new DoubleConstant(1.0d);
        Double valueOf = Double.valueOf(0.0d);
        Iterator<Bucket> it = collection.iterator();
        while (it.hasNext()) {
            valueOf = Double.valueOf(valueOf.doubleValue() + it.next().frequency());
        }
        return new Bucket(ConstantRange.newInstance(doubleConstant, true, new DoubleConstant((valueOf.doubleValue() - d) + 1.0d), true), d, d);
    }

    private static Bucket distinctCountBucket(Collection<Bucket> collection, double d) {
        DoubleConstant doubleConstant = new DoubleConstant(1.0d);
        Double valueOf = Double.valueOf(0.0d);
        Double valueOf2 = Double.valueOf(0.0d);
        for (Bucket bucket : collection) {
            valueOf = Double.valueOf(valueOf.doubleValue() + bucket.frequency());
            valueOf2 = Double.valueOf(valueOf2.doubleValue() + bucket.distinctValues());
        }
        return new Bucket(ConstantRange.newInstance(doubleConstant, true, new DoubleConstant(Math.min((valueOf.doubleValue() - d) + 1.0d, valueOf2.doubleValue())), true), d, d);
    }

    private static Bucket minBucket(Collection<Bucket> collection, double d) {
        Constant constant = null;
        Constant constant2 = null;
        Double valueOf = Double.valueOf(0.0d);
        for (Bucket bucket : collection) {
            if (constant == null || bucket.valueRange().low().compareTo(constant) < 0) {
                constant = bucket.valueRange().low();
            }
            if (constant2 == null || bucket.valueRange().high().compareTo(constant2) > 0) {
                constant2 = bucket.valueRange().high();
            }
            valueOf = Double.valueOf(valueOf.doubleValue() + bucket.distinctValues());
        }
        return new Bucket(ConstantRange.newInstance(constant, true, constant2, true), d, Math.min(d, valueOf.doubleValue()));
    }

    private static Bucket maxBucket(Collection<Bucket> collection, double d) {
        Constant constant = null;
        Constant constant2 = null;
        Double valueOf = Double.valueOf(0.0d);
        for (Bucket bucket : collection) {
            if (constant == null || bucket.valueRange().low().compareTo(constant) < 0) {
                constant = bucket.valueRange().low();
            }
            if (constant2 == null || bucket.valueRange().high().compareTo(constant2) > 0) {
                constant2 = bucket.valueRange().high();
            }
            valueOf = Double.valueOf(valueOf.doubleValue() + bucket.distinctValues());
        }
        return new Bucket(ConstantRange.newInstance(constant, true, constant2, true), d, Math.min(d, valueOf.doubleValue()));
    }

    public GroupByPlan(Plan plan, Set<String> set, Set<AggregationFn> set2, Transaction transaction) {
        this.groupFlds = set;
        if (this.groupFlds.isEmpty()) {
            this.sp = plan;
        } else {
            Iterator<String> it = set.iterator();
            while (it.hasNext()) {
                this.schema.add(it.next(), plan.schema());
            }
            this.sp = new SortPlan(plan, new ArrayList(set), transaction);
        }
        this.aggFns = set2;
        if (set2 != null) {
            for (AggregationFn aggregationFn : set2) {
                this.schema.addField(aggregationFn.fieldName(), aggregationFn.isArgumentTypeDependent() ? plan.schema().type(aggregationFn.argumentFieldName()) : aggregationFn.fieldType());
            }
        }
        this.hist = groupByHistogram(plan.histogram(), this.groupFlds, set2);
    }

    @Override // org.vanilladb.core.query.algebra.Plan
    public Scan open() {
        return new GroupByScan(this.sp.open(), this.groupFlds, this.aggFns);
    }

    @Override // org.vanilladb.core.query.algebra.Plan
    public long blocksAccessed() {
        return this.sp.blocksAccessed();
    }

    @Override // org.vanilladb.core.query.algebra.Plan
    public Schema schema() {
        return this.schema;
    }

    @Override // org.vanilladb.core.query.algebra.Plan
    public Histogram histogram() {
        return this.hist;
    }

    @Override // org.vanilladb.core.query.algebra.Plan
    public long recordsOutput() {
        return (long) this.hist.recordsOutput();
    }

    public String toString() {
        String[] split = this.sp.toString().split("\n");
        StringBuilder sb = new StringBuilder();
        sb.append("->");
        sb.append("GroupByPlan: (#blks=" + blocksAccessed() + ", #recs=" + recordsOutput() + ")\n");
        for (String str : split) {
            sb.append("\t").append(str).append("\n");
        }
        return sb.toString();
    }
}
