美文网首页
34 - ASM之修改已有的方法

34 - ASM之修改已有的方法

作者: 舍是境界 | 来源:发表于2022-02-14 07:00 被阅读0次

    预期目标

    假如有一个HelloWorld类,代码如下:

    public class HelloWorld {
        public void test() {
            System.out.println("this is a test method.");
        }
    }
    

    我们想实现的预期目标:对于test()方法,在“方法进入”时和“方法退出”时,添加一条打印语句。

    • 第一种情况,在“方法进入”时,预期目标如下所示:
    public class HelloWorld {
        public void test() {
            System.out.println("Method Enter...");
            System.out.println("this is a test method.");
        }
    }
    
    • 第二种情况,在“方法退出”时,预期目标如下所示:
    public class HelloWorld {
        public void test() {
            System.out.println("this is a test method.");
            System.out.println("Method Exit...");
        }
    }
    

    现在,我们有了明确的预期目标;接下来,就是将这个预期目标转换成具体的ASM代码。那么,应该怎么实现呢?从哪里着手呢?

    实现思路

    我们知道,现在的内容是Class Transformation的操作,其中涉及到三个主要的类:ClassReader、ClassVisitor和ClassWriter。其中,ClassReader负责读取Class文件,ClassWriter负责生成Class文件,而具体的ClassVisitor负责进行Transformation的操作。换句话说,我们还是应该从ClassVisitor类开始。

    第一步,回顾一下ClassVisitor类当中主要的visitXxx()方法有哪些。在ClassVisitor类当中,有visit()、visitField()、visitMethod()和visitEnd()方法;这些visitXxx()方法与.class文件里的不同部分之间是有对应关系的,如下图:

    visitXxx关系

    根据我们的预期目标,现在想要修改的是“方法”的部分,那么就对应着ClassVisitor类的visitMethod()方法。ClassVisitor.visitMethod()会返回一个MethodVisitor类的实例;而MethodVisitor类就是用来生成方法的“方法体”。

    第二步,回顾一下MethodVisitor类当中定义了哪些visitXxx()方法。

    methodVisitor方法

    在MethodVisitor类当中,定义的visitXxx()方法比较多,但是我们可以将这些visitXxx()方法进行分组:

    • 第一组,visitCode()方法,标志着方法体(method body)的开始。
    • 第二组,visitXxxInsn()方法,对应方法体(method body)本身,这里包含多个方法。
    • 第三组,visitMaxs()方法,标志着方法体(method body)的结束。
    • 第四组,visitEnd()方法,是最后调用的方法。

    另外,我们也回顾一下,在MethodVisitor类中,visitXxx()方法的调用顺序:

    • 第一步,调用visitCode()方法,调用一次。
    • 第二步,调用visitXxxInsn()方法,可以调用多次。
    • 第三步,调用visitMaxs()方法,调用一次。
    • 第四步,调用visitEnd()方法,调用一次。

    到了这一步,我们基本上就知道了:需要修改的内容就位于visitCode()和visitMaxs()方法之间,这是一个大概的范围。

    第三步,精确定位。也就是说,在MethodVisitor类当中,要确定出要在哪一个visitXxx()方法里进行修改。

    方法进入

    如果我们想在“方法进入”时,添加一些打印语句,那么我们有两个位置可以添加打印语句:

    • 第一个位置,就是在visitCode()方法中。
    • 第二个位置,就是在第1个visitXxxInsn()方法中。

    在这两个位置当中,我们推荐使用visitCode()方法。因为visitCode()方法总是位于方法体(method body)的前面,而第1个visitXxxInsn()方法是不稳定的。

    public void visitCode() {
        // 首先,处理自己的代码逻辑
        // TODO: 添加“方法进入”时的代码
    
        // 其次,调用父类的方法实现
        super.visitCode();
    }
    

    方法退出

    如果我们在“方法退出”时想添加的代码,是否可以添加到visitMaxs()方法内呢?这样做是不行的。因为在执行visitMaxs()方法之前,方法体(method body)已经执行过了:在方法体(method body)当中,里面会包含return语句;如果return语句一执行,后面的任何语句都不会再执行了;换句话说,如果在visitMaxs()方法内添加的打印输出语句,由于前面方法体(method body)中已经执行了return语句,后面的任何语句就执行不到了。

    那么,到底是应该在哪里添加代码呢?为了回答这个问题,我们需要知道“方法退出”有哪几种情况。方法的退出,有两种情况,一种是正常退出(执行return语句),另一种是异常退出(执行throw语句);接下来,就是将这两种退出情况应用到ASM的代码层面。

    在MethodVisitor类当中,无论是执行return语句,还是执行throw语句,都是通过visitInsn(opcode)方法来实现的。所以,如果我们想在“方法退出”时,添加一些语句,那么这些语句放到visitInsn(opcode)方法中就可以了。

    public void visitInsn(int opcode) {
        // 首先,处理自己的代码逻辑
        if (opcode == Opcodes.ATHROW || (opcode >= Opcodes.IRETURN && opcode <= Opcodes.RETURN)) {
            // TODO: 添加“方法退出”时的代码
        }
    
        // 其次,调用父类的方法实现
        super.visitInsn(opcode);
    }
    

    推荐做法:在编写ASM代码的时候,如果写了一个类,它继承自ClassVisitor,那么就命名成XxxVisitor;如果写了一个类,它继承自MethodVisitor,那么就命名成XxxAdapter。通过类的名字,我就可以区分出哪些类是继承自ClassVisitor,哪些类是继承自MethodVisitor。

    示例一:方法进入

    编码实现:

    import org.objectweb.asm.ClassVisitor;
    import org.objectweb.asm.MethodVisitor;
    import org.objectweb.asm.Opcodes;
    
    public class MethodEnterVisitor extends ClassVisitor {
        public MethodEnterVisitor(int api, ClassVisitor classVisitor) {
            super(api, classVisitor);
        }
    
        @Override
        public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
            MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions);
            if (mv != null && !"<init>".equals(name)) {
                mv = new MethodEnterAdapter(api, mv);
            }
            return mv;
        }
    
        private static class MethodEnterAdapter extends MethodVisitor {
            public MethodEnterAdapter(int api, MethodVisitor methodVisitor) {
                super(api, methodVisitor);
            }
    
            @Override
            public void visitCode() {
                // 首先,处理自己的代码逻辑
                super.visitFieldInsn(Opcodes.GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
                super.visitLdcInsn("Method Enter...");
                super.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
    
                // 其次,调用父类的方法实现
                super.visitCode();
            }
        }
    }
    

    在上面MethodEnterAdapter类的visitCode()方法中,主要是做两件事情:

    • 首先,处理自己的代码逻辑。
    • 其次,调用父类的方法实现。

    在处理自己的代码逻辑中,有3行代码。这3条语句的作用就是添加System.out.println("Method Enter...");语句:

    super.visitFieldInsn(Opcodes.GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
    super.visitLdcInsn("Method Enter...");
    super.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
    

    注意,上面的代码中使用了super关键字。

    事实上,在MethodVisitor类当中,定义了一个protected MethodVisitor mv;字段。我们也可以使用mv这个字段,代码也可以这样写:

    mv.visitFieldInsn(Opcodes.GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
    mv.visitLdcInsn("Method Enter...");
    mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
    

    但是这样写,可能会遇到mv为null的情况,这样就会出现NullPointerException异常。

    如果使用super,就会避免NullPointerException异常的情况。因为使用super的情况下,就是调用父类定义的方法,在本例中其实就是调用MethodVisitor类里定义的方法。在MethodVisitor类里的visitXxx()方法中,会先对mv进行是否为null的判断,所以就不会出现NullPointerException的情况。

    public abstract class MethodVisitor {
        protected MethodVisitor mv;
    
        public void visitCode() {
            if (mv != null) {
                mv.visitCode();
            }
        }
    
        public void visitInsn(final int opcode) {
            if (mv != null) {
                mv.visitInsn(opcode);
            }
        }
    
        public void visitIntInsn(final int opcode, final int operand) {
            if (mv != null) {
                mv.visitIntInsn(opcode, operand);
            }
        }
    
        public void visitVarInsn(final int opcode, final int var) {
            if (mv != null) {
                mv.visitVarInsn(opcode, var);
            }
        }
    
        public void visitFieldInsn(final int opcode, final String owner, final String name, final String descriptor) {
            if (mv != null) {
                mv.visitFieldInsn(opcode, owner, name, descriptor);
            }
        }
    
        // ......
    
        public void visitMaxs(final int maxStack, final int maxLocals) {
            if (mv != null) {
                mv.visitMaxs(maxStack, maxLocals);
            }
        }
    
        public void visitEnd() {
            if (mv != null) {
                mv.visitEnd();
            }
        }
    }
    

    进行转换:

    import lsieun.utils.FileUtils;
    import org.objectweb.asm.*;
    
    public class HelloWorldTransformCore {
        public static void main(String[] args) {
            String relative_path = "sample/HelloWorld.class";
            String filepath = FileUtils.getFilePath(relative_path);
            byte[] bytes1 = FileUtils.readBytes(filepath);
    
            //(1)构建ClassReader
            ClassReader cr = new ClassReader(bytes1);
    
            //(2)构建ClassWriter
            ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
    
            //(3)串连ClassVisitor
            int api = Opcodes.ASM9;
            ClassVisitor cv = new MethodEnterVisitor(api, cw);
    
            //(4)结合ClassReader和ClassVisitor
            int parsingOptions = ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES;
            cr.accept(cv, parsingOptions);
    
            //(5)生成byte[]
            byte[] bytes2 = cw.toByteArray();
    
            FileUtils.writeBytes(filepath, bytes2);
        }
    }
    

    结果验证:

    import java.lang.reflect.Method;
    
    public class HelloWorldRun {
        public static void main(String[] args) throws Exception {
            Class<?> clazz = Class.forName("sample.HelloWorld");
            Method m = clazz.getDeclaredMethod("test");
    
            Object instance = clazz.newInstance();
            m.invoke(instance);
        }
    }
    

    特殊情况:<init>()方法

    在.class文件中,<init>()方法,就表示类当中的构造方法。
    我们在“方法进入”时,有一个对于<init>的判断:

    if (mv != null && !"<init>".equals(name)) {
        // ......
    }
    

    Java requires that if you call this() or super() in a constructor, it must be the first statement.

    public class HelloWorld {
        public HelloWorld() {
            System.out.println("Method Enter...");
            super(); // 报错:Call to 'super()' must be first statement in constructor body
        }
    }
    

    去掉对于<init>()方法的判断,会发现它好像也是可以正常执行的。
    但是,如果我们换一下添加的语句,就会出错了:

    super.visitFieldInsn(Opcodes.GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
    super.visitVarInsn(Opcodes.ALOAD, 0);
    super.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Object", "toString", "()Ljava/lang/String;", false);
    super.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
    

    示例二:方法退出

    代码实现:

    import org.objectweb.asm.ClassVisitor;
    import org.objectweb.asm.MethodVisitor;
    import org.objectweb.asm.Opcodes;
    
    public class MethodExitVisitor extends ClassVisitor {
        public MethodExitVisitor(int api, ClassVisitor classVisitor) {
            super(api, classVisitor);
        }
    
        @Override
        public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
            MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions);
            if (mv != null && !"<init>".equals(name)) {
                mv = new MethodExitAdapter(api, mv);
            }
            return mv;
        }
    
        private static class MethodExitAdapter extends MethodVisitor {
            public MethodExitAdapter(int api, MethodVisitor methodVisitor) {
                super(api, methodVisitor);
            }
    
            @Override
            public void visitInsn(int opcode) {
                // 首先,处理自己的代码逻辑
                if (opcode == Opcodes.ATHROW || (opcode >= Opcodes.IRETURN && opcode <= Opcodes.RETURN)) {
                    super.visitFieldInsn(Opcodes.GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
                    super.visitLdcInsn("Method Exit...");
                    super.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
                }
    
                // 其次,调用父类的方法实现
                super.visitInsn(opcode);
            }
        }
    }
    

    进行转换:

    import lsieun.utils.FileUtils;
    import org.objectweb.asm.*;
    
    public class HelloWorldTransformCore {
        public static void main(String[] args) {
            String relative_path = "sample/HelloWorld.class";
            String filepath = FileUtils.getFilePath(relative_path);
            byte[] bytes1 = FileUtils.readBytes(filepath);
    
            //(1)构建ClassReader
            ClassReader cr = new ClassReader(bytes1);
    
            //(2)构建ClassWriter
            ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
    
            //(3)串连ClassVisitor
            int api = Opcodes.ASM9;
            ClassVisitor cv = new MethodExitVisitor(api, cw);
    
            //(4)结合ClassReader和ClassVisitor
            int parsingOptions = ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES;
            cr.accept(cv, parsingOptions);
    
            //(5)生成byte[]
            byte[] bytes2 = cw.toByteArray();
    
            FileUtils.writeBytes(filepath, bytes2);
        }
    }
    

    结果验证:

    import java.lang.reflect.Method;
    
    public class HelloWorldRun {
        public static void main(String[] args) throws Exception {
            Class<?> clazz = Class.forName("sample.HelloWorld");
            Method m = clazz.getDeclaredMethod("test");
    
            Object instance = clazz.newInstance();
            m.invoke(instance);
        }
    }
    

    输出结果:

    this is a test method.
    Method Exit...
    

    示例三:方法进入和方法退出

    第一种方式

    第一种方式,就是将多个ClassVisitor类串联起来。

    import lsieun.utils.FileUtils;
    import org.objectweb.asm.*;
    
    public class HelloWorldTransformCore {
        public static void main(String[] args) {
            String relative_path = "sample/HelloWorld.class";
            String filepath = FileUtils.getFilePath(relative_path);
            byte[] bytes1 = FileUtils.readBytes(filepath);
    
            //(1)构建ClassReader
            ClassReader cr = new ClassReader(bytes1);
    
            //(2)构建ClassWriter
            ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
    
            //(3)串连ClassVisitor
            int api = Opcodes.ASM9;
            ClassVisitor cv1 = new MethodEnterVisitor(api, cw);
            ClassVisitor cv2 = new MethodExitVisitor(api, cv1);
            ClassVisitor cv = cv2;
    
            //(4)结合ClassReader和ClassVisitor
            int parsingOptions = ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES;
            cr.accept(cv, parsingOptions);
    
            //(5)生成byte[]
            byte[] bytes2 = cw.toByteArray();
    
            FileUtils.writeBytes(filepath, bytes2);
        }
    }
    
    第二种方式

    第二种方式,就是将所有的代码都放到一个ClassVisitor类里面。

    编码实现:

    import org.objectweb.asm.ClassVisitor;
    import org.objectweb.asm.MethodVisitor;
    import org.objectweb.asm.Opcodes;
    
    public class MethodAroundVisitor extends ClassVisitor {
        public MethodAroundVisitor(int api, ClassVisitor classVisitor) {
            super(api, classVisitor);
        }
    
        @Override
        public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
            MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions);
            if (mv != null && !"<init>".equals(name)) {
                boolean isAbstractMethod = (access & Opcodes.ACC_ABSTRACT) == Opcodes.ACC_ABSTRACT;
                boolean isNativeMethod = (access & Opcodes.ACC_NATIVE) == Opcodes.ACC_NATIVE;
                if (!isAbstractMethod && !isNativeMethod) {
                    mv = new MethodAroundAdapter(api, mv);
                }
            }
            return mv;
        }
    
        private static class MethodAroundAdapter extends MethodVisitor {
            public MethodAroundAdapter(int api, MethodVisitor methodVisitor) {
                super(api, methodVisitor);
            }
    
            @Override
            public void visitCode() {
                // 首先,处理自己的代码逻辑
                super.visitFieldInsn(Opcodes.GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
                super.visitLdcInsn("Method Enter...");
                super.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
    
                // 其次,调用父类的方法实现
                super.visitCode();
            }
    
            @Override
            public void visitInsn(int opcode) {
                // 首先,处理自己的代码逻辑
                if (opcode == Opcodes.ATHROW || (opcode >= Opcodes.IRETURN && opcode <= Opcodes.RETURN)) {
                    super.visitFieldInsn(Opcodes.GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
                    super.visitLdcInsn("Method Exit...");
                    super.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
                }
    
                // 其次,调用父类的方法实现
                super.visitInsn(opcode);
            }
        }
    }
    

    进行转换:

    import lsieun.utils.FileUtils;
    import org.objectweb.asm.*;
    
    public class HelloWorldTransformCore {
        public static void main(String[] args) {
            String relative_path = "sample/HelloWorld.class";
            String filepath = FileUtils.getFilePath(relative_path);
            byte[] bytes1 = FileUtils.readBytes(filepath);
    
            //(1)构建ClassReader
            ClassReader cr = new ClassReader(bytes1);
    
            //(2)构建ClassWriter
            ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
    
            //(3)串连ClassVisitor
            int api = Opcodes.ASM9;
            ClassVisitor cv = new MethodAroundVisitor(api, cw);
    
            //(4)结合ClassReader和ClassVisitor
            int parsingOptions = ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES;
            cr.accept(cv, parsingOptions);
    
            //(5)生成byte[]
            byte[] bytes2 = cw.toByteArray();
    
            FileUtils.writeBytes(filepath, bytes2);
        }
    }
    

    总结

    本文主要是对“方法进入”和“方法退出”添加代码进行介绍,内容总结如下:

    • 第一点,在“方法进入”时和“方法退出”时添加代码,应该如何实现?
      • 在“方法进入”时添加代码,是在visitCode()方法当中完成;
      • 在“方法退出”添加代码时,是在visitInsn(opcode)方法中,判断opcode为return或throw的情况下完成。
    • 第二点,在“方法进入”时和“方法退出”时添加代码,有一些特殊的情况,需要小心处理:
      • 接口,是否需要处理?接口当中的抽象方法没有方法体,但也可能有带有方法体的default方法。
      • 带有特殊修饰符的方法:
        • 抽象方法,是否需要处理?不只是接口当中有抽象方法,抽象类里也可能有抽象方法。抽象方法,是没有方法体的。
        • native方法,是否需要处理?native方法是没有方法体的。
      • 名字特殊的方法,例如,构造方法(<init>())和静态初始化方法(<clinit>()),是否需要处理?

    另外,在编写代码的时候,我们遵循一个“规则”:如果是ClassVisitor的子类,就取名为XxxVisitor类;如果是MethodVisitor的子类,就取名为XxxAdapter类。

    本文的介绍方式侧重于让大家理解“工作原理”,而后续介绍的AdviceAdapter则侧重于“应用”,AdviceAdapter的实现也是基于visitCode()和visitInsn(opcode)方法实现的,在理解上有一个步步递进的关系。

    本文也是基础方式,可以应对各种场景,大家需进行掌握。

    相关文章

      网友评论

          本文标题:34 - ASM之修改已有的方法

          本文链接:https://www.haomeiwen.com/subject/icvxlrtx.html