package org.apfloat.aparapi;

import com.aparapi.Kernel;
import org.apfloat.ApfloatRuntimeException;
import org.apfloat.spi.ArrayAccess;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:org/apfloat/aparapi/LongKernel.class */
public class LongKernel extends Kernel {
    private static ThreadLocal<LongKernel> kernel = ThreadLocal.withInitial(LongKernel::new);
    public static final int TRANSFORM_ROWS = 1;
    public static final int INVERSE_TRANSFORM_ROWS = 2;
    private int stride;
    private int length;
    private long[] data;
    private int offset;
    private int permutationTableLength;
    private long modulus;
    private double inverseModulus;
    public static final int TRANSPOSE = 3;
    public static final int PERMUTE = 4;
    private int n2;
    private int indexCount;
    public static final int MULTIPLY_ELEMENTS = 5;
    private int startRow;
    private int startColumn;
    private int rows;
    private int columns;
    private long w;
    private long scaleFactor;
    public static final int TRANSFORM_COLUMNS = 6;
    public static final int INVERSE_TRANSFORM_COLUMNS = 7;
    private int op;
    private long ww;
    private long w1;
    private long w2;
    private long[] wTable = {0};
    private int[] permutationTable = {0};
    private int[] index = {0};

    private LongKernel() {
    }

    public static LongKernel getInstance() {
        return kernel.get();
    }

    public void setLength(int i) {
        this.length = i;
    }

    public void setArrayAccess(ArrayAccess arrayAccess) throws ApfloatRuntimeException {
        this.data = arrayAccess.getLongData();
        this.offset = arrayAccess.getOffset();
        if (this.length != 0) {
            this.stride = arrayAccess.getLength() / this.length;
        }
    }

    public void setWTable(long[] jArr) {
        this.wTable = jArr;
    }

    public void setPermutationTable(int[] iArr) {
        this.permutationTable = iArr == null ? new int[1] : iArr;
        this.permutationTableLength = iArr == null ? 0 : iArr.length;
    }

    private void columnTableFNT() {
        long[] jArr = this.data;
        int globalId = this.offset + getGlobalId();
        int i = this.stride;
        int i2 = this.length;
        if (i2 >= 2) {
            int i3 = 1;
            int i4 = i2;
            while (true) {
                int i5 = i4 >> 1;
                if (i5 <= 0) {
                    break;
                }
                int i6 = i5 << 1;
                int i7 = globalId;
                while (true) {
                    int i8 = i7;
                    if (i8 >= globalId + (i2 * i)) {
                        break;
                    }
                    int i9 = i8 + (i5 * i);
                    long j = jArr[i8];
                    long j2 = jArr[i9];
                    jArr[i8] = modAdd(j, j2);
                    jArr[i9] = modSubtract(j, j2);
                    i7 = i8 + (i6 * i);
                }
                int i10 = i3;
                for (int i11 = 1; i11 < i5; i11++) {
                    int i12 = globalId;
                    int i13 = i11;
                    while (true) {
                        int i14 = i12 + (i13 * i);
                        if (i14 < globalId + (i2 * i)) {
                            int i15 = i14 + (i5 * i);
                            long j3 = jArr[i14];
                            long j4 = jArr[i15];
                            jArr[i14] = modAdd(j3, j4);
                            jArr[i15] = modMultiply(this.wTable[i10], modSubtract(j3, j4));
                            i12 = i14;
                            i13 = i6;
                        }
                    }
                    i10 += i3;
                }
                i3 <<= 1;
                i4 = i5;
            }
            if (this.permutationTableLength > 0) {
                columnScramble(globalId);
            }
        }
    }

    private void inverseColumnTableFNT() {
        long[] jArr = this.data;
        int globalId = this.offset + getGlobalId();
        int i = this.stride;
        int i2 = this.length;
        if (i2 < 2) {
            return;
        }
        if (this.permutationTableLength > 0) {
            columnScramble(globalId);
        }
        int i3 = i2;
        int i4 = 1;
        while (true) {
            int i5 = i4;
            if (i2 <= i5) {
                return;
            }
            int i6 = i5 << 1;
            i3 >>= 1;
            int i7 = globalId;
            while (true) {
                int i8 = i7;
                if (i8 >= globalId + (i2 * i)) {
                    break;
                }
                int i9 = i8 + (i5 * i);
                long j = jArr[i9];
                jArr[i9] = modSubtract(jArr[i8], j);
                jArr[i8] = modAdd(jArr[i8], j);
                i7 = i8 + (i6 * i);
            }
            int i10 = i3;
            for (int i11 = 1; i11 < i5; i11++) {
                int i12 = globalId;
                int i13 = i11;
                while (true) {
                    int i14 = i12 + (i13 * i);
                    if (i14 < globalId + (i2 * i)) {
                        int i15 = i14 + (i5 * i);
                        long modMultiply = modMultiply(this.wTable[i10], jArr[i15]);
                        jArr[i15] = modSubtract(jArr[i14], modMultiply);
                        jArr[i14] = modAdd(jArr[i14], modMultiply);
                        i12 = i14;
                        i13 = i6;
                    }
                }
                i10 += i3;
            }
            i4 = i6;
        }
    }

    private void columnScramble(int i) {
        for (int i2 = 0; i2 < this.permutationTableLength; i2 += 2) {
            int i3 = i + (this.permutationTable[i2] * this.stride);
            int i4 = i + (this.permutationTable[i2 + 1] * this.stride);
            long j = this.data[i3];
            this.data[i3] = this.data[i4];
            this.data[i4] = j;
        }
    }

    private long modMultiply(long j, long j2) {
        long j3 = ((j * j2) - (this.modulus * ((long) ((j * j2) * this.inverseModulus)))) - (this.modulus * ((int) (r0 * this.inverseModulus)));
        long j4 = j3 >= this.modulus ? j3 - this.modulus : j3;
        return j4 < 0 ? j4 + this.modulus : j4;
    }

    private long modAdd(long j, long j2) {
        long j3 = j + j2;
        return j3 >= this.modulus ? j3 - this.modulus : j3;
    }

    private long modSubtract(long j, long j2) {
        long j3 = j - j2;
        return j3 < 0 ? j3 + this.modulus : j3;
    }

    public void setModulus(long j) {
        this.inverseModulus = 1.0d / j;
        this.modulus = j;
    }

    public long getModulus() {
        return this.modulus;
    }

    public void setN2(int i) {
        this.n2 = i;
    }

    public void setIndex(int[] iArr) {
        this.index = iArr;
    }

    public void setIndexCount(int i) {
        this.indexCount = i;
    }

    private void transpose() {
        int globalId = getGlobalId(0);
        int globalId2 = getGlobalId(1);
        if (globalId < globalId2) {
            int i = this.offset + (globalId2 * this.n2) + globalId;
            int i2 = this.offset + (globalId * this.n2) + globalId2;
            long j = this.data[i];
            this.data[i] = this.data[i2];
            this.data[i2] = j;
        }
    }

    private void permute() {
        int globalId = getGlobalId();
        int i = 0;
        while (i < this.indexCount) {
            int i2 = this.index[i];
            long j = this.data[this.offset + (this.n2 * i2) + globalId];
            while (true) {
                i++;
                if (this.index[i] != 0) {
                    int i3 = this.index[i];
                    this.data[this.offset + (this.n2 * i2) + globalId] = this.data[this.offset + (this.n2 * i3) + globalId];
                    i2 = i3;
                }
            }
            this.data[this.offset + (this.n2 * i2) + globalId] = j;
            i++;
        }
    }

    public void setStartRow(int i) {
        this.startRow = i;
    }

    public void setStartColumn(int i) {
        this.startColumn = i;
    }

    public void setRows(int i) {
        this.rows = i;
    }

    public void setColumns(int i) {
        this.columns = i;
    }

    public void setW(long j) {
        this.w = j;
    }

    public void setScaleFactor(long j) {
        this.scaleFactor = j;
    }

    private void multiplyElements() {
        long[] jArr = this.data;
        int globalId = this.offset + getGlobalId();
        long modPow = modPow(this.w, this.startRow);
        long modPow2 = modPow(this.w, this.startColumn + getGlobalId());
        long modMultiply = modMultiply(this.scaleFactor, modPow(modPow, this.startColumn + getGlobalId()));
        for (int i = 0; i < this.rows; i++) {
            jArr[globalId] = modMultiply(jArr[globalId], modMultiply);
            globalId += this.columns;
            modMultiply = modMultiply(modMultiply, modPow2);
        }
    }

    private long modPow(long j, long j2) {
        long j3;
        if (j2 == 0) {
            return 1L;
        }
        if (j2 < 0) {
            j2 = (getModulus() - 1) + j2;
        }
        long j4 = j2;
        while (true) {
            j3 = j4;
            if ((j3 & 1) != 0) {
                break;
            }
            j = modMultiply(j, j);
            j4 = j3 >> 1;
        }
        long j5 = j;
        while (true) {
            j3 >>= 1;
            if (j3 <= 0) {
                return j5;
            }
            j = modMultiply(j, j);
            if ((j3 & 1) != 0) {
                j5 = modMultiply(j5, j);
            }
        }
    }

    public void setOp(int i) {
        this.op = i;
    }

    public void setWw(long j) {
        this.ww = j;
    }

    public void setW1(long j) {
        this.w1 = j;
    }

    public void setW2(long j) {
        this.w2 = j;
    }

    public void run() {
        if (this.op == 1) {
            columnTableFNT();
            return;
        }
        if (this.op == 2) {
            inverseColumnTableFNT();
            return;
        }
        if (this.op == 3) {
            transpose();
            return;
        }
        if (this.op == 4) {
            permute();
            return;
        }
        if (this.op == 5) {
            multiplyElements();
        } else if (this.op == 6 || this.op == 7) {
            transformColumns();
        }
    }

    private void transformColumns() {
        int globalId = getGlobalId();
        long modPow = modPow(this.w, this.startColumn + globalId);
        long modPow2 = modPow(this.ww, this.startColumn + globalId);
        long j = this.data[this.offset + globalId];
        long j2 = this.data[this.offset + this.columns + globalId];
        long j3 = this.data[this.offset + (2 * this.columns) + globalId];
        if (this.op == 7) {
            j2 = modMultiply(j2, modPow);
            j3 = modMultiply(j3, modPow2);
        }
        long modAdd = modAdd(j2, j3);
        long modSubtract = modSubtract(j2, j3);
        long modAdd2 = modAdd(j, modAdd);
        long modMultiply = modMultiply(modAdd, this.w1);
        long modMultiply2 = modMultiply(modSubtract, this.w2);
        long modAdd3 = modAdd(modMultiply, modAdd2);
        long modAdd4 = modAdd(modAdd3, modMultiply2);
        long modSubtract2 = modSubtract(modAdd3, modMultiply2);
        if (this.op == 6) {
            modAdd4 = modMultiply(modAdd4, modPow);
            modSubtract2 = modMultiply(modSubtract2, modPow2);
        }
        this.data[this.offset + globalId] = modAdd2;
        this.data[this.offset + this.columns + globalId] = modAdd4;
        this.data[this.offset + (2 * this.columns) + globalId] = modSubtract2;
    }
}
