/**
 * OW2 Util
 * Copyright (C) 2009 Bull S.A.S.
 * Contact: easybeans@ow2.org
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with this library; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA
 *
 * --------------------------------------------------------------------------
 * $Id: SubClassServiceGenerator.java 5360 2010-02-24 14:29:45Z benoitf $
 * --------------------------------------------------------------------------
 */

package org.ow2.util.ee.builder.webserviceref.factory;

import java.lang.reflect.InvocationTargetException;

import javax.xml.ws.Service;

import org.ow2.util.asm.ClassWriter;
import org.ow2.util.asm.FieldVisitor;
import org.ow2.util.asm.MethodVisitor;
import org.ow2.util.asm.Opcodes;
import org.ow2.util.asm.Type;


/**
 * Allow to generate a subclass of a Service class in order to add some hooks
 * for the getPort(...) methods.
 * @author Florent Benoit
 */
public class SubClassServiceGenerator implements Opcodes {

    /**
     * Version used for generated class.
     */
    private static final int GENERATED_CLASS_VERSION = V1_5;

    /**
     * Constant 3.
     */
    private static final int THREE = 3;

    /**
     * Constant 4.
     */
    private static final int FOUR = 4;

    /**
     * Interface of this invocation context.
     */
    private static final String[] INTERFACES = new String[] {};

    /**
     * Interface of the Itom.
     */

    /**
     * IMtomHelper interface L......./IPortProcessor;  .
     */
    private static final String PORTPROCESSOR_ITF_DESCRIPTOR = Type.getType(IPortProcessor.class).getDescriptor();

    /**
     * IMtomHelper interface the/package/......./IPortProcessor.
     */
    private static final String PORTPROCESSOR_ITF_INTERNALNAME = Type.getType(IPortProcessor.class).getInternalName();

    /**
     * ClassWriter used by this generator.
     */
    private ClassWriter classWriter = null;

    /**
     * Name of the generated class.
     */
    private String name = null;

    /**
     * Name of the super class to use.
     */
    private String superClassname = null;

    /**
     * For the constructor use, know if we have a WSDL or not.
     */
    private boolean hasWSDL = false;

    /**
     * Version of JAX-WS.
     */
    private JAXWSVersion version = null;

    /**
     * Build a new generator.
     * @param name the generated classname
     * @param superClassname the super class name
     * @param hasWSDL if true, use another constructor.
     * @param version the JAX-WS version.
     */
    public SubClassServiceGenerator(final String name, final String superClassname, final boolean hasWSDL,
            final JAXWSVersion version) {
        this.classWriter = new ClassWriter(ClassWriter.COMPUTE_MAXS);
        this.name = name;
        this.superClassname = superClassname;
        this.hasWSDL = hasWSDL;
        this.version = version;
    }

    /**
     * Generates a subclass of the given class by intercepting getPort()
     * methods. If hasWSDL parameter is true, the constructor will contains URL
     * and QName parameters.
     * @param clazz the given class to subclass
     * @param hasWSDL if true, add parameters to the constructor
     * @param version the JAX-WS version.
     * @return a class
     */
    @SuppressWarnings("unchecked")
    public synchronized static Class<? extends Service> getClass(final Class<? extends Service> clazz, final boolean hasWSDL,
            final JAXWSVersion version) {

        // Generated class name
        String generatedName = "OW2$ServiceRef$Generated$" + clazz.getSimpleName();
        String packageName = clazz.getPackage().getName();
        generatedName = packageName + "." + generatedName;

        // Class already present ?
        Class<? extends Service> enhancedClass = null;
        try {
            enhancedClass = (Class<? extends Service>) clazz.getClassLoader().loadClass(generatedName);
        } catch (ClassNotFoundException e) {
            // needs to generate the class
            SubClassServiceGenerator generator = new SubClassServiceGenerator(generatedName.replace('.', '/'), clazz.getName()
                    .replace('.', '/'), hasWSDL, version);
            generator.generateClass();

            // Get array of bytes
            byte[] b = generator.getBytes();

            // Get defineClass method on the classLoader
            Class<?> cls;
            try {
                cls = Class.forName("java.lang.ClassLoader");
            } catch (ClassNotFoundException cnfe) {
                throw new IllegalStateException("Cannot get ClassLoader class", e);
            }
            java.lang.reflect.Method method;
            try {
                method = cls.getDeclaredMethod("defineClass", new Class[] {String.class,
                        byte[].class, int.class, int.class});
            } catch (SecurityException se) {
                throw new IllegalStateException("Cannot get defineClass method", se);
            } catch (NoSuchMethodException nsme) {
                throw new IllegalStateException("Cannot get defineClass method", nsme);
            }

            // protected method invocaton
            method.setAccessible(true);
            try {
                Object[] args = new Object[] {generatedName, b, Integer.valueOf(0), Integer.valueOf(b.length)};
                try {
                    enhancedClass = (Class<? extends Service>) method.invoke(clazz.getClassLoader(), args);
                } catch (IllegalArgumentException e1) {
                    throw new IllegalStateException("Unable to define class on the classloader", e1);
                } catch (IllegalAccessException e1) {
                    throw new IllegalStateException("Unable to define class on the classloader", e1);
                } catch (InvocationTargetException e1) {
                    throw new IllegalStateException("Unable to define class on the classloader", e1);
                }
            } finally {
                method.setAccessible(false);
            }

        }

        return enhancedClass;
    }

    /**
     * Creates the declaration of the class with the given interfaces.
     */
    protected void addClassDeclaration() {
        // create class
        classWriter.visit(GENERATED_CLASS_VERSION, ACC_PUBLIC + ACC_SUPER, this.name, null, superClassname,
                INTERFACES);
    }

    /**
     * Generate attributes of the class.
     */
    protected void addAttributes() {
        FieldVisitor fv = classWriter.visitField(ACC_PRIVATE, "portProcessor", PORTPROCESSOR_ITF_DESCRIPTOR, null, null);
        fv.visitEnd();
    }

    /**
     * Add the constructor (one parameter or three parameters if WSDL is disabled/enabled).
     */
    protected void addConstructor() {
        MethodVisitor mv = null;

        // More arguments if WSDL
        if (hasWSDL) {
            // public GeneratedService(IPortProcessor portProcessor, URL wsdlDocumentLocation, QName serviceName) {
            //    super(wsdlDocumentLocation, serviceName);
            mv = classWriter.visitMethod(ACC_PUBLIC, "<init>", "(" + PORTPROCESSOR_ITF_DESCRIPTOR
                    + "Ljava/net/URL;Ljavax/xml/namespace/QName;)V", null, null);
            mv.visitCode();
            mv.visitVarInsn(ALOAD, 0);
            mv.visitVarInsn(ALOAD, 2);
            mv.visitVarInsn(ALOAD, THREE);
            mv.visitMethodInsn(INVOKESPECIAL, superClassname, "<init>", "(Ljava/net/URL;Ljavax/xml/namespace/QName;)V");
        } else {
            // public GeneratedService(IMtomHelper mtomHelper) {
            //    super();
            mv = classWriter.visitMethod(ACC_PUBLIC, "<init>", "(" + PORTPROCESSOR_ITF_DESCRIPTOR + ")V", null, null);
            mv.visitCode();
            mv.visitVarInsn(ALOAD, 0);
            mv.visitMethodInsn(INVOKESPECIAL, superClassname, "<init>", "()V");
        }

        //     this.portProcessor = portProcessor;
        // }
        mv.visitVarInsn(ALOAD, 0);
        mv.visitVarInsn(ALOAD, 1);
        mv.visitFieldInsn(PUTFIELD, name, "portProcessor", PORTPROCESSOR_ITF_DESCRIPTOR);
        mv.visitInsn(RETURN);
        mv.visitMaxs(0, 0);
        mv.visitEnd();

    }

    /**
     * Generate getPort() methods.
     */
    protected void addgetPortMethods() {

        /**
         * public <T> T getPort(Class<T> serviceEndpointInterface) {<br>
         *  T port = super.getPort(serviceEndpointInterface);<br>
         *  postProcess(serviceEndpointInterface, port);<br>
         *  return port; <br>
         * }<br>
         */
        MethodVisitor mv = classWriter.visitMethod(ACC_PUBLIC, "getPort",
                "(Ljava/lang/Class;)Ljava/lang/Object;",
                "<T:Ljava/lang/Object;>(Ljava/lang/Class<TT;>;)TT;", null);
        mv.visitCode();
        mv.visitVarInsn(ALOAD, 0);
        mv.visitVarInsn(ALOAD, 1);
        mv.visitMethodInsn(INVOKESPECIAL, superClassname, "getPort", "(Ljava/lang/Class;)Ljava/lang/Object;");
        mv.visitVarInsn(ASTORE, 2);
        mv.visitVarInsn(ALOAD, 0);
        mv.visitVarInsn(ALOAD, 1);
        mv.visitVarInsn(ALOAD, 2);
        mv.visitMethodInsn(INVOKEVIRTUAL, name, "postProcess", "(Ljava/lang/Class;Ljava/lang/Object;)V");
        mv.visitVarInsn(ALOAD, 2);
        mv.visitInsn(ARETURN);
        mv.visitMaxs(0, 0);
        mv.visitEnd();

        /**
         * public <T> T getPort(Class<T> serviceEndpointInterface, WebServiceFeature... features) {<br>
         *   T port = super.getPort(serviceEndpointInterface, features);<br>
         *   postProcess(serviceEndpointInterface, port);<br>
         *   return port;<br>
         * }<br>
         */
        if (JAXWSVersion.JAXWS_21 == version) {
            mv = classWriter.visitMethod(ACC_PUBLIC + ACC_VARARGS, "getPort",
                    "(Ljava/lang/Class;[Ljavax/xml/ws/WebServiceFeature;)Ljava/lang/Object;",
                    "<T:Ljava/lang/Object;>(Ljava/lang/Class<TT;>;[Ljavax/xml/ws/WebServiceFeature;)TT;", null);
            mv.visitCode();
            mv.visitVarInsn(ALOAD, 0);
            mv.visitVarInsn(ALOAD, 1);
            mv.visitVarInsn(ALOAD, 2);
            mv.visitMethodInsn(INVOKESPECIAL, superClassname, "getPort",
            "(Ljava/lang/Class;[Ljavax/xml/ws/WebServiceFeature;)Ljava/lang/Object;");
            mv.visitVarInsn(ASTORE, THREE);
            mv.visitVarInsn(ALOAD, 0);
            mv.visitVarInsn(ALOAD, 1);
            mv.visitVarInsn(ALOAD, THREE);
            mv.visitMethodInsn(INVOKEVIRTUAL, name, "postProcess", "(Ljava/lang/Class;Ljava/lang/Object;)V");
            mv.visitVarInsn(ALOAD, THREE);
            mv.visitInsn(ARETURN);
            mv.visitMaxs(0, 0);
            mv.visitEnd();
        }

        /**
         * public <T> T getPort(QName portName, Class<T>
         * serviceEndpointInterface) { T port = super.getPort(portName,
         * serviceEndpointInterface);
         * postProcess(serviceEndpointInterface, port); return port; }
         */
        mv = classWriter.visitMethod(ACC_PUBLIC, "getPort", "(Ljavax/xml/namespace/QName;Ljava/lang/Class;)Ljava/lang/Object;",
                "<T:Ljava/lang/Object;>(Ljavax/xml/namespace/QName;Ljava/lang/Class<TT;>;)TT;", null);
        mv.visitCode();
        mv.visitVarInsn(ALOAD, 0);
        mv.visitVarInsn(ALOAD, 1);
        mv.visitVarInsn(ALOAD, 2);
        mv.visitMethodInsn(INVOKESPECIAL, superClassname, "getPort",
                "(Ljavax/xml/namespace/QName;Ljava/lang/Class;)Ljava/lang/Object;");
        mv.visitVarInsn(ASTORE, THREE);
        mv.visitVarInsn(ALOAD, 0);
        mv.visitVarInsn(ALOAD, 2);
        mv.visitVarInsn(ALOAD, THREE);
        mv.visitMethodInsn(INVOKEVIRTUAL, name, "postProcess", "(Ljava/lang/Class;Ljava/lang/Object;)V");
        mv.visitVarInsn(ALOAD, THREE);
        mv.visitInsn(ARETURN);
        mv.visitMaxs(0, 0);
        mv.visitEnd();

        /**
         * public <T> T getPort(QName portName, Class<T>
         * serviceEndpointInterface,WebServiceFeature... features) { T port =
         * super.getPort(portName, serviceEndpointInterface, features);
         * postProcess(serviceEndpointInterface, port); return port; }
         */
        if (JAXWSVersion.JAXWS_21 == version) {
            mv = classWriter
            .visitMethod(
                            ACC_PUBLIC + ACC_VARARGS,
                            "getPort",
                            "(Ljavax/xml/namespace/QName;Ljava/lang/Class;[Ljavax/xml/ws/WebServiceFeature;)Ljava/lang/Object;",
                            "<T:Ljava/lang/Object;>(Ljavax/xml/namespace/QName;Ljava/lang/Class<TT;>;[Ljavax/xml/ws/WebServiceFeature;)TT;",
                            null);
            mv.visitCode();
            mv.visitVarInsn(ALOAD, 0);
            mv.visitVarInsn(ALOAD, 1);
            mv.visitVarInsn(ALOAD, 2);
            mv.visitVarInsn(ALOAD, THREE);
            mv.visitMethodInsn(INVOKESPECIAL, superClassname, "getPort",
            "(Ljavax/xml/namespace/QName;Ljava/lang/Class;[Ljavax/xml/ws/WebServiceFeature;)Ljava/lang/Object;");
            mv.visitVarInsn(ASTORE, FOUR);
            mv.visitVarInsn(ALOAD, 0);
            mv.visitVarInsn(ALOAD, 2);
            mv.visitVarInsn(ALOAD, FOUR);
            mv.visitMethodInsn(INVOKEVIRTUAL, name, "postProcess", "(Ljava/lang/Class;Ljava/lang/Object;)V");
            mv.visitVarInsn(ALOAD, FOUR);
            mv.visitInsn(ARETURN);
            mv.visitMaxs(0, 0);
            mv.visitEnd();
        }

        /**
         * public <T> T getPort(EndpointReference endpointReference, Class<T>
         * serviceEndpointInterface, WebServiceFeature... features) { T port =
         * super.getPort(endpointReference, serviceEndpointInterface, features);
         * postProcess(serviceEndpointInterface, port); return port; }
         */
        if (JAXWSVersion.JAXWS_21 == version) {
            mv = classWriter
            .visitMethod(
                            ACC_PUBLIC + ACC_VARARGS,
                            "getPort",
                            "(Ljavax/xml/ws/EndpointReference;Ljava/lang/Class;[Ljavax/xml/ws/WebServiceFeature;)Ljava/lang/Object;",
                            "<T:Ljava/lang/Object;>(Ljavax/xml/ws/EndpointReference;Ljava/lang/Class<TT;>;[Ljavax/xml/ws/WebServiceFeature;)TT;",
                            null);
            mv.visitCode();
            mv.visitVarInsn(ALOAD, 0);
            mv.visitVarInsn(ALOAD, 1);
            mv.visitVarInsn(ALOAD, 2);
            mv.visitVarInsn(ALOAD, THREE);
            mv.visitMethodInsn(INVOKESPECIAL, superClassname, "getPort",
            "(Ljavax/xml/ws/EndpointReference;Ljava/lang/Class;[Ljavax/xml/ws/WebServiceFeature;)Ljava/lang/Object;");
            mv.visitVarInsn(ASTORE, FOUR);
            mv.visitVarInsn(ALOAD, 0);
            mv.visitVarInsn(ALOAD, 2);
            mv.visitVarInsn(ALOAD, FOUR);
            mv.visitMethodInsn(INVOKEVIRTUAL, name, "postProcess", "(Ljava/lang/Class;Ljava/lang/Object;)V");
            mv.visitVarInsn(ALOAD, FOUR);
            mv.visitInsn(ARETURN);
            mv.visitMaxs(0, 0);
            mv.visitEnd();
        }

    }

    /**
     * Generate postProcess() method.
     */
    protected void addpostProcessMethod() {
        MethodVisitor mv = classWriter.visitMethod(ACC_PROTECTED, "postProcess", "(Ljava/lang/Class;Ljava/lang/Object;)V",
                "<T:Ljava/lang/Object;>(Ljava/lang/Class<TT;>;TT;)V", null);
        mv.visitCode();
        mv.visitVarInsn(ALOAD, 0);
        mv.visitFieldInsn(GETFIELD, name, "portProcessor", PORTPROCESSOR_ITF_DESCRIPTOR);
        mv.visitVarInsn(ALOAD, 2);
        mv.visitVarInsn(ALOAD, 1);
        mv.visitMethodInsn(INVOKEINTERFACE, PORTPROCESSOR_ITF_INTERNALNAME, "postProcess",
                "(Ljava/lang/Object;Ljava/lang/Class;)V");
        mv.visitInsn(RETURN);
        mv.visitMaxs(0, 0);
        mv.visitEnd();
    }

    /**
     * Called when the generated class is done.
     */
    private void endClass() {
        classWriter.visitEnd();
    }

    /**
     * Generate the class.
     */
    public void generateClass() {
        addClassDeclaration();
        addAttributes();
        addConstructor();
        addgetPortMethods();
        addpostProcessMethod();
        endClass();
    }

    /**
     * @return the bytecode of this class.
     */
    public byte[] getBytes() {
        return classWriter.toByteArray();
    }
}
