package io.trino.operator.aggregation.state;

import com.google.common.base.Preconditions;
import io.trino.spi.function.AccumulatorState;

/* loaded from: input_file:io/trino/operator/aggregation/state/CentralMomentsState.class */
public interface CentralMomentsState extends AccumulatorState {
    long getCount();

    void setCount(long j);

    double getM1();

    void setM1(double d);

    double getM2();

    void setM2(double d);

    double getM3();

    void setM3(double d);

    double getM4();

    void setM4(double d);

    default void update(double d) {
        long count = getCount();
        long j = count + 1;
        double m1 = getM1();
        double m2 = getM2();
        double m3 = getM3();
        double d2 = d - m1;
        double d3 = d2 / j;
        double d4 = d3 * d3;
        double d5 = d2 * d3 * count;
        setCount(j);
        setM1(m1 + d3);
        setM2(m2 + d5);
        setM3((m3 + ((d5 * d3) * (j - 2))) - ((3.0d * d3) * m2));
        setM4(((getM4() + ((d5 * d4) * (((j * j) - (3 * j)) + 3.0d))) + ((6.0d * d4) * m2)) - ((4.0d * d3) * m3));
    }

    default void merge(CentralMomentsState centralMomentsState) {
        long count = getCount();
        long count2 = centralMomentsState.getCount();
        Preconditions.checkArgument(count2 >= 0, "count is negative");
        if (count2 == 0) {
            return;
        }
        double m1 = getM1();
        double m2 = getM2();
        double m3 = getM3();
        double m12 = centralMomentsState.getM1();
        double m22 = centralMomentsState.getM2();
        double m32 = centralMomentsState.getM3();
        double d = count + count2;
        double d2 = m12 - m1;
        double d3 = d2 * d2;
        setCount((long) d);
        setM1(((count * m1) + (count2 * m12)) / d);
        setM2(m2 + m22 + (((d3 * count) * count2) / d));
        setM3(m3 + m32 + (((((d2 * d3) * count) * count2) * (count - count2)) / (d * d)) + (((3.0d * d2) * ((count * m22) - (count2 * m2))) / d));
        setM4(getM4() + centralMomentsState.getM4() + (((((d3 * d3) * count) * count2) * (((count * count) - (count * count2)) + (count2 * count2))) / ((d * d) * d)) + (((6.0d * d3) * (((count * count) * m22) + ((count2 * count2) * m2))) / (d * d)) + (((4.0d * d2) * ((count * m32) - (count2 * m3))) / d));
    }
}
