/*
 * Decompiled with CFR 0.152.
 */
package com.maltego.cloud.crypto.util;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.TextNode;
import com.maltego.cloud.crypto.Pair;
import com.maltego.cloud.crypto.util.CheckedFunction;
import com.maltego.cloud.crypto.util.PKCS7Padding;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.Key;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.List;
import java.util.function.BiConsumer;
import java.util.function.Function;
import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import javax.crypto.spec.GCMParameterSpec;

public final class JsonEncryptionUtil {
    private static int offsetsIdx = 0;

    private JsonEncryptionUtil() {
    }

    private static void handleValues(ObjectNode data, List<String> segments, BiConsumer<ObjectNode, String> handler) {
        String key = segments.get(segments.size() - 1);
        if (segments.size() == 1) {
            handler.accept(data, key);
            return;
        }
        JsonNode head = data.get(segments.get(0));
        if (head == null) {
            return;
        }
        if (segments.size() == 2) {
            if (head.isArray()) {
                for (JsonNode node : head) {
                    handler.accept((ObjectNode)node, key);
                }
            } else if (head.isObject()) {
                handler.accept((ObjectNode)head, key);
            } else {
                throw new Error("Invalid Node type " + head.getNodeType());
            }
            return;
        }
        List<String> segmentsTail = segments.subList(1, segments.size());
        if (head.isArray()) {
            for (JsonNode node : head) {
                JsonEncryptionUtil.handleValues((ObjectNode)node, segmentsTail, handler);
            }
        } else if (head.isObject()) {
            JsonEncryptionUtil.handleValues((ObjectNode)head, segmentsTail, handler);
        } else {
            throw new Error("Invalid Node type " + head.getNodeType());
        }
    }

    private static void setValueFromBuffer(ObjectNode data, List<String> keyPaths, byte[] buffer, List<Integer> offsets, Function<byte[], JsonNode> byteToJsonNodeEncoder) {
        offsetsIdx = 0;
        BiConsumer<ObjectNode, String> setJsonNodeFromBuffer = (node, key) -> {
            int fromOffset = (Integer)offsets.get(offsetsIdx);
            int toOffset = (Integer)offsets.get(offsetsIdx + 1);
            ++offsetsIdx;
            byte[] valueBuffer = Arrays.copyOfRange(buffer, fromOffset, toOffset);
            node.replace(key, (JsonNode)byteToJsonNodeEncoder.apply(valueBuffer));
        };
        for (String path : keyPaths) {
            List<String> segments = Arrays.asList(path.split("\\."));
            segments = segments.subList(1, segments.size());
            JsonEncryptionUtil.handleValues(data, segments, setJsonNodeFromBuffer);
        }
    }

    private static Pair<ByteArrayOutputStream, List<Integer>> createValueBuffer(ObjectNode data, List<String> keyPaths, CheckedFunction<JsonNode, byte[]> valueToByteEncoder) {
        ByteArrayOutputStream buffer = new ByteArrayOutputStream();
        ArrayList<Integer> offsets = new ArrayList<Integer>();
        offsets.add(0);
        BiConsumer<ObjectNode, String> addJsonNodeToBuffer = (node, key) -> {
            JsonNode value = node.get(key);
            try {
                byte[] bytes = (byte[])valueToByteEncoder.apply(value);
                buffer.write(bytes);
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            offsets.add(buffer.size());
        };
        for (String path : keyPaths) {
            List<String> segments = Arrays.asList(path.split("\\."));
            segments = segments.subList(1, segments.size());
            JsonEncryptionUtil.handleValues(data, segments, addJsonNodeToBuffer);
        }
        return new Pair<ByteArrayOutputStream, List<Integer>>(buffer, offsets);
    }

    private static byte[] atomicEncryptingJsonEncoder(SecretKey secretKey, JsonNode node, CheckedFunction<JsonNode, byte[]> valueToByteEncoder) throws Exception {
        byte[] ivBytes = new byte[12];
        new SecureRandom().nextBytes(ivBytes);
        Cipher atomicCipher = Cipher.getInstance("AES/GCM/NoPadding");
        atomicCipher.init(1, (Key)secretKey, new GCMParameterSpec(128, ivBytes));
        byte[] atomicBytes = atomicCipher.doFinal(valueToByteEncoder.apply(node));
        ByteArrayOutputStream byteStream = new ByteArrayOutputStream(ivBytes.length + atomicBytes.length);
        byteStream.write(ivBytes);
        byteStream.write(atomicBytes);
        return byteStream.toByteArray();
    }

    public static Pair<ObjectNode, byte[]> encryptJson(ObjectNode data, SecretKey secretKey, byte[] ivBytes, List<String> bulkKeyPaths, List<String> atomicKeyPaths) throws GeneralSecurityException {
        assert (Collections.disjoint(bulkKeyPaths, atomicKeyPaths));
        ObjectMapper mapper = new ObjectMapper();
        Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
        cipher.init(1, (Key)secretKey, new GCMParameterSpec(128, ivBytes));
        if (!atomicKeyPaths.isEmpty()) {
            Pair<ByteArrayOutputStream, List<Integer>> atomicBufferAndOffsets = JsonEncryptionUtil.createValueBuffer(data, atomicKeyPaths, node -> JsonEncryptionUtil.atomicEncryptingJsonEncoder(secretKey, node, n -> PKCS7Padding.add(mapper.writeValueAsBytes(n), 4)));
            ByteArrayOutputStream atomicBuffer = atomicBufferAndOffsets.first();
            List<Integer> atomicOffsets = atomicBufferAndOffsets.second();
            JsonEncryptionUtil.setValueFromBuffer(data, atomicKeyPaths, atomicBuffer.toByteArray(), atomicOffsets, bytes -> new TextNode(Base64.getEncoder().encodeToString((byte[])bytes)));
        }
        Pair<ByteArrayOutputStream, List<Integer>> bufferAndOffsets = JsonEncryptionUtil.createValueBuffer(data, bulkKeyPaths, node -> PKCS7Padding.add(mapper.writeValueAsBytes(node), 4));
        ByteArrayOutputStream buffer = bufferAndOffsets.first();
        List<Integer> offsets = bufferAndOffsets.second();
        byte[] encryptedBuffer = cipher.doFinal(buffer.toByteArray());
        int tagIdx = encryptedBuffer.length - 16;
        byte[] encryptedBytes = Arrays.copyOfRange(encryptedBuffer, 0, tagIdx);
        byte[] tag = Arrays.copyOfRange(encryptedBuffer, tagIdx, encryptedBuffer.length);
        JsonEncryptionUtil.setValueFromBuffer(data, bulkKeyPaths, encryptedBytes, offsets, bytes -> new TextNode(Base64.getEncoder().encodeToString((byte[])bytes)));
        return new Pair<ObjectNode, byte[]>(data, tag);
    }

    private static byte[] atomicDecryptingJsonEncoder(SecretKey secretKey, JsonNode node) throws Exception {
        byte[] encryptedBytes = Base64.getDecoder().decode(node.textValue());
        byte[] ivBytes = Arrays.copyOfRange(encryptedBytes, 0, 12);
        byte[] encryptedValue = Arrays.copyOfRange(encryptedBytes, 12, encryptedBytes.length);
        Cipher atomicCipher = Cipher.getInstance("AES/GCM/NoPadding");
        atomicCipher.init(2, (Key)secretKey, new GCMParameterSpec(128, ivBytes));
        return atomicCipher.doFinal(encryptedValue);
    }

    public static ObjectNode decryptJson(ObjectNode data, List<String> bulkKeyPaths, List<String> atomicKeyPaths, SecretKey secretKey, byte[] ivBytes, byte[] tag) throws IOException, GeneralSecurityException {
        ObjectMapper mapper = new ObjectMapper();
        Cipher cipher = Cipher.getInstance("AES/GCM/NoPadding");
        cipher.init(2, (Key)secretKey, new GCMParameterSpec(128, ivBytes));
        if (!atomicKeyPaths.isEmpty()) {
            Pair<ByteArrayOutputStream, List<Integer>> atomicBufferAndOffsets = JsonEncryptionUtil.createValueBuffer(data, atomicKeyPaths, node -> {
                try {
                    return JsonEncryptionUtil.atomicDecryptingJsonEncoder(secretKey, node);
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            });
            ByteArrayOutputStream atomicBuffer = atomicBufferAndOffsets.first();
            List<Integer> atomicOffsets = atomicBufferAndOffsets.second();
            JsonEncryptionUtil.setValueFromBuffer(data, atomicKeyPaths, atomicBuffer.toByteArray(), atomicOffsets, bytes -> {
                try {
                    return mapper.readTree(PKCS7Padding.remove(bytes));
                }
                catch (IOException e) {
                    throw new RuntimeException(e);
                }
            });
        }
        Pair<ByteArrayOutputStream, List<Integer>> bufferAndOffsets = JsonEncryptionUtil.createValueBuffer(data, bulkKeyPaths, bytes -> Base64.getDecoder().decode(bytes.textValue()));
        ByteArrayOutputStream buffer = bufferAndOffsets.first();
        List<Integer> offsets = bufferAndOffsets.second();
        buffer.write(tag);
        byte[] decryptedBytes = cipher.doFinal(buffer.toByteArray());
        JsonEncryptionUtil.setValueFromBuffer(data, bulkKeyPaths, decryptedBytes, offsets, bytes -> {
            try {
                return mapper.readTree(PKCS7Padding.remove(bytes));
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        });
        return data;
    }

    public static byte[] concatArrays(byte[] ... arrays) {
        int totalLength = Arrays.stream(arrays).mapToInt(array -> ((byte[])array).length).sum();
        byte[] result = new byte[totalLength];
        int currentPos = 0;
        for (byte[] array2 : arrays) {
            System.arraycopy(array2, 0, result, currentPos, array2.length);
            currentPos += array2.length;
        }
        return result;
    }
}

