package name.remal.gradle_plugins.plugins.classes_relocation

import edu.umd.cs.findbugs.annotations.SuppressFBWarnings
import name.remal.*
import name.remal.gradle_plugins.api.AutoService
import name.remal.gradle_plugins.api.BuildTimeConstants.getClassDescriptor
import name.remal.gradle_plugins.api.RelocateClasses
import name.remal.gradle_plugins.api.RelocatePackages
import name.remal.gradle_plugins.api.classes_processing.BytecodeModifier
import name.remal.gradle_plugins.api.classes_processing.ClassesProcessor
import name.remal.gradle_plugins.api.classes_processing.ClassesProcessor.RELOCATION_STAGE
import name.remal.gradle_plugins.api.classes_processing.ClassesProcessorsGradleTaskFactory
import name.remal.gradle_plugins.api.classes_processing.ProcessContext
import name.remal.gradle_plugins.dsl.Plugin
import name.remal.gradle_plugins.dsl.artifact.ArtifactEntryNotFoundException
import name.remal.gradle_plugins.dsl.artifact.CachedArtifactsCollection
import name.remal.gradle_plugins.dsl.extensions.isPluginApplied
import name.remal.gradle_plugins.dsl.extensions.toHasEntries
import name.remal.gradle_plugins.dsl.internal.Generated
import name.remal.gradle_plugins.dsl.internal.RelocatedClass
import name.remal.gradle_plugins.dsl.utils.getGradleLogger
import org.gradle.api.artifacts.Configuration
import org.gradle.api.tasks.compile.AbstractCompile
import org.objectweb.asm.*
import org.objectweb.asm.ClassReader.*
import org.objectweb.asm.Type.getDescriptor
import org.objectweb.asm.commons.ClassRemapper
import org.objectweb.asm.commons.Method
import org.objectweb.asm.commons.MethodRemapper
import org.objectweb.asm.commons.Remapper
import org.objectweb.asm.tree.AnnotationNode
import java.util.concurrent.ForkJoinPool
import java.util.concurrent.ForkJoinTask
import java.util.concurrent.RecursiveAction
import java.util.stream.Stream
import kotlin.text.Charsets.UTF_8
import kotlin.text.isNotEmpty
import kotlin.text.trim

private val relocateClassesDesc = getDescriptor(RelocateClasses::class.java)
private val relocatePackagesDesc = getDescriptor(RelocatePackages::class.java)

class RelocateClassesClassesProcessor(
    relocatedClassesPackageName: String,
    private val classpathArtifacts: CachedArtifactsCollection,
    private val relocateClassesConf: Configuration,
    private val excludeFromClassesRelocationConf: Configuration,
    private val excludeFromForcedClassesRelocationConf: Configuration,
    private val isRelocateClassesAnnotationInClasspath: Boolean
) : ClassesProcessor {

    companion object {
        private val logger = getGradleLogger(RelocateClassesClassesProcessor::class.java)
        private val relocatedClassDesc = getDescriptor(RelocatedClass::class.java)
        private val generatedDesc = getDescriptor(Generated::class.java)
        private val suppressFBWarningsDesc = getClassDescriptor(SuppressFBWarnings::class.java)
        private val outerClassInternalNameRegex = Regex("\\\$[^\$/]*\$")
    }

    private val relocatedClassNamePrefix = relocatedClassesPackageName + '.'
    private val relocatedInternalClassNamePrefix = relocatedClassNamePrefix.replace('.', '/')

    private val internalClassNamesExcludedFromForcedRelocation: Set<String> by lazy {
        CachedArtifactsCollection(excludeFromForcedClassesRelocationConf).classNames.mapTo(mutableSetOf(), { it.replace('.', '/') })
    }

    private val possibleInternalClassNamesForForcedRelocation: Set<String> by lazy {
        classpathArtifacts.classNames.stream()
            .map { it.replace('.', '/') }
            .filter { it !in internalClassNamesExcludedFromForcedRelocation }
            .toSet()
    }

    private val internalClassNamesExcludedFromRelocation: Set<String> by lazy {
        CachedArtifactsCollection(excludeFromClassesRelocationConf).classNames.mapTo(mutableSetOf(), { it.replace('.', '/') })
    }

    private val possibleInternalClassNamesForRelocation: Set<String> by lazy {
        CachedArtifactsCollection(relocateClassesConf).classNames.stream()
            .map { it.replace('.', '/') }
            .filter { it !in internalClassNamesExcludedFromRelocation }
            .toSet()
    }


    private val alreadyRelocatedInternalClassNames = concurrentSetOf<String>()


    override fun process(bytecode: ByteArray, bytecodeModifier: BytecodeModifier, className: String, resourceName: String, context: ProcessContext) {
        var isChanged = false

        val internalClassNamesForForcedRelocation = mutableSetOf<String>()
        val internalClassNamesForRelocation = mutableSetOf<String>()

        val relocationInfo = readRelocationInfo(bytecode, context)
        val rootRemapper = object : RelocationRemapper() {
            override fun map(internalClassName: String): String? {
                if (internalClassName in relocationInfo.internalClassNamesForRelocation) {
                    isChanged = true
                    internalClassNamesForForcedRelocation.add(internalClassName)
                    return relocatedInternalClassNamePrefix + internalClassName
                }
                if (internalClassName in possibleInternalClassNamesForRelocation) {
                    isChanged = true
                    internalClassNamesForRelocation.add(internalClassName)
                    return relocatedInternalClassNamePrefix + internalClassName
                }
                return null
            }
        }

        val classReader = ClassReader(bytecode)
        val classWriter = ClassWriter(classReader, 0)
        val classRemapper = object : ClassRemapper(classWriter, rootRemapper) {
            override fun visitMethod(access: Int, name: String, desc: String, signature: String?, exceptions: Array<String>?): MethodVisitor? {
                val method = Method(name, desc)
                val methodRemapper = if (method in relocationInfo.methodsInternalClassNamesForRelocation) {
                    object : RelocationRemapper() {
                        private val methodInternalClassNamesForRelocation = relocationInfo.methodsInternalClassNamesForRelocation[method]!!
                        override fun map(internalClassName: String): String? {
                            if (internalClassName in methodInternalClassNamesForRelocation) {
                                isChanged = true
                                internalClassNamesForForcedRelocation.add(internalClassName)
                                return relocatedInternalClassNamePrefix + internalClassName
                            }
                            return remapper.map(internalClassName)
                        }
                    }
                } else {
                    remapper
                }

                val mv = super.visitMethod(
                    access,
                    methodRemapper.mapMethodName(className, name, desc),
                    methodRemapper.mapMethodDesc(desc),
                    methodRemapper.mapSignature(signature, false),
                    exceptions?.let(methodRemapper::mapTypes)
                )
                return if (mv == null) null else MethodRemapper(mv, methodRemapper)
            }
        }

        classReader.accept(classRemapper)

        if (isChanged) {
            bytecodeModifier.modify(classWriter.toByteArray())
        }


        internalClassNamesForForcedRelocation
            .filter { alreadyRelocatedInternalClassNames.add(it) }
            .map { RelocateRecursiveAction(it, classpathArtifacts, context, true) }
            .map { ForkJoinPool.commonPool().submit(it) }
            .forEach { it.get() }
        internalClassNamesForRelocation
            .filter { alreadyRelocatedInternalClassNames.add(it) }
            .map { RelocateRecursiveAction(it, classpathArtifacts, context, false) }
            .map { ForkJoinPool.commonPool().submit(it) }
            .forEach { it.get() }
    }


    private data class RelocationInfo(
        val internalClassNamesForRelocation: Set<String>,
        val methodsInternalClassNamesForRelocation: Map<Method, Set<String>> = emptyMap()
    ) {
        companion object {
            val EMPTY = RelocationInfo(emptySet(), emptyMap())
        }
    }

    private fun readRelocationInfo(bytecode: ByteArray, context: ProcessContext): RelocationInfo {
        if (!isRelocateClassesAnnotationInClasspath) return RelocationInfo.EMPTY

        var classInternalName: String? = null
        val classAnnotations = mutableListOf<AnnotationNode>()
        val methodAnnotations = mutableMapOf<Method, MutableList<AnnotationNode>>()
        ClassReader(bytecode).accept(
            object : ClassVisitor(ASM_API) {
                override fun visit(version: Int, access: Int, name: String?, signature: String?, superName: String?, interfaces: Array<String>?) {
                    classInternalName = name
                    super.visit(version, access, name, signature, superName, interfaces)
                }

                override fun visitAnnotation(desc: String?, visible: Boolean): AnnotationVisitor? {
                    if (relocateClassesDesc == desc || relocatePackagesDesc == desc) {
                        return AnnotationNode(desc).also { classAnnotations.add(it) }
                    }
                    return null
                }

                override fun visitMethod(access: Int, name: String, desc: String, signature: String?, exceptions: Array<out String>?): MethodVisitor {
                    return object : MethodVisitor(api) {
                        override fun visitAnnotation(annotationDesc: String?, annotationVisible: Boolean): AnnotationVisitor? {
                            if (relocateClassesDesc == annotationDesc || relocatePackagesDesc == desc) {
                                return AnnotationNode(annotationDesc).also {
                                    methodAnnotations.computeIfAbsent(Method(name, desc), { mutableListOf() }).add(it)
                                }
                            }
                            return null
                        }
                    }
                }
            },
            SKIP_DEBUG or SKIP_FRAMES or SKIP_CODE
        )

        val outerClassRelocationInfo: RelocationInfo = run {
            if (classInternalName == null) return@run RelocationInfo.EMPTY
            val outerClassInternalName = outerClassInternalNameRegex.replace(classInternalName!!, "")
            if (outerClassInternalName != classInternalName) {
                val outerClassBytecode = context.readBinaryResource(outerClassInternalName + CLASS_FILE_NAME_SUFFIX)
                if (outerClassBytecode != null) {
                    return@run readRelocationInfo(outerClassBytecode, context)

                } else {
                    return@run RelocationInfo.EMPTY
                }
            } else {
                return@run RelocationInfo.EMPTY
            }
        }

        return RelocationInfo(
            classAnnotations.stream()
                .flatMap { getClassInternalNamesFromRelocateAnnotation(classInternalName, it).stream() }
                .plus(outerClassRelocationInfo.internalClassNamesForRelocation.stream())
                .toSet(),
            methodAnnotations.mapValues {
                it.value.stream()
                    .flatMap { getClassInternalNamesFromRelocateAnnotation(classInternalName, it).stream() }
                    .toSet()
            }
        )
    }

    private fun getClassInternalNamesFromRelocateAnnotation(classInternalName: String?, annotationNode: AnnotationNode): Set<String> {
        if (relocateClassesDesc == annotationNode.desc) {
            return annotationNode[RelocateClasses::value]?.stream()
                ?.map { it.internalName }
                ?.filter {
                    if (it in internalClassNamesExcludedFromForcedRelocation) {
                        logger.error("${classInternalName?.replace('/', '.')}: $it is excluded from forced relocation")
                        return@filter false
                    } else {
                        return@filter true
                    }
                }
                ?.toSet()
                ?: emptySet()

        } else if (relocatePackagesDesc == annotationNode.desc) {
            return Stream.concat(
                annotationNode[RelocatePackages::value]?.stream()
                    ?.map { it.replace('.', '/') }
                    ?: emptyStream(),
                annotationNode[RelocatePackages::basePackageClasses]?.stream()
                    ?.map { it.internalName.substringBeforeLast('/', "") }
                    ?: emptyStream()
            )
                .filter(String::isNotEmpty)
                .distinct()
                .flatMap { packageInternalName ->
                    val prefix = "$packageInternalName/"
                    possibleInternalClassNamesForForcedRelocation.stream()
                        .filter { it.startsWith(prefix) }
                }
                .toSet()

        } else {
            throw UnsupportedOperationException("Unsupported annotation: ${annotationNode.desc}")
        }
    }


    private inner class RelocateRecursiveAction(
        private val rootInternalClassName: String,
        private val classpathArtifacts: CachedArtifactsCollection,
        private val context: ProcessContext,
        private val force: Boolean
    ) : RecursiveAction() {
        override fun compute() {
            val bytecode = try {
                classpathArtifacts.readBytecode(rootInternalClassName.replace('/', '.'))
            } catch (ignored: ArtifactEntryNotFoundException) {
                return
            }

            logger.debug(
                "Relocating class {} from artifact {}",
                rootInternalClassName,
                classpathArtifacts.entryMapping[rootInternalClassName + CLASS_FILE_NAME_SUFFIX]
            )

            val classReader = ClassReader(bytecode)
            val classWriter = ClassWriter(classReader, 0)

            val relocatedClassAnnotationAdder = object : ClassVisitor(ASM_API, classWriter) {
                override fun visit(version: Int, access: Int, name: String?, signature: String?, superName: String?, interfaces: Array<String>?) {
                    super.visit(version, access, name, signature, superName, interfaces)
                    super.visitAnnotation(relocatedClassDesc, false)?.visitEnd()
                    super.visitAnnotation(generatedDesc, false)?.visitEnd()
                    super.visitAnnotation(suppressFBWarningsDesc, false)?.visitEnd()
                }
            }

            val internalClassNamesForRelocation = mutableSetOf<String>()
            val classRemapper = ClassRemapper(relocatedClassAnnotationAdder, object : RelocationRemapper() {
                override fun map(internalClassName: String): String? {
                    if (internalClassName in possibleInternalClassNamesForRelocation || (force && internalClassName in possibleInternalClassNamesForForcedRelocation)) {
                        internalClassNamesForRelocation.add(internalClassName)
                        return relocatedInternalClassNamePrefix + internalClassName
                    }
                    return null
                }
            })

            classReader.accept(classRemapper)

            val resultInternalClassName = relocatedInternalClassNamePrefix + rootInternalClassName
            val resultResourceName = resultInternalClassName + CLASS_FILE_NAME_SUFFIX
            context.writeBinaryResource(resultResourceName, classWriter.toByteArray())


            classpathArtifacts.artifacts.forEach { artifact ->
                val rootClassName = rootInternalClassName.replace('/', '.')
                val serviceResourceName = "$SERVICE_FILE_BASE_PATH/$rootClassName"
                if (artifact.contains(serviceResourceName)) {
                    logger.debug(
                        "Relocating service file {} from artifact {}",
                        serviceResourceName,
                        artifact
                    )

                    val content = artifact.readBytes(serviceResourceName).toString(UTF_8)
                    val serviceClassNames = content.splitToSequence('\n')
                        .map { it.substringBefore('#') }
                        .map(String::trim)
                        .filter(String::isNotEmpty)
                        .toSet()
                    if (serviceClassNames.isNotEmpty()) {
                        val resultClassName = resultInternalClassName.replace('/', '.')
                        serviceClassNames.forEach { serviceClassName ->
                            val serviceInternalClassName = serviceClassName.replace('.', '/')
                            if (serviceInternalClassName in possibleInternalClassNamesForRelocation || (force && serviceInternalClassName in possibleInternalClassNamesForForcedRelocation)) {
                                internalClassNamesForRelocation.add(serviceInternalClassName)
                                context.writeService(resultClassName, relocatedClassNamePrefix + serviceClassName)

                            } else {
                                context.writeService(resultClassName, serviceClassName)
                            }
                        }
                    }
                }
            }


            internalClassNamesForRelocation
                .filter { alreadyRelocatedInternalClassNames.add(it) }
                .let { internalClassNames ->
                    if (internalClassNames.isNotEmpty()) {
                        ForkJoinTask.invokeAll(internalClassNames.map {
                            RelocateRecursiveAction(it, classpathArtifacts, context, force)
                        })
                    }
                }
        }
    }


    override fun getStage() = RELOCATION_STAGE

}

@KotlinAllOpen
class RelocationRemapper : Remapper() {

    override fun mapValue(value: Any?): Any? {
        if (value is String) {
            var mappedValue = map(value.replace('.', '/'))
            if (mappedValue != null && value != mappedValue) return mappedValue.replace('/', '.')
            mappedValue = map(value)
            if (mappedValue != null && value != mappedValue) return mappedValue
        }
        return super.mapValue(value)
    }
}


@AutoService
class RelocateClassesClassesProcessorFactory : ClassesProcessorsGradleTaskFactory {

    override fun createClassesProcessors(compileTask: AbstractCompile): List<ClassesProcessor> {
        if (!compileTask.project.isPluginApplied(ClassesRelocationPlugin::class.java)) return emptyList()

        val isRelocateClassesAnnotationInClasspath = compileTask.classpath.toHasEntries().containsClass(RelocateClasses::class.java)
        if (compileTask.project.configurations.relocateClasses.allDependencies.isEmpty() && !isRelocateClassesAnnotationInClasspath) return emptyList()

        return listOf(RelocateClassesClassesProcessor(
            compileTask.project.relocatedClassesJavaPackageName,
            CachedArtifactsCollection(compileTask.classpath),
            compileTask.project.configurations.relocateClasses,
            compileTask.project.configurations.excludeFromClassesRelocation,
            compileTask.project.configurations.excludeFromForcedClassesRelocation,
            isRelocateClassesAnnotationInClasspath
        ))
    }

}


@AutoService
class DisabledRelocateClassesClassesProcessorFactory : ClassesProcessorsGradleTaskFactory {

    companion object {
        @JvmStatic
        private val logger = getGradleLogger(DisabledRelocateClassesClassesProcessorFactory::class.java)
    }

    override fun createClassesProcessors(compileTask: AbstractCompile): List<ClassesProcessor> {
        if (compileTask.project.isPluginApplied(ClassesRelocationPlugin::class.java)) return emptyList()

        val isRelocateClassesAnnotationInClasspath = compileTask.classpath.toHasEntries().containsClass(RelocateClasses::class.java)
        if (!isRelocateClassesAnnotationInClasspath) return emptyList()

        return listOf(ClassesProcessor { bytecode, _, className, _, _ ->
            var isAnnotatedBy = false
            val visitor = object : ClassVisitor(ASM_API) {
                override fun visitAnnotation(desc: String?, visible: Boolean): AnnotationVisitor? {
                    if (relocateClassesDesc == desc || relocatePackagesDesc == desc) isAnnotatedBy = true
                    return null
                }

                override fun visitMethod(access: Int, name: String?, desc: String?, signature: String?, exceptions: Array<out String>?): MethodVisitor {
                    return object : MethodVisitor(api) {
                        override fun visitAnnotation(desc: String?, visible: Boolean): AnnotationVisitor? {
                            if (relocateClassesDesc == desc || relocatePackagesDesc == desc) isAnnotatedBy = true
                            return null
                        }
                    }
                }
            }
            ClassReader(bytecode).accept(visitor, SKIP_DEBUG or SKIP_FRAMES or SKIP_CODE)

            if (isAnnotatedBy) {
                logger.error(
                    "{}: Class/constructor/method is annotated by {}/{}. Apply '{}' Gradle plugin to process these annotations.",
                    className,
                    RelocateClasses::class.java.name,
                    RelocatePackages::class.java.name,
                    ClassesRelocationPlugin::class.java.getAnnotation(Plugin::class.java)?.id
                )
            }
        })
    }

}
