diff --git a/java/compiler/instrumentation-util/src/com/intellij/compiler/notNullVerification/NotNullVerifyingInstrumenter.java b/java/compiler/instrumentation-util/src/com/intellij/compiler/notNullVerification/NotNullVerifyingInstrumenter.java index 764efb7a6ee1..41076a6c41a5 100644 --- a/java/compiler/instrumentation-util/src/com/intellij/compiler/notNullVerification/NotNullVerifyingInstrumenter.java +++ b/java/compiler/instrumentation-util/src/com/intellij/compiler/notNullVerification/NotNullVerifyingInstrumenter.java @@ -7,10 +7,7 @@ import org.jetbrains.org.objectweb.asm.*; import java.io.PrintWriter; import java.io.StringWriter; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.Map; -import java.util.Set; +import java.util.*; /** * @author ven @@ -25,20 +22,17 @@ public class NotNullVerifyingInstrumenter extends ClassVisitor implements Opcode private static final String[] EMPTY_STRING_ARRAY = new String[0]; private final MethodData myMethodData; - private String myClassName; private boolean myIsModification = false; private RuntimeException myPostponedError; private final AuxiliaryMethodGenerator myAuxGenerator; - private final Set myNotNullAnnotations = new HashSet(); - private boolean myEnum; - private boolean myInner; private NotNullVerifyingInstrumenter(ClassVisitor classVisitor, ClassReader reader, String[] notNullAnnotations) { super(Opcodes.API_VERSION, classVisitor); + Set annoSet = new HashSet(); for (String annotation : notNullAnnotations) { - myNotNullAnnotations.add('L' + annotation.replace('.', '/') + ';'); + annoSet.add('L' + annotation.replace('.', '/') + ';'); } - myMethodData = collectMethodData(reader, myNotNullAnnotations); + myMethodData = collectMethodData(reader, annoSet); myAuxGenerator = new AuxiliaryMethodGenerator(reader); } @@ -48,49 +42,70 @@ public class NotNullVerifyingInstrumenter extends ClassVisitor implements Opcode return instrumenter.myIsModification; } + private static class MethodInfo { + final NotNullState nullability = new NotNullState(); + final Map paramNames = new HashMap(); + final Map paramNullability = new LinkedHashMap(); + boolean isStable; + int paramAnnotationOffset; + + NotNullState obtainParameterNullability(int index) { + NotNullState state = paramNullability.get(index); + if (state == null) { + state = new NotNullState(); + paramNullability.put(index, state); + } + return state; + } + } + private static final class MethodData { private String myClassName; - final Map> paramNames = new LinkedHashMap>(); - final Set alwaysNotNullMethods = new HashSet(); // methods we are 100% sure return a non-null value - - public void setClassName(String className) { - myClassName = className; - } + private final Map myMethodInfos = new HashMap(); static String key(String methodName, String desc) { return methodName + desc; } String lookupParamName(String methodName, String desc, Integer num) { - final Map names = paramNames.get(key(methodName, desc)); - return names != null? names.get(num) : null; - } - - void markNotNull(String methodName, String desc) { - alwaysNotNullMethods.add(key(methodName, desc)); + MethodInfo info = myMethodInfos.get(key(methodName, desc)); + Map names = info == null ? null : info.paramNames; + return names != null ? names.get(num) : null; } boolean isAlwaysNotNull(String className, String methodName, String desc) { - return myClassName.equals(className) && alwaysNotNullMethods.contains(key(methodName, desc)); + if (myClassName.equals(className)) { + MethodInfo info = myMethodInfos.get(key(methodName, desc)); + return info != null && info.isStable && info.nullability.isNotNull(); + } + return false; } } private static MethodData collectMethodData(ClassReader reader, final Set notNullAnnotations) { final MethodData result = new MethodData(); reader.accept(new ClassVisitor(Opcodes.API_VERSION) { + private boolean myEnum, myInner; @Override public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) { - result.setClassName(name); + super.visit(version, access, name, signature, superName, interfaces); + result.myClassName = name; + myEnum = (access & ACC_ENUM) != 0; + } + + @Override + public void visitInnerClass(String name, String outerName, String innerName, int access) { + super.visitInnerClass(name, outerName, innerName, access); + if (result.myClassName.equals(name)) { + myInner = (access & ACC_STATIC) == 0; + } } @Override public MethodVisitor visitMethod(int access, final String name, final String desc, String signature, String[] exceptions) { - final Map names = new LinkedHashMap(); - result.paramNames.put(MethodData.key(name, desc), names); - Type[] args = Type.getArgumentTypes(desc); - final boolean shouldRegisterNotNull = isReferenceType(Type.getReturnType(desc)) && - (access & (Opcodes.ACC_FINAL | Opcodes.ACC_STATIC | Opcodes.ACC_PRIVATE)) != 0; + final Type[] args = Type.getArgumentTypes(desc); + final boolean methodCanHaveNullability = isReferenceType(Type.getReturnType(desc)); final Map paramSlots = new LinkedHashMap(); // map: localVariableSlot -> methodParameterIndex int slotIndex = isStatic(access) ? 0 : 1; @@ -100,28 +115,98 @@ public class NotNullVerifyingInstrumenter extends ClassVisitor implements Opcode slotIndex += arg.getSize(); } + final MethodInfo methodInfo = new MethodInfo(); + methodInfo.isStable = (access & (Opcodes.ACC_FINAL | Opcodes.ACC_STATIC | Opcodes.ACC_PRIVATE)) != 0; + methodInfo.paramAnnotationOffset = !"".equals(name) ? 0 : myEnum ? 2 : myInner ? 1 : 0; + result.myMethodInfos.put(MethodData.key(name, desc), methodInfo); + return new MethodVisitor(api) { + private int myParamAnnotationOffset = methodInfo.paramAnnotationOffset; + + @Override + public void visitAnnotableParameterCount(int parameterCount, boolean visible) { + if (myParamAnnotationOffset != 0 && parameterCount == args.length) { + myParamAnnotationOffset = 0; + } + super.visitAnnotableParameterCount(parameterCount, visible); + } + + @Override + public AnnotationVisitor visitParameterAnnotation(int parameter, String anno, boolean visible) { + AnnotationVisitor base = super.visitParameterAnnotation(parameter, anno, visible); + return checkParameterNullability(parameter + myParamAnnotationOffset, anno, base, false); + } + @Override public AnnotationVisitor visitAnnotation(String anno, boolean isRuntime) { - if (shouldRegisterNotNull && notNullAnnotations.contains(anno)) { - result.markNotNull(name, desc); + AnnotationVisitor base = super.visitAnnotation(anno, isRuntime); + if (methodCanHaveNullability && notNullAnnotations.contains(anno)) { + return collectNotNullArgs(base, methodInfo.nullability.withNotNull(anno, ISE_CLASS_NAME)); } - return super.visitAnnotation(anno, isRuntime); + return base; } @Override public AnnotationVisitor visitTypeAnnotation(int typeRef, TypePath typePath, String anno, boolean visible) { - if (shouldRegisterNotNull && new TypeReference(typeRef).getSort() == TypeReference.METHOD_RETURN && notNullAnnotations.contains(anno)) { - result.markNotNull(name, desc); + AnnotationVisitor base = super.visitTypeAnnotation(typeRef, typePath, anno, visible); + if (typePath != null) return base; + + TypeReference ref = new TypeReference(typeRef); + if (methodCanHaveNullability && ref.getSort() == TypeReference.METHOD_RETURN) { + if (notNullAnnotations.contains(anno)) { + return collectNotNullArgs(base, methodInfo.nullability.withNotNull(anno, ISE_CLASS_NAME)); + } + else if (seemsNullable(anno)) { + methodInfo.nullability.hasTypeUseNullable = true; + } } - return super.visitTypeAnnotation(typeRef, typePath, anno, visible); + else if (ref.getSort() == TypeReference.METHOD_FORMAL_PARAMETER) { + return checkParameterNullability(ref.getFormalParameterIndex() + methodInfo.paramAnnotationOffset, anno, base, true); + } + + return base; + } + + private boolean seemsNullable(String anno) { + String shortName = getAnnoShortName(anno); + // use hardcoded short names until it causes trouble + // this is to avoid cumbersome passing of configured nullable names from the IDE + return shortName.contains("Nullable") || shortName.equals("CheckForNull"); + } + + private AnnotationVisitor collectNotNullArgs(AnnotationVisitor base, final NotNullState state) { + return new AnnotationVisitor(Opcodes.API_VERSION, base) { + @Override + public void visit(String methodName, Object o) { + if (ANNOTATION_DEFAULT_METHOD.equals(methodName) && !((String) o).isEmpty()) { + state.message = (String) o; + } + else if ("exception".equals(methodName) && o instanceof Type && !((Type)o).getClassName().equals(Exception.class.getName())) { + state.exceptionType = ((Type)o).getInternalName(); + } + super.visit(methodName, o); + } + }; + } + + private AnnotationVisitor checkParameterNullability(int parameter, String anno, AnnotationVisitor av, boolean typeUse) { + if (parameter >= 0 && parameter < args.length && isReferenceType(args[parameter])) { + if (notNullAnnotations.contains(anno)) { + return collectNotNullArgs(av, methodInfo.obtainParameterNullability(parameter).withNotNull(anno, IAE_CLASS_NAME)); + } + else if (typeUse && seemsNullable(anno)) { + methodInfo.obtainParameterNullability(parameter).hasTypeUseNullable = true; + } + } + + return av; } @Override public void visitLocalVariable(String name2, String desc, String signature, Label start, Label end, int slotIndex) { Integer paramIndex = paramSlots.get(slotIndex); if (paramIndex != null) { - names.put(paramIndex, name2); + methodInfo.paramNames.put(paramIndex, name2); } } }; @@ -130,142 +215,66 @@ public class NotNullVerifyingInstrumenter extends ClassVisitor implements Opcode return result; } - @Override - public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) { - super.visit(version, access, name, signature, superName, interfaces); - myClassName = name; - myEnum = (access & ACC_ENUM) != 0; - } - - @Override - public void visitInnerClass(String name, String outerName, String innerName, int access) { - super.visitInnerClass(name, outerName, innerName, access); - if (myClassName.equals(name)) { - myInner = (access & ACC_STATIC) == 0; - } - } - private static class NotNullState { String message; String exceptionType; - final String notNullAnno; + String notNullAnno; + boolean hasTypeUseNullable; - NotNullState(String notNullAnno, String exceptionType) { + NotNullState withNotNull(String notNullAnno, String exceptionType) { this.notNullAnno = notNullAnno; this.exceptionType = exceptionType; + return this; + } + + boolean isNotNull() { + return notNullAnno != null && !hasTypeUseNullable; } String getNullParamMessage(String paramName) { if (message != null) return message; - String shortName = getAnnoShortName(); + String shortName = getAnnoShortName(notNullAnno); if (paramName != null) return "Argument for @" + shortName + " parameter '%s' of %s.%s must not be null"; return "Argument %s for @" + shortName + " parameter of %s.%s must not be null"; } String getNullResultMessage() { if (message != null) return message; - String shortName = getAnnoShortName(); + String shortName = getAnnoShortName(notNullAnno); return "@" + shortName + " method %s.%s must not return null"; } + } - private String getAnnoShortName() { - String fullName = notNullAnno.substring(1, notNullAnno.length() - 1); // "Lpk/name;" -> "pk/name" - return fullName.substring(fullName.lastIndexOf('/') + 1); - } + private static String getAnnoShortName(String anno) { + String fullName = anno.substring(1, anno.length() - 1); // "Lpk/name;" -> "pk/name" + return fullName.substring(fullName.lastIndexOf('/') + 1); } @Override public MethodVisitor visitMethod(int access, final String name, final String desc, String signature, String[] exceptions) { - if ((access & Opcodes.ACC_BRIDGE) != 0) { + final MethodInfo info = myMethodData.myMethodInfos.get(MethodData.key(name, desc)); + if ((access & Opcodes.ACC_BRIDGE) != 0 || info == null) { return new FailSafeMethodVisitor(Opcodes.API_VERSION, super.visitMethod(access, name, desc, signature, exceptions)); } final boolean isStatic = isStatic(access); final Type[] args = Type.getArgumentTypes(desc); - final int paramAnnotationOffset = !"".equals(name) ? 0 : myEnum ? 2 : myInner ? 1 : 0; - - final Type returnType = Type.getReturnType(desc); final NotNullInstructionTracker instrTracker = new NotNullInstructionTracker(cv.visitMethod(access, name, desc, signature, exceptions)); return new FailSafeMethodVisitor(Opcodes.API_VERSION, instrTracker) { - private final Map myNotNullParams = new LinkedHashMap(); - private int myParamAnnotationOffset = paramAnnotationOffset; - private NotNullState myMethodNotNull; private Label myStartGeneratedCodeLabel; - private AnnotationVisitor collectNotNullArgs(AnnotationVisitor base, final NotNullState state) { - return new AnnotationVisitor(Opcodes.API_VERSION, base) { - @Override - public void visit(String methodName, Object o) { - if (ANNOTATION_DEFAULT_METHOD.equals(methodName) && !((String) o).isEmpty()) { - state.message = (String) o; - } - else if ("exception".equals(methodName) && o instanceof Type && !((Type)o).getClassName().equals(Exception.class.getName())) { - state.exceptionType = ((Type)o).getInternalName(); - } - super.visit(methodName, o); - } - }; - } - - @Override - public AnnotationVisitor visitTypeAnnotation(int typeRef, TypePath typePath, String desc, boolean visible) { - AnnotationVisitor base = mv.visitTypeAnnotation(typeRef, typePath, desc, visible); - if (typePath != null) return base; - - TypeReference ref = new TypeReference(typeRef); - if (ref.getSort() == TypeReference.METHOD_RETURN) { - return checkNotNullMethod(desc, base); - } - if (ref.getSort() == TypeReference.METHOD_FORMAL_PARAMETER) { - return checkNotNullParameter(ref.getFormalParameterIndex() + paramAnnotationOffset, desc, base); - } - return base; - } - - @Override - public void visitAnnotableParameterCount(int parameterCount, boolean visible) { - if (myParamAnnotationOffset != 0 && parameterCount == args.length) { - myParamAnnotationOffset = 0; - } - super.visitAnnotableParameterCount(parameterCount, visible); - } - - @Override - public AnnotationVisitor visitParameterAnnotation(int parameter, String anno, boolean visible) { - AnnotationVisitor base = mv.visitParameterAnnotation(parameter, anno, visible); - return checkNotNullParameter(parameter + myParamAnnotationOffset, anno, base); - } - - private AnnotationVisitor checkNotNullParameter(int parameter, String anno, AnnotationVisitor av) { - if (parameter >= 0 && parameter < args.length && isReferenceType(args[parameter]) && myNotNullAnnotations.contains(anno)) { - NotNullState state = new NotNullState(anno, IAE_CLASS_NAME); - myNotNullParams.put(parameter, state); - return collectNotNullArgs(av, state); - } - - return av; - } - - @Override - public AnnotationVisitor visitAnnotation(String anno, boolean isRuntime) { - return checkNotNullMethod(anno, mv.visitAnnotation(anno, isRuntime)); - } - - private AnnotationVisitor checkNotNullMethod(String anno, AnnotationVisitor base) { - if (isReferenceType(returnType) && myNotNullAnnotations.contains(anno)) { - myMethodNotNull = new NotNullState(anno, ISE_CLASS_NAME); - return collectNotNullArgs(base, myMethodNotNull); - } - return base; - } - @Override public void visitCode() { - if (myNotNullParams.size() > 0) { + for (Iterator iterator = info.paramNullability.values().iterator(); iterator.hasNext(); ) { + if (!iterator.next().isNotNull()) { + iterator.remove(); + } + } + if (info.paramNullability.size() > 0) { myStartGeneratedCodeLabel = new Label(); mv.visitLabel(myStartGeneratedCodeLabel); } - for (Map.Entry entry : myNotNullParams.entrySet()) { + for (Map.Entry entry : info.paramNullability.entrySet()) { Integer param = entry.getKey(); int var = isStatic ? 0 : 1; for (int i = 0; i < param; ++i) { @@ -281,7 +290,7 @@ public class NotNullVerifyingInstrumenter extends ClassVisitor implements Opcode String descrPattern = state.getNullParamMessage(paramName); String[] args = state.message != null ? EMPTY_STRING_ARRAY - : new String[]{paramName != null ? paramName : String.valueOf(param - paramAnnotationOffset), myClassName, name}; + : new String[]{paramName != null ? paramName : String.valueOf(param - info.paramAnnotationOffset), myMethodData.myClassName, name}; reportError(state.exceptionType, end, descrPattern, args); } } @@ -295,20 +304,20 @@ public class NotNullVerifyingInstrumenter extends ClassVisitor implements Opcode @Override public void visitInsn(int opcode) { - if (opcode == ARETURN && myMethodNotNull != null && instrTracker.canBeNull()) { + if (opcode == ARETURN && instrTracker.canBeNull() && info.nullability.isNotNull()) { mv.visitInsn(DUP); Label skipLabel = new Label(); mv.visitJumpInsn(IFNONNULL, skipLabel); - String descrPattern = myMethodNotNull.getNullResultMessage(); - String[] args = myMethodNotNull.message != null ? EMPTY_STRING_ARRAY : new String[]{myClassName, name}; - reportError(myMethodNotNull.exceptionType, skipLabel, descrPattern, args); + String descrPattern = info.nullability.getNullResultMessage(); + String[] args = info.nullability.message != null ? EMPTY_STRING_ARRAY : new String[]{myMethodData.myClassName, name}; + reportError(info.nullability.exceptionType, skipLabel, descrPattern, args); } mv.visitInsn(opcode); } private void reportError(String exceptionClass, Label end, String descrPattern, String[] args) { - myAuxGenerator.reportError(mv, myClassName, exceptionClass, descrPattern, args); + myAuxGenerator.reportError(mv, myMethodData.myClassName, exceptionClass, descrPattern, args); mv.visitLabel(end); myIsModification = true; processPostponedErrors(); @@ -352,7 +361,7 @@ public class NotNullVerifyingInstrumenter extends ClassVisitor implements Opcode t.printStackTrace(new PrintWriter(writer)); StringBuilder text = new StringBuilder(); - text.append("Operation '").append(operationName).append("' failed for ").append(myClassName).append(".").append(methodName).append("(): "); + text.append("Operation '").append(operationName).append("' failed for ").append(myMethodData.myClassName).append(".").append(methodName).append("(): "); if (message != null) text.append(message); text.append('\n').append(writer.getBuffer()); myPostponedError = new RuntimeException(text.toString(), cause); diff --git a/java/java-tests/testData/compiler/notNullVerification/TypeUseAndMemberAnnotationsOnArrays.java b/java/java-tests/testData/compiler/notNullVerification/TypeUseAndMemberAnnotationsOnArrays.java new file mode 100644 index 000000000000..1f18fc0a8e44 --- /dev/null +++ b/java/java-tests/testData/compiler/notNullVerification/TypeUseAndMemberAnnotationsOnArrays.java @@ -0,0 +1,16 @@ +import org.jetbrains.annotations.*; + +public class TypeUseAndMemberAnnotationsOnArrays { + + public void nullableArray(@NotNull String @Nullable [] query) {} + + public void notNullArray(@Nullable String @NotNull [] query) {} + + public @Nullable String @NotNull [] notNullReturn() { + return null; + } + + public @NotNull String @Nullable [] nullableReturn() { + return null; + } +} \ No newline at end of file diff --git a/java/java-tests/testData/compiler/notNullVerification/mixed/Nullable.java b/java/java-tests/testData/compiler/notNullVerification/mixed/Nullable.java new file mode 100644 index 000000000000..cbf8e107a6df --- /dev/null +++ b/java/java-tests/testData/compiler/notNullVerification/mixed/Nullable.java @@ -0,0 +1,9 @@ +package org.jetbrains.annotations; + +import java.lang.annotation.*; + +@Retention(RetentionPolicy.CLASS) +@Target({ElementType.METHOD, ElementType.FIELD, ElementType.PARAMETER, ElementType.LOCAL_VARIABLE, ElementType.TYPE_USE}) +public @interface Nullable { + String value() default ""; +} \ No newline at end of file diff --git a/java/java-tests/testData/compiler/notNullVerification/types/Nullable.java b/java/java-tests/testData/compiler/notNullVerification/types/Nullable.java new file mode 100644 index 000000000000..aa386cda4c21 --- /dev/null +++ b/java/java-tests/testData/compiler/notNullVerification/types/Nullable.java @@ -0,0 +1,9 @@ +package org.jetbrains.annotations; + +import java.lang.annotation.*; + +@Retention(RetentionPolicy.CLASS) +@Target({ElementType.TYPE_USE}) +public @interface Nullable { + String value() default ""; +} \ No newline at end of file diff --git a/java/java-tests/testSrc/com/intellij/java/compiler/notNullVerification/NotNullVerifyingInstrumenterTest.java b/java/java-tests/testSrc/com/intellij/java/compiler/notNullVerification/NotNullVerifyingInstrumenterTest.java index 79404d7bcbf9..dd296ceb99f9 100644 --- a/java/java-tests/testSrc/com/intellij/java/compiler/notNullVerification/NotNullVerifyingInstrumenterTest.java +++ b/java/java-tests/testSrc/com/intellij/java/compiler/notNullVerification/NotNullVerifyingInstrumenterTest.java @@ -50,10 +50,11 @@ public abstract class NotNullVerifyingInstrumenterTest { public static class MembersTargetTest extends NotNullVerifyingInstrumenterTest { } @TestDirectory("types") - public static class TypesTargetTest extends NotNullVerifyingInstrumenterTest { } + public static class TypesTargetTest extends WithTypeUse { } @TestDirectory("mixed") - public static class MixedTargetTest extends NotNullVerifyingInstrumenterTest { } + public static class MixedTargetTest extends WithTypeUse { + } private static final String TEST_DATA_PATH = "/compiler/notNullVerification/"; @@ -64,10 +65,12 @@ public abstract class NotNullVerifyingInstrumenterTest { public Statement apply(Statement base, Description description) { TestDirectory annotation = description.getAnnotation(TestDirectory.class); if (annotation == null) throw new IllegalArgumentException("Class " + description.getTestClass() + " misses @TestDirectory annotation"); - File source = new File(JavaTestUtil.getJavaTestDataPath() + TEST_DATA_PATH + annotation.value() + "/NotNull.java"); - if (!source.isFile()) throw new IllegalArgumentException("Cannot find annotation file at " + source); + File source = new File(JavaTestUtil.getJavaTestDataPath() + TEST_DATA_PATH + annotation.value()); + if (!source.isDirectory() || source.listFiles().length == 0) throw new IllegalArgumentException("Cannot find annotation file at " + source); classes = IoTestUtil.createTestDir("test-notNullInstrumenter-" + annotation.value()); - IdeaTestUtil.compileFile(source, classes); + for (File file : source.listFiles()) { + IdeaTestUtil.compileFile(file, classes); + } return super.apply(base, description); } @@ -288,6 +291,22 @@ public abstract class NotNullVerifyingInstrumenterTest { assertEquals(1, returnType.getAnnotatedReturnType().getAnnotations().length); } + public static abstract class WithTypeUse extends NotNullVerifyingInstrumenterTest { + @Test + public void testTypeUseAndMemberAnnotationsOnArrays() throws Exception { + Class test = prepareTest(); + Object instance = test.newInstance(); + + Object[] singleNullArg = {null}; + verifyCallThrowsException("Argument 0 for @NotNull parameter of TypeUseAndMemberAnnotationsOnArrays.notNullArray must not be null", instance, test.getMethod("notNullArray", String[].class), singleNullArg); + test.getMethod("nullableArray", String[].class).invoke(instance, singleNullArg); + + verifyCallThrowsException("@NotNull method TypeUseAndMemberAnnotationsOnArrays.notNullReturn must not return null", instance, test.getMethod("notNullReturn")); + assertNull(test.getMethod("nullableReturn").invoke(instance)); + } + + } + @Test public void testMalformedBytecode() throws Exception { Class testClass = prepareTest();