package io.ortis.jqbit.xmss;

import io.ortis.jqbit.HashFunction;
import io.ortis.jqbit.Utils;
import io.ortis.jqbit.wotsp.Adrs;
import io.ortis.jqbit.wotsp.WOTSpConfig;
import io.ortis.jqbit.wotsp.WOTSpRFC;
import java.util.Arrays;
import java.util.LinkedList;

/* loaded from: input_file:io/ortis/jqbit/xmss/XMSSRFC.class */
public class XMSSRFC {

    /* loaded from: input_file:io/ortis/jqbit/xmss/XMSSRFC$TreeNode.class */
    public static class TreeNode {
        private final byte[] value;
        private final int height;

        public TreeNode(byte[] bArr, int i) {
            this.value = bArr;
            this.height = i;
        }

        public byte[] value() {
            return this.value;
        }

        public int getHeight() {
            return this.height;
        }

        public String toString() {
            return getClass().getSimpleName() + "{value=" + Utils.toBase16(this.value) + ", height=" + this.height + "}";
        }
    }

    public static boolean xmssVerify(XMSSConfig xMSSConfig, byte[] bArr, int i, byte[] bArr2, byte[] bArr3, byte[] bArr4, byte[] bArr5, byte[] bArr6) throws HashFunction.Instance.HashFunctionException {
        int n = xMSSConfig.getWOTSPConfig().getN();
        byte[] bArr7 = new byte[3 * n];
        byte[] bArr8 = new byte[n];
        Adrs adrs = new Adrs();
        System.arraycopy(bArr2, 0, bArr7, 0, n);
        System.arraycopy(bArr5, 0, bArr7, n, n);
        System.arraycopy(Utils.zToBytes(i, n), 0, bArr7, 2 * n, n);
        hmsg(xMSSConfig, bArr7, bArr, bArr8, 0);
        return Arrays.equals(bArr5, rootFromSig(xMSSConfig, bArr8, i, bArr3, bArr4, bArr6, adrs));
    }

    public static byte[] rootFromSig(XMSSConfig xMSSConfig, byte[] bArr, int i, byte[] bArr2, byte[] bArr3, byte[] bArr4, Adrs adrs) throws HashFunction.Instance.HashFunctionException {
        int h = xMSSConfig.getH();
        int n = xMSSConfig.getWOTSPConfig().getN();
        byte[] bArr5 = new byte[xMSSConfig.getWOTSPConfig().getKeyLength()];
        byte[] bArr6 = new byte[n];
        byte[] bArr7 = new byte[n];
        adrs.setType(Adrs.Type.OTS);
        adrs.setOTSAddress(i);
        WOTSpRFC.signatureToPublicKey(xMSSConfig.getWOTSPConfig(), bArr, bArr2, bArr4, adrs, bArr5, 0);
        adrs.setType(Adrs.Type.LTree);
        adrs.setLTreeAddress(i);
        ltree(xMSSConfig, bArr5, bArr4, adrs, bArr6, 0);
        adrs.setType(Adrs.Type.HashTree);
        adrs.setTreeIndex(i);
        for (int i2 = 0; i2 < h; i2++) {
            adrs.setTreeHeight(i2);
            if ((WOTSpRFC.floorDiv(i, pow2(i2)) & 1) == 0) {
                adrs.setTreeIndex(adrs.getTreeIndex() >> 1);
                randHash(xMSSConfig, bArr6, 0, bArr3, i2 * n, bArr4, adrs, bArr7, 0);
            } else {
                adrs.setTreeIndex((adrs.getTreeIndex() - 1) >> 1);
                randHash(xMSSConfig, bArr3, i2 * n, bArr6, 0, bArr4, adrs, bArr7, 0);
            }
            System.arraycopy(bArr7, 0, bArr6, 0, n);
        }
        return bArr6;
    }

    public static void xmssSign(XMSSConfig xMSSConfig, byte[] bArr, int i, byte[] bArr2, int i2, byte[] bArr3, byte[] bArr4, byte[] bArr5, byte[] bArr6, int i3, byte[] bArr7, int i4, byte[] bArr8, int i5, byte[] bArr9, int i6) throws HashFunction.Instance.HashFunctionException {
        int n = xMSSConfig.getWOTSPConfig().getN();
        byte[] bArr10 = new byte[3 * n];
        byte[] bArr11 = new byte[n];
        WOTSpRFC.prf(xMSSConfig.getWOTSPConfig(), bArr3, 0, bArr3.length, Utils.zToBytes(i, 32), 0, 32, bArr10, 0);
        System.arraycopy(bArr4, 0, bArr10, n, n);
        System.arraycopy(Utils.zToBytes(i, n), 0, bArr10, 2 * n, n);
        hmsg(xMSSConfig, bArr10, bArr, bArr11, 0);
        System.arraycopy(bArr10, 0, bArr7, i4, n);
        treeSig(xMSSConfig, bArr11, i, bArr2, i2, bArr5, new Adrs(), bArr6, i3, bArr8, i5, bArr9, i6);
    }

    public static void treeSig(XMSSConfig xMSSConfig, byte[] bArr, int i, byte[] bArr2, int i2, byte[] bArr3, Adrs adrs, byte[] bArr4, int i3, byte[] bArr5, int i4, byte[] bArr6, int i5) throws HashFunction.Instance.HashFunctionException {
        if (bArr4 == null) {
            computeAuth(xMSSConfig, i, bArr2, i2, bArr3, adrs, bArr6, i5);
        } else {
            readAuth(xMSSConfig, i, bArr4, i3, bArr6, i5);
        }
        adrs.setType(Adrs.Type.OTS);
        adrs.setOTSAddress(i);
        byte[] bArr7 = new byte[xMSSConfig.getWOTSPConfig().getKeyLength()];
        inflateCompactWOTSPPrivateKey(xMSSConfig, i, bArr2, i2, bArr7, 0);
        WOTSpRFC.sign(xMSSConfig.getWOTSPConfig(), bArr, bArr7, bArr3, adrs, bArr5, i4);
    }

    public static void computeAuth(XMSSConfig xMSSConfig, int i, byte[] bArr, int i2, byte[] bArr2, Adrs adrs, byte[] bArr3, int i3) throws HashFunction.Instance.HashFunctionException {
        int h = xMSSConfig.getH();
        for (int i4 = 0; i4 < h; i4++) {
            int pow2 = pow2(i4);
            byte[] treeHash = treeHash(xMSSConfig, (WOTSpRFC.floorDiv(i, pow2) ^ 1) * pow2, i4, bArr, i2, bArr2, adrs, null, -1);
            System.arraycopy(treeHash, 0, bArr3, i3 + (i4 * treeHash.length), treeHash.length);
        }
    }

    public static void readAuth(XMSSConfig xMSSConfig, int i, byte[] bArr, int i2, byte[] bArr2, int i3) {
        int h = xMSSConfig.getH();
        int n = xMSSConfig.getWOTSPConfig().getN();
        for (int i4 = 0; i4 < h; i4++) {
            System.arraycopy(bArr, i2 + flatTreeIndex(i4, WOTSpRFC.floorDiv(i, pow2(i4)) ^ 1, 0, h, n), bArr2, i3 + (i4 * n), n);
        }
    }

    public static byte[] computeRoot(XMSSConfig xMSSConfig, byte[] bArr, int i, byte[] bArr2) throws HashFunction.Instance.HashFunctionException {
        return treeHash(xMSSConfig, 0, xMSSConfig.getH(), bArr, i, bArr2, new Adrs(), null, -1);
    }

    public static void readRoot(XMSSConfig xMSSConfig, byte[] bArr, int i, byte[] bArr2, int i2) {
        xMSSConfig.getH();
        int n = xMSSConfig.getWOTSPConfig().getN();
        System.arraycopy(bArr, (xMSSConfig.getTreeNodeCount() - 1) * n, bArr2, i2, n);
    }

    public static byte[] treeHash(XMSSConfig xMSSConfig, int i, int i2, byte[] bArr, int i3, byte[] bArr2, Adrs adrs, byte[] bArr3, int i4) throws HashFunction.Instance.HashFunctionException {
        if (i % (1 << i2) != 0) {
            throw new IllegalArgumentException("Inputs must verify s % 2^t == 0");
        }
        int pow2 = pow2(i2);
        WOTSpConfig wOTSPConfig = xMSSConfig.getWOTSPConfig();
        int n = wOTSPConfig.getN();
        byte[] bArr4 = new byte[n];
        byte[] bArr5 = new byte[xMSSConfig.getWOTSPConfig().getKeyLength()];
        byte[] bArr6 = new byte[xMSSConfig.getWOTSPConfig().getKeyLength()];
        LinkedList linkedList = new LinkedList();
        for (int i5 = 0; i5 < pow2; i5++) {
            adrs.setType(Adrs.Type.OTS);
            int i6 = i + i5;
            adrs.setOTSAddress(i6);
            inflateCompactWOTSPPrivateKey(xMSSConfig, i6, bArr, i3, bArr5, 0);
            WOTSpRFC.publicKey(wOTSPConfig, bArr5, bArr2, adrs, bArr6, 0);
            adrs.setType(Adrs.Type.LTree);
            adrs.setLTreeAddress(i6);
            ltree(xMSSConfig, bArr6, bArr2, adrs, bArr4, 0);
            adrs.setType(Adrs.Type.HashTree);
            adrs.setTreeHeight(0);
            adrs.setTreeIndex(i6);
            TreeNode treeNode = new TreeNode(Arrays.copyOf(bArr4, bArr4.length), adrs.getTreeHeight());
            if (bArr3 != null) {
                System.arraycopy(bArr4, 0, bArr3, i4 + flatTreeIndex(treeNode.getHeight(), i6, i, i2, n), n);
            }
            while (!linkedList.isEmpty() && ((TreeNode) linkedList.peek()).getHeight() == treeNode.getHeight()) {
                adrs.setTreeIndex((adrs.getTreeIndex() - 1) >> 1);
                randHash(xMSSConfig, ((TreeNode) linkedList.poll()).value(), 0, treeNode.value(), 0, bArr2, adrs, bArr4, 0);
                adrs.setTreeHeight(adrs.getTreeHeight() + 1);
                treeNode = new TreeNode(Arrays.copyOf(bArr4, bArr4.length), adrs.getTreeHeight());
                if (bArr3 != null) {
                    System.arraycopy(bArr4, 0, bArr3, i4 + flatTreeIndex(treeNode.getHeight(), adrs.getTreeIndex(), i, i2, n), n);
                }
            }
            linkedList.push(treeNode);
        }
        if (linkedList.size() != 1) {
            throw new RuntimeException("Stake size must be 1");
        }
        return ((TreeNode) linkedList.poll()).value();
    }

    public static int flatTreeIndex(int i, int i2, int i3, int i4, int i5) {
        int pow2 = i2 - (i3 / pow2(i));
        if (pow2 < 0) {
            throw new RuntimeException("Bad index");
        }
        int i6 = 0;
        for (int i7 = 0; i7 < i; i7++) {
            i6 += pow2(i4 - i7);
        }
        return (i6 + pow2) * i5;
    }

    public static void ltree(XMSSConfig xMSSConfig, byte[] bArr, byte[] bArr2, Adrs adrs, byte[] bArr3, int i) throws HashFunction.Instance.HashFunctionException {
        WOTSpConfig wOTSPConfig = xMSSConfig.getWOTSPConfig();
        int len = wOTSPConfig.getLen();
        int n = wOTSPConfig.getN();
        byte[] bArr4 = new byte[wOTSPConfig.getKeyLength()];
        System.arraycopy(bArr, 0, bArr4, 0, bArr4.length);
        int i2 = len;
        adrs.setTreeHeight(0);
        while (i2 > 1) {
            for (int i3 = 0; i3 < WOTSpRFC.floorDiv(i2, 2); i3++) {
                int i4 = i3 * n;
                int i5 = 2 * i4;
                adrs.setTreeIndex(i3);
                randHash(xMSSConfig, bArr4, i5, bArr4, i5 + n, bArr2, adrs, bArr4, i4);
            }
            if (i2 % 2 == 1) {
                System.arraycopy(bArr4, (i2 - 1) * n, bArr4, WOTSpRFC.floorDiv(i2, 2) * n, n);
            }
            i2 = WOTSpRFC.ceilDiv(i2, 2);
            adrs.setTreeHeight(adrs.getTreeHeight() + 1);
        }
        System.arraycopy(bArr4, 0, bArr3, i, n);
    }

    public static void randHash(XMSSConfig xMSSConfig, byte[] bArr, int i, byte[] bArr2, int i2, byte[] bArr3, Adrs adrs, byte[] bArr4, int i3) throws HashFunction.Instance.HashFunctionException {
        WOTSpConfig wOTSPConfig = xMSSConfig.getWOTSPConfig();
        int n = wOTSPConfig.getN();
        byte[] bArr5 = new byte[n];
        byte[] bArr6 = new byte[n];
        byte[] bArr7 = new byte[n];
        byte[] bArr8 = new byte[2 * n];
        adrs.setKeyAndMask(0);
        WOTSpRFC.prf(wOTSPConfig, bArr3, 0, bArr3.length, adrs.toBytes(), 0, 32, bArr5, 0);
        adrs.setKeyAndMask(1);
        WOTSpRFC.prf(wOTSPConfig, bArr3, 0, bArr3.length, adrs.toBytes(), 0, 32, bArr6, 0);
        adrs.setKeyAndMask(2);
        WOTSpRFC.prf(wOTSPConfig, bArr3, 0, bArr3.length, adrs.toBytes(), 0, 32, bArr7, 0);
        for (int i4 = 0; i4 < n; i4++) {
            bArr8[i4] = (byte) (bArr[i + i4] ^ bArr6[i4]);
        }
        for (int i5 = 0; i5 < n; i5++) {
            bArr8[n + i5] = (byte) (bArr2[i2 + i5] ^ bArr7[i5]);
        }
        h(xMSSConfig, bArr5, bArr8, bArr4, i3);
    }

    public static void h(XMSSConfig xMSSConfig, byte[] bArr, byte[] bArr2, byte[] bArr3, int i) throws HashFunction.Instance.HashFunctionException {
        WOTSpRFC.functionTemplate(xMSSConfig.getWOTSPConfig(), 1, bArr, 0, bArr.length, bArr2, 0, bArr2.length, bArr3, i);
    }

    public static void hmsg(XMSSConfig xMSSConfig, byte[] bArr, byte[] bArr2, byte[] bArr3, int i) throws HashFunction.Instance.HashFunctionException {
        WOTSpRFC.functionTemplate(xMSSConfig.getWOTSPConfig(), 2, bArr, 0, bArr.length, bArr2, 0, bArr2.length, bArr3, i);
    }

    public static void inflateCompactWOTSPPrivateKey(XMSSConfig xMSSConfig, int i, byte[] bArr, int i2, byte[] bArr2, int i3) throws HashFunction.Instance.HashFunctionException {
        WOTSpRFC.inflatePrivateKey(xMSSConfig.getWOTSPConfig(), bArr, i2 + (i * xMSSConfig.getWOTSPConfig().getN()), bArr2, i3);
    }

    public static int pow2(int i) {
        return 1 << i;
    }
}
