/*
 * Decompiled with CFR 0.152.
 */
package com.intellij.rt.debugger;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

public final class VirtualThreadDumper {
    static volatile boolean initialized = false;
    static boolean successfully = false;
    static MethodHandle streamIteratorHandle;
    static MethodHandle containersRootHandle;
    static MethodHandle containerChildrenHandle;
    static MethodHandle containerThreadsHandle;
    static MethodHandle threadIsVirtualHandle;
    static MethodHandle threadThreadState;

    private static boolean init(MethodHandles.Lookup lookup) {
        if (!initialized) {
            try {
                Class<?> streamClass = Class.forName("java.util.stream.Stream");
                streamIteratorHandle = lookup.findVirtual(streamClass, "iterator", MethodType.methodType(Iterator.class));
                Class<?> threadContainersClass = Class.forName("jdk.internal.vm.ThreadContainers");
                Class<?> threadContainerClass = Class.forName("jdk.internal.vm.ThreadContainer");
                containersRootHandle = lookup.findStatic(threadContainersClass, "root", MethodType.methodType(threadContainerClass));
                containerChildrenHandle = lookup.findVirtual(threadContainerClass, "children", MethodType.methodType(streamClass));
                containerThreadsHandle = lookup.findVirtual(threadContainerClass, "threads", MethodType.methodType(streamClass));
                threadIsVirtualHandle = lookup.findVirtual(Thread.class, "isVirtual", MethodType.methodType(Boolean.TYPE));
                threadThreadState = lookup.findVirtual(Thread.class, "threadState", MethodType.methodType(Thread.State.class));
                successfully = true;
            }
            catch (ClassNotFoundException | IllegalAccessException | NoSuchMethodException e) {
                successfully = false;
            }
            initialized = true;
        }
        return successfully;
    }

    public static Object[] getAllVirtualThreadsWithStackTraces(MethodHandles.Lookup lookup) throws Throwable {
        if (!VirtualThreadDumper.init(lookup)) {
            return null;
        }
        ArrayList<Thread> threads = VirtualThreadDumper.getAllVirtualThreads(lookup);
        if (threads.isEmpty()) {
            return null;
        }
        HashMap<String, ArrayList<Thread>> groupedByStackTrace = new HashMap<String, ArrayList<Thread>>();
        for (Thread t : threads) {
            StringBuilder buffer = new StringBuilder();
            String name = t.getName();
            Thread.State javaThreadState = threadThreadState.invoke(t);
            buffer.append(name).append('\n').append((Object)javaThreadState);
            for (StackTraceElement ste : t.getStackTrace()) {
                buffer.append("\n\tat ").append(ste);
            }
            String stackTrace = buffer.toString();
            ArrayList<Thread> similarThreads = (ArrayList<Thread>)groupedByStackTrace.get(stackTrace);
            if (similarThreads == null) {
                similarThreads = new ArrayList<Thread>();
                groupedByStackTrace.put(stackTrace, similarThreads);
            }
            similarThreads.add(t);
        }
        long[] tids = new long[threads.size()];
        int tidIdx = 0;
        Object[] allStackTraceAndThreads = new Object[threads.size() + groupedByStackTrace.size() * 2];
        int stIdx = 0;
        for (Map.Entry e : groupedByStackTrace.entrySet()) {
            String st = (String)e.getKey();
            ArrayList ts = (ArrayList)e.getValue();
            allStackTraceAndThreads[stIdx++] = st;
            for (Thread t : ts) {
                allStackTraceAndThreads[stIdx++] = t;
                tids[tidIdx++] = t.getId();
            }
            allStackTraceAndThreads[stIdx++] = null;
        }
        assert (stIdx == allStackTraceAndThreads.length);
        return new Object[]{allStackTraceAndThreads, tids};
    }

    private static ArrayList<Thread> getAllVirtualThreads(MethodHandles.Lookup lookup) throws Throwable {
        if (!VirtualThreadDumper.init(lookup)) {
            return null;
        }
        ArrayList<Thread> result = new ArrayList<Thread>();
        for (Object container : VirtualThreadDumper.getAllContainers()) {
            Object threads = containerThreadsHandle.invoke(container);
            Iterator it = streamIteratorHandle.invoke(threads);
            while (it.hasNext()) {
                Thread t = (Thread)it.next();
                boolean isVirtual = threadIsVirtualHandle.invoke(t);
                if (!isVirtual) continue;
                result.add(t);
            }
        }
        return result;
    }

    private static ArrayList<Object> getAllContainers() throws Throwable {
        ArrayList<Object> allContainers = new ArrayList<Object>();
        Object rootContainer = containersRootHandle.invoke();
        VirtualThreadDumper.collectContainers(allContainers, rootContainer);
        return allContainers;
    }

    private static void collectContainers(ArrayList<Object> allContainers, Object container) throws Throwable {
        allContainers.add(container);
        Object children = containerChildrenHandle.invoke(container);
        Iterator it = streamIteratorHandle.invoke(children);
        while (it.hasNext()) {
            Object child = it.next();
            VirtualThreadDumper.collectContainers(allContainers, child);
        }
    }
}

