package org.bouncycastle.mls.TreeKEM;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
import org.bouncycastle.mls.TreeSize;
import org.bouncycastle.mls.codec.HPKECiphertext;
import org.bouncycastle.mls.codec.UpdatePath;
import org.bouncycastle.mls.crypto.MlsCipherSuite;
import org.bouncycastle.mls.crypto.Secret;
import org.bouncycastle.util.Strings;
import org.bouncycastle.util.encoders.Hex;

/* loaded from: input_file:org/bouncycastle/mls/TreeKEM/TreeKEMPrivateKey.class */
public class TreeKEMPrivateKey {
    MlsCipherSuite suite;
    LeafIndex index;
    Secret updateSecret;
    Map<NodeIndex, Secret> pathSecrets = new HashMap();
    Map<NodeIndex, AsymmetricCipherKeyPair> privateKeyCache = new HashMap();

    public Secret getUpdateSecret() {
        return this.updateSecret;
    }

    public void insertPathSecret(NodeIndex nodeIndex, Secret secret) {
        this.pathSecrets.put(nodeIndex, secret);
    }

    public void insertPrivateKey(NodeIndex nodeIndex, AsymmetricCipherKeyPair asymmetricCipherKeyPair) {
        this.privateKeyCache.put(nodeIndex, asymmetricCipherKeyPair);
    }

    public TreeKEMPrivateKey(MlsCipherSuite mlsCipherSuite, LeafIndex leafIndex) {
        this.suite = mlsCipherSuite;
        this.index = leafIndex;
    }

    public TreeKEMPrivateKey copy() {
        TreeKEMPrivateKey treeKEMPrivateKey = new TreeKEMPrivateKey(this.suite, this.index);
        treeKEMPrivateKey.pathSecrets.putAll(this.pathSecrets);
        treeKEMPrivateKey.privateKeyCache.putAll(this.privateKeyCache);
        return treeKEMPrivateKey;
    }

    public static TreeKEMPrivateKey solo(MlsCipherSuite mlsCipherSuite, LeafIndex leafIndex, AsymmetricCipherKeyPair asymmetricCipherKeyPair) {
        TreeKEMPrivateKey treeKEMPrivateKey = new TreeKEMPrivateKey(mlsCipherSuite, leafIndex);
        treeKEMPrivateKey.privateKeyCache.put(new NodeIndex(leafIndex), asymmetricCipherKeyPair);
        return treeKEMPrivateKey;
    }

    public static TreeKEMPrivateKey create(TreeKEMPublicKey treeKEMPublicKey, LeafIndex leafIndex, Secret secret) throws Exception {
        TreeKEMPrivateKey treeKEMPrivateKey = new TreeKEMPrivateKey(treeKEMPublicKey.suite, leafIndex);
        treeKEMPrivateKey.implant(treeKEMPublicKey, new NodeIndex(leafIndex), secret);
        return treeKEMPrivateKey;
    }

    public static TreeKEMPrivateKey joiner(TreeKEMPublicKey treeKEMPublicKey, LeafIndex leafIndex, AsymmetricCipherKeyPair asymmetricCipherKeyPair, NodeIndex nodeIndex, Secret secret) throws Exception {
        TreeKEMPrivateKey treeKEMPrivateKey = new TreeKEMPrivateKey(treeKEMPublicKey.suite, leafIndex);
        treeKEMPrivateKey.privateKeyCache.put(new NodeIndex(leafIndex), asymmetricCipherKeyPair);
        if (secret != null) {
            treeKEMPrivateKey.implant(treeKEMPublicKey, nodeIndex, secret);
        }
        return treeKEMPrivateKey;
    }

    public String dump() throws IOException {
        StringBuilder sb = new StringBuilder();
        Iterator<NodeIndex> it = this.pathSecrets.keySet().iterator();
        while (it.hasNext()) {
            setPrivateKey(it.next(), true);
        }
        sb.append("Tree (priv)").append(Strings.lineSeparator());
        sb.append("  Index: ").append(new NodeIndex(this.index).value()).append(Strings.lineSeparator());
        sb.append("  Secrets: ").append(Strings.lineSeparator());
        for (NodeIndex nodeIndex : this.pathSecrets.keySet()) {
            Secret secret = this.pathSecrets.get(nodeIndex);
            sb.append("    ").append(nodeIndex.value()).append(" => ").append(Hex.toHexString(secret.value(), 0, 4)).append(" => ").append(Hex.toHexString(this.suite.getHPKE().serializePublicKey(this.suite.getHPKE().deriveKeyPair(secret.deriveSecret(this.suite, "node").value()).getPublic()), 0, 4)).append(Strings.lineSeparator());
        }
        sb.append("  Cached key pairs: ").append(Strings.lineSeparator());
        for (NodeIndex nodeIndex2 : this.privateKeyCache.keySet()) {
            sb.append("    ").append(nodeIndex2.value()).append(" => ").append(Hex.toHexString(this.suite.getHPKE().serializePublicKey(this.privateKeyCache.get(nodeIndex2).getPublic()), 0, 4)).append(Strings.lineSeparator());
        }
        return sb.toString();
    }

    public void truncate(TreeSize treeSize) {
        NodeIndex nodeIndex = new NodeIndex(new LeafIndex((int) (treeSize.leafCount() - 1)));
        ArrayList<NodeIndex> arrayList = new ArrayList();
        for (NodeIndex nodeIndex2 : this.pathSecrets.keySet()) {
            if (nodeIndex2.value() > nodeIndex.value()) {
                arrayList.add(nodeIndex2);
            }
        }
        for (NodeIndex nodeIndex3 : arrayList) {
            this.pathSecrets.remove(nodeIndex3);
            this.privateKeyCache.remove(nodeIndex3);
        }
    }

    public void setLeafKey(byte[] bArr) {
        NodeIndex nodeIndex = new NodeIndex(this.index);
        this.pathSecrets.remove(nodeIndex);
        this.privateKeyCache.put(nodeIndex, this.suite.getHPKE().deserializePrivateKey(bArr, (byte[]) null));
    }

    public void decap(LeafIndex leafIndex, TreeKEMPublicKey treeKEMPublicKey, byte[] bArr, UpdatePath updatePath, List<LeafIndex> list) throws Exception {
        NodeIndex nodeIndex = new NodeIndex(this.index);
        FilteredDirectPath filteredDirectPath = treeKEMPublicKey.getFilteredDirectPath(new NodeIndex(leafIndex));
        if (filteredDirectPath.parents.size() != updatePath.getNodes().size()) {
            throw new Exception("Malformed direct path");
        }
        NodeIndex nodeIndex2 = null;
        ArrayList<NodeIndex> arrayList = new ArrayList<>();
        int i = 0;
        while (true) {
            if (i >= filteredDirectPath.parents.size()) {
                break;
            }
            if (nodeIndex.isBelow(filteredDirectPath.parents.get(i))) {
                nodeIndex2 = filteredDirectPath.parents.get(i);
                arrayList = filteredDirectPath.resolutions.get(i);
                break;
            }
            i++;
        }
        if (i == filteredDirectPath.parents.size()) {
            throw new Exception("No overlap in path");
        }
        Utils.removeLeaves(arrayList, list);
        if (arrayList.size() != updatePath.getNodes().get(i).getEncryptedPathSecret().size()) {
            throw new Exception("Malformed direct path node");
        }
        int i2 = 0;
        while (i2 < arrayList.size() && !havePrivateKey(arrayList.get(i2))) {
            i2++;
        }
        if (i2 == arrayList.size()) {
            throw new Exception("No private key to decrypt path secret");
        }
        AsymmetricCipherKeyPair privateKey = setPrivateKey(arrayList.get(i2), false);
        HPKECiphertext hPKECiphertext = updatePath.getNodes().get(i).getEncryptedPathSecret().get(i2);
        implant(treeKEMPublicKey, nodeIndex2, new Secret(this.suite.decryptWithLabel(this.suite.getHPKE().serializePrivateKey(privateKey.getPrivate()), "UpdatePathNode", bArr, hPKECiphertext.getKemOutput(), hPKECiphertext.getCiphertext())));
        if (!consistent(treeKEMPublicKey)) {
            throw new Exception("TreeKEMPublicKey inconsistant with TreeKEMPrivateKey");
        }
    }

    private boolean havePrivateKey(NodeIndex nodeIndex) {
        return this.pathSecrets.containsKey(nodeIndex) || this.privateKeyCache.containsKey(nodeIndex);
    }

    public final boolean consistent(TreeKEMPublicKey treeKEMPublicKey) throws IOException {
        if (this.suite.getSuiteID() != treeKEMPublicKey.suite.getSuiteID()) {
            return false;
        }
        Iterator<NodeIndex> it = this.pathSecrets.keySet().iterator();
        while (it.hasNext()) {
            setPrivateKey(it.next(), true);
        }
        for (NodeIndex nodeIndex : this.privateKeyCache.keySet()) {
            Node node = treeKEMPublicKey.nodeAt(nodeIndex).node;
            if (node != null) {
                if (!Arrays.equals(node.getPublicKey(), this.suite.getHPKE().serializePublicKey(this.privateKeyCache.get(nodeIndex).getPublic()))) {
                    return false;
                }
            }
        }
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AsymmetricCipherKeyPair setPrivateKey(NodeIndex nodeIndex, boolean z) throws IOException {
        AsymmetricCipherKeyPair privateKey = getPrivateKey(nodeIndex);
        if (privateKey != null && !z) {
            this.privateKeyCache.put(nodeIndex, privateKey);
        }
        return privateKey;
    }

    private AsymmetricCipherKeyPair getPrivateKey(NodeIndex nodeIndex) throws IOException {
        if (this.privateKeyCache.containsKey(nodeIndex)) {
            return this.privateKeyCache.get(nodeIndex);
        }
        if (!this.pathSecrets.containsKey(nodeIndex)) {
            return null;
        }
        return this.suite.getHPKE().deriveKeyPair(this.pathSecrets.get(nodeIndex).deriveSecret(this.suite, "node").value());
    }

    private void implant(TreeKEMPublicKey treeKEMPublicKey, NodeIndex nodeIndex, Secret secret) throws Exception {
        FilteredDirectPath filteredDirectPath = treeKEMPublicKey.getFilteredDirectPath(nodeIndex);
        Secret secret2 = new Secret(secret.value());
        this.pathSecrets.put(nodeIndex, secret2);
        this.privateKeyCache.remove(nodeIndex);
        Iterator<NodeIndex> it = filteredDirectPath.parents.iterator();
        while (it.hasNext()) {
            NodeIndex next = it.next();
            secret2 = secret2.deriveSecret(treeKEMPublicKey.suite, "path");
            this.pathSecrets.put(next, secret2);
            this.privateKeyCache.remove(next);
        }
        this.updateSecret = secret2.deriveSecret(treeKEMPublicKey.suite, "path");
    }

    public Secret getSharedPathSecret(LeafIndex leafIndex) {
        NodeIndex commonAncestor = this.index.commonAncestor(leafIndex);
        return !this.pathSecrets.containsKey(commonAncestor) ? new Secret(new byte[0]) : this.pathSecrets.get(commonAncestor);
    }
}
