don't add @NotNull instrumentation for TYPE_USE places where @Nullable is also present (IDEA-211172)

GitOrigin-RevId: 46e7571d34c08de56fbe5ae960c67bb64bc398bd
This commit is contained in:
peter
2019-05-13 08:38:10 +02:00
committed by intellij-monorepo-bot
parent 7969535387
commit e68bbc919e
5 changed files with 209 additions and 147 deletions

View File

@@ -7,10 +7,7 @@ import org.jetbrains.org.objectweb.asm.*;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Set;
import java.util.*;
/**
* @author ven
@@ -25,20 +22,17 @@ public class NotNullVerifyingInstrumenter extends ClassVisitor implements Opcode
private static final String[] EMPTY_STRING_ARRAY = new String[0];
private final MethodData myMethodData;
private String myClassName;
private boolean myIsModification = false;
private RuntimeException myPostponedError;
private final AuxiliaryMethodGenerator myAuxGenerator;
private final Set<String> myNotNullAnnotations = new HashSet<String>();
private boolean myEnum;
private boolean myInner;
private NotNullVerifyingInstrumenter(ClassVisitor classVisitor, ClassReader reader, String[] notNullAnnotations) {
super(Opcodes.API_VERSION, classVisitor);
Set<String> annoSet = new HashSet<String>();
for (String annotation : notNullAnnotations) {
myNotNullAnnotations.add('L' + annotation.replace('.', '/') + ';');
annoSet.add('L' + annotation.replace('.', '/') + ';');
}
myMethodData = collectMethodData(reader, myNotNullAnnotations);
myMethodData = collectMethodData(reader, annoSet);
myAuxGenerator = new AuxiliaryMethodGenerator(reader);
}
@@ -48,49 +42,70 @@ public class NotNullVerifyingInstrumenter extends ClassVisitor implements Opcode
return instrumenter.myIsModification;
}
private static class MethodInfo {
final NotNullState nullability = new NotNullState();
final Map<Integer, String> paramNames = new HashMap<Integer, String>();
final Map<Integer, NotNullState> paramNullability = new LinkedHashMap<Integer, NotNullState>();
boolean isStable;
int paramAnnotationOffset;
NotNullState obtainParameterNullability(int index) {
NotNullState state = paramNullability.get(index);
if (state == null) {
state = new NotNullState();
paramNullability.put(index, state);
}
return state;
}
}
private static final class MethodData {
private String myClassName;
final Map<String, Map<Integer, String>> paramNames = new LinkedHashMap<String, Map<Integer, String>>();
final Set<String> alwaysNotNullMethods = new HashSet<String>(); // methods we are 100% sure return a non-null value
public void setClassName(String className) {
myClassName = className;
}
private final Map<String, MethodInfo> myMethodInfos = new HashMap<String, MethodInfo>();
static String key(String methodName, String desc) {
return methodName + desc;
}
String lookupParamName(String methodName, String desc, Integer num) {
final Map<Integer, String> names = paramNames.get(key(methodName, desc));
return names != null? names.get(num) : null;
}
void markNotNull(String methodName, String desc) {
alwaysNotNullMethods.add(key(methodName, desc));
MethodInfo info = myMethodInfos.get(key(methodName, desc));
Map<Integer, String> names = info == null ? null : info.paramNames;
return names != null ? names.get(num) : null;
}
boolean isAlwaysNotNull(String className, String methodName, String desc) {
return myClassName.equals(className) && alwaysNotNullMethods.contains(key(methodName, desc));
if (myClassName.equals(className)) {
MethodInfo info = myMethodInfos.get(key(methodName, desc));
return info != null && info.isStable && info.nullability.isNotNull();
}
return false;
}
}
private static MethodData collectMethodData(ClassReader reader, final Set<String> notNullAnnotations) {
final MethodData result = new MethodData();
reader.accept(new ClassVisitor(Opcodes.API_VERSION) {
private boolean myEnum, myInner;
@Override
public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
result.setClassName(name);
super.visit(version, access, name, signature, superName, interfaces);
result.myClassName = name;
myEnum = (access & ACC_ENUM) != 0;
}
@Override
public void visitInnerClass(String name, String outerName, String innerName, int access) {
super.visitInnerClass(name, outerName, innerName, access);
if (result.myClassName.equals(name)) {
myInner = (access & ACC_STATIC) == 0;
}
}
@Override
public MethodVisitor visitMethod(int access, final String name, final String desc, String signature, String[] exceptions) {
final Map<Integer, String> names = new LinkedHashMap<Integer, String>();
result.paramNames.put(MethodData.key(name, desc), names);
Type[] args = Type.getArgumentTypes(desc);
final boolean shouldRegisterNotNull = isReferenceType(Type.getReturnType(desc)) &&
(access & (Opcodes.ACC_FINAL | Opcodes.ACC_STATIC | Opcodes.ACC_PRIVATE)) != 0;
final Type[] args = Type.getArgumentTypes(desc);
final boolean methodCanHaveNullability = isReferenceType(Type.getReturnType(desc));
final Map<Integer, Integer> paramSlots = new LinkedHashMap<Integer, Integer>(); // map: localVariableSlot -> methodParameterIndex
int slotIndex = isStatic(access) ? 0 : 1;
@@ -100,28 +115,98 @@ public class NotNullVerifyingInstrumenter extends ClassVisitor implements Opcode
slotIndex += arg.getSize();
}
final MethodInfo methodInfo = new MethodInfo();
methodInfo.isStable = (access & (Opcodes.ACC_FINAL | Opcodes.ACC_STATIC | Opcodes.ACC_PRIVATE)) != 0;
methodInfo.paramAnnotationOffset = !"<init>".equals(name) ? 0 : myEnum ? 2 : myInner ? 1 : 0;
result.myMethodInfos.put(MethodData.key(name, desc), methodInfo);
return new MethodVisitor(api) {
private int myParamAnnotationOffset = methodInfo.paramAnnotationOffset;
@Override
public void visitAnnotableParameterCount(int parameterCount, boolean visible) {
if (myParamAnnotationOffset != 0 && parameterCount == args.length) {
myParamAnnotationOffset = 0;
}
super.visitAnnotableParameterCount(parameterCount, visible);
}
@Override
public AnnotationVisitor visitParameterAnnotation(int parameter, String anno, boolean visible) {
AnnotationVisitor base = super.visitParameterAnnotation(parameter, anno, visible);
return checkParameterNullability(parameter + myParamAnnotationOffset, anno, base, false);
}
@Override
public AnnotationVisitor visitAnnotation(String anno, boolean isRuntime) {
if (shouldRegisterNotNull && notNullAnnotations.contains(anno)) {
result.markNotNull(name, desc);
AnnotationVisitor base = super.visitAnnotation(anno, isRuntime);
if (methodCanHaveNullability && notNullAnnotations.contains(anno)) {
return collectNotNullArgs(base, methodInfo.nullability.withNotNull(anno, ISE_CLASS_NAME));
}
return super.visitAnnotation(anno, isRuntime);
return base;
}
@Override
public AnnotationVisitor visitTypeAnnotation(int typeRef, TypePath typePath, String anno, boolean visible) {
if (shouldRegisterNotNull && new TypeReference(typeRef).getSort() == TypeReference.METHOD_RETURN && notNullAnnotations.contains(anno)) {
result.markNotNull(name, desc);
AnnotationVisitor base = super.visitTypeAnnotation(typeRef, typePath, anno, visible);
if (typePath != null) return base;
TypeReference ref = new TypeReference(typeRef);
if (methodCanHaveNullability && ref.getSort() == TypeReference.METHOD_RETURN) {
if (notNullAnnotations.contains(anno)) {
return collectNotNullArgs(base, methodInfo.nullability.withNotNull(anno, ISE_CLASS_NAME));
}
else if (seemsNullable(anno)) {
methodInfo.nullability.hasTypeUseNullable = true;
}
}
return super.visitTypeAnnotation(typeRef, typePath, anno, visible);
else if (ref.getSort() == TypeReference.METHOD_FORMAL_PARAMETER) {
return checkParameterNullability(ref.getFormalParameterIndex() + methodInfo.paramAnnotationOffset, anno, base, true);
}
return base;
}
private boolean seemsNullable(String anno) {
String shortName = getAnnoShortName(anno);
// use hardcoded short names until it causes trouble
// this is to avoid cumbersome passing of configured nullable names from the IDE
return shortName.contains("Nullable") || shortName.equals("CheckForNull");
}
private AnnotationVisitor collectNotNullArgs(AnnotationVisitor base, final NotNullState state) {
return new AnnotationVisitor(Opcodes.API_VERSION, base) {
@Override
public void visit(String methodName, Object o) {
if (ANNOTATION_DEFAULT_METHOD.equals(methodName) && !((String) o).isEmpty()) {
state.message = (String) o;
}
else if ("exception".equals(methodName) && o instanceof Type && !((Type)o).getClassName().equals(Exception.class.getName())) {
state.exceptionType = ((Type)o).getInternalName();
}
super.visit(methodName, o);
}
};
}
private AnnotationVisitor checkParameterNullability(int parameter, String anno, AnnotationVisitor av, boolean typeUse) {
if (parameter >= 0 && parameter < args.length && isReferenceType(args[parameter])) {
if (notNullAnnotations.contains(anno)) {
return collectNotNullArgs(av, methodInfo.obtainParameterNullability(parameter).withNotNull(anno, IAE_CLASS_NAME));
}
else if (typeUse && seemsNullable(anno)) {
methodInfo.obtainParameterNullability(parameter).hasTypeUseNullable = true;
}
}
return av;
}
@Override
public void visitLocalVariable(String name2, String desc, String signature, Label start, Label end, int slotIndex) {
Integer paramIndex = paramSlots.get(slotIndex);
if (paramIndex != null) {
names.put(paramIndex, name2);
methodInfo.paramNames.put(paramIndex, name2);
}
}
};
@@ -130,142 +215,66 @@ public class NotNullVerifyingInstrumenter extends ClassVisitor implements Opcode
return result;
}
@Override
public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
super.visit(version, access, name, signature, superName, interfaces);
myClassName = name;
myEnum = (access & ACC_ENUM) != 0;
}
@Override
public void visitInnerClass(String name, String outerName, String innerName, int access) {
super.visitInnerClass(name, outerName, innerName, access);
if (myClassName.equals(name)) {
myInner = (access & ACC_STATIC) == 0;
}
}
private static class NotNullState {
String message;
String exceptionType;
final String notNullAnno;
String notNullAnno;
boolean hasTypeUseNullable;
NotNullState(String notNullAnno, String exceptionType) {
NotNullState withNotNull(String notNullAnno, String exceptionType) {
this.notNullAnno = notNullAnno;
this.exceptionType = exceptionType;
return this;
}
boolean isNotNull() {
return notNullAnno != null && !hasTypeUseNullable;
}
String getNullParamMessage(String paramName) {
if (message != null) return message;
String shortName = getAnnoShortName();
String shortName = getAnnoShortName(notNullAnno);
if (paramName != null) return "Argument for @" + shortName + " parameter '%s' of %s.%s must not be null";
return "Argument %s for @" + shortName + " parameter of %s.%s must not be null";
}
String getNullResultMessage() {
if (message != null) return message;
String shortName = getAnnoShortName();
String shortName = getAnnoShortName(notNullAnno);
return "@" + shortName + " method %s.%s must not return null";
}
}
private String getAnnoShortName() {
String fullName = notNullAnno.substring(1, notNullAnno.length() - 1); // "Lpk/name;" -> "pk/name"
return fullName.substring(fullName.lastIndexOf('/') + 1);
}
private static String getAnnoShortName(String anno) {
String fullName = anno.substring(1, anno.length() - 1); // "Lpk/name;" -> "pk/name"
return fullName.substring(fullName.lastIndexOf('/') + 1);
}
@Override
public MethodVisitor visitMethod(int access, final String name, final String desc, String signature, String[] exceptions) {
if ((access & Opcodes.ACC_BRIDGE) != 0) {
final MethodInfo info = myMethodData.myMethodInfos.get(MethodData.key(name, desc));
if ((access & Opcodes.ACC_BRIDGE) != 0 || info == null) {
return new FailSafeMethodVisitor(Opcodes.API_VERSION, super.visitMethod(access, name, desc, signature, exceptions));
}
final boolean isStatic = isStatic(access);
final Type[] args = Type.getArgumentTypes(desc);
final int paramAnnotationOffset = !"<init>".equals(name) ? 0 : myEnum ? 2 : myInner ? 1 : 0;
final Type returnType = Type.getReturnType(desc);
final NotNullInstructionTracker instrTracker = new NotNullInstructionTracker(cv.visitMethod(access, name, desc, signature, exceptions));
return new FailSafeMethodVisitor(Opcodes.API_VERSION, instrTracker) {
private final Map<Integer, NotNullState> myNotNullParams = new LinkedHashMap<Integer, NotNullState>();
private int myParamAnnotationOffset = paramAnnotationOffset;
private NotNullState myMethodNotNull;
private Label myStartGeneratedCodeLabel;
private AnnotationVisitor collectNotNullArgs(AnnotationVisitor base, final NotNullState state) {
return new AnnotationVisitor(Opcodes.API_VERSION, base) {
@Override
public void visit(String methodName, Object o) {
if (ANNOTATION_DEFAULT_METHOD.equals(methodName) && !((String) o).isEmpty()) {
state.message = (String) o;
}
else if ("exception".equals(methodName) && o instanceof Type && !((Type)o).getClassName().equals(Exception.class.getName())) {
state.exceptionType = ((Type)o).getInternalName();
}
super.visit(methodName, o);
}
};
}
@Override
public AnnotationVisitor visitTypeAnnotation(int typeRef, TypePath typePath, String desc, boolean visible) {
AnnotationVisitor base = mv.visitTypeAnnotation(typeRef, typePath, desc, visible);
if (typePath != null) return base;
TypeReference ref = new TypeReference(typeRef);
if (ref.getSort() == TypeReference.METHOD_RETURN) {
return checkNotNullMethod(desc, base);
}
if (ref.getSort() == TypeReference.METHOD_FORMAL_PARAMETER) {
return checkNotNullParameter(ref.getFormalParameterIndex() + paramAnnotationOffset, desc, base);
}
return base;
}
@Override
public void visitAnnotableParameterCount(int parameterCount, boolean visible) {
if (myParamAnnotationOffset != 0 && parameterCount == args.length) {
myParamAnnotationOffset = 0;
}
super.visitAnnotableParameterCount(parameterCount, visible);
}
@Override
public AnnotationVisitor visitParameterAnnotation(int parameter, String anno, boolean visible) {
AnnotationVisitor base = mv.visitParameterAnnotation(parameter, anno, visible);
return checkNotNullParameter(parameter + myParamAnnotationOffset, anno, base);
}
private AnnotationVisitor checkNotNullParameter(int parameter, String anno, AnnotationVisitor av) {
if (parameter >= 0 && parameter < args.length && isReferenceType(args[parameter]) && myNotNullAnnotations.contains(anno)) {
NotNullState state = new NotNullState(anno, IAE_CLASS_NAME);
myNotNullParams.put(parameter, state);
return collectNotNullArgs(av, state);
}
return av;
}
@Override
public AnnotationVisitor visitAnnotation(String anno, boolean isRuntime) {
return checkNotNullMethod(anno, mv.visitAnnotation(anno, isRuntime));
}
private AnnotationVisitor checkNotNullMethod(String anno, AnnotationVisitor base) {
if (isReferenceType(returnType) && myNotNullAnnotations.contains(anno)) {
myMethodNotNull = new NotNullState(anno, ISE_CLASS_NAME);
return collectNotNullArgs(base, myMethodNotNull);
}
return base;
}
@Override
public void visitCode() {
if (myNotNullParams.size() > 0) {
for (Iterator<NotNullState> iterator = info.paramNullability.values().iterator(); iterator.hasNext(); ) {
if (!iterator.next().isNotNull()) {
iterator.remove();
}
}
if (info.paramNullability.size() > 0) {
myStartGeneratedCodeLabel = new Label();
mv.visitLabel(myStartGeneratedCodeLabel);
}
for (Map.Entry<Integer, NotNullState> entry : myNotNullParams.entrySet()) {
for (Map.Entry<Integer, NotNullState> entry : info.paramNullability.entrySet()) {
Integer param = entry.getKey();
int var = isStatic ? 0 : 1;
for (int i = 0; i < param; ++i) {
@@ -281,7 +290,7 @@ public class NotNullVerifyingInstrumenter extends ClassVisitor implements Opcode
String descrPattern = state.getNullParamMessage(paramName);
String[] args = state.message != null
? EMPTY_STRING_ARRAY
: new String[]{paramName != null ? paramName : String.valueOf(param - paramAnnotationOffset), myClassName, name};
: new String[]{paramName != null ? paramName : String.valueOf(param - info.paramAnnotationOffset), myMethodData.myClassName, name};
reportError(state.exceptionType, end, descrPattern, args);
}
}
@@ -295,20 +304,20 @@ public class NotNullVerifyingInstrumenter extends ClassVisitor implements Opcode
@Override
public void visitInsn(int opcode) {
if (opcode == ARETURN && myMethodNotNull != null && instrTracker.canBeNull()) {
if (opcode == ARETURN && instrTracker.canBeNull() && info.nullability.isNotNull()) {
mv.visitInsn(DUP);
Label skipLabel = new Label();
mv.visitJumpInsn(IFNONNULL, skipLabel);
String descrPattern = myMethodNotNull.getNullResultMessage();
String[] args = myMethodNotNull.message != null ? EMPTY_STRING_ARRAY : new String[]{myClassName, name};
reportError(myMethodNotNull.exceptionType, skipLabel, descrPattern, args);
String descrPattern = info.nullability.getNullResultMessage();
String[] args = info.nullability.message != null ? EMPTY_STRING_ARRAY : new String[]{myMethodData.myClassName, name};
reportError(info.nullability.exceptionType, skipLabel, descrPattern, args);
}
mv.visitInsn(opcode);
}
private void reportError(String exceptionClass, Label end, String descrPattern, String[] args) {
myAuxGenerator.reportError(mv, myClassName, exceptionClass, descrPattern, args);
myAuxGenerator.reportError(mv, myMethodData.myClassName, exceptionClass, descrPattern, args);
mv.visitLabel(end);
myIsModification = true;
processPostponedErrors();
@@ -352,7 +361,7 @@ public class NotNullVerifyingInstrumenter extends ClassVisitor implements Opcode
t.printStackTrace(new PrintWriter(writer));
StringBuilder text = new StringBuilder();
text.append("Operation '").append(operationName).append("' failed for ").append(myClassName).append(".").append(methodName).append("(): ");
text.append("Operation '").append(operationName).append("' failed for ").append(myMethodData.myClassName).append(".").append(methodName).append("(): ");
if (message != null) text.append(message);
text.append('\n').append(writer.getBuffer());
myPostponedError = new RuntimeException(text.toString(), cause);

View File

@@ -0,0 +1,16 @@
import org.jetbrains.annotations.*;
public class TypeUseAndMemberAnnotationsOnArrays {
public void nullableArray(@NotNull String @Nullable [] query) {}
public void notNullArray(@Nullable String @NotNull [] query) {}
public @Nullable String @NotNull [] notNullReturn() {
return null;
}
public @NotNull String @Nullable [] nullableReturn() {
return null;
}
}

View File

@@ -0,0 +1,9 @@
package org.jetbrains.annotations;
import java.lang.annotation.*;
@Retention(RetentionPolicy.CLASS)
@Target({ElementType.METHOD, ElementType.FIELD, ElementType.PARAMETER, ElementType.LOCAL_VARIABLE, ElementType.TYPE_USE})
public @interface Nullable {
String value() default "";
}

View File

@@ -0,0 +1,9 @@
package org.jetbrains.annotations;
import java.lang.annotation.*;
@Retention(RetentionPolicy.CLASS)
@Target({ElementType.TYPE_USE})
public @interface Nullable {
String value() default "";
}

View File

@@ -50,10 +50,11 @@ public abstract class NotNullVerifyingInstrumenterTest {
public static class MembersTargetTest extends NotNullVerifyingInstrumenterTest { }
@TestDirectory("types")
public static class TypesTargetTest extends NotNullVerifyingInstrumenterTest { }
public static class TypesTargetTest extends WithTypeUse { }
@TestDirectory("mixed")
public static class MixedTargetTest extends NotNullVerifyingInstrumenterTest { }
public static class MixedTargetTest extends WithTypeUse {
}
private static final String TEST_DATA_PATH = "/compiler/notNullVerification/";
@@ -64,10 +65,12 @@ public abstract class NotNullVerifyingInstrumenterTest {
public Statement apply(Statement base, Description description) {
TestDirectory annotation = description.getAnnotation(TestDirectory.class);
if (annotation == null) throw new IllegalArgumentException("Class " + description.getTestClass() + " misses @TestDirectory annotation");
File source = new File(JavaTestUtil.getJavaTestDataPath() + TEST_DATA_PATH + annotation.value() + "/NotNull.java");
if (!source.isFile()) throw new IllegalArgumentException("Cannot find annotation file at " + source);
File source = new File(JavaTestUtil.getJavaTestDataPath() + TEST_DATA_PATH + annotation.value());
if (!source.isDirectory() || source.listFiles().length == 0) throw new IllegalArgumentException("Cannot find annotation file at " + source);
classes = IoTestUtil.createTestDir("test-notNullInstrumenter-" + annotation.value());
IdeaTestUtil.compileFile(source, classes);
for (File file : source.listFiles()) {
IdeaTestUtil.compileFile(file, classes);
}
return super.apply(base, description);
}
@@ -288,6 +291,22 @@ public abstract class NotNullVerifyingInstrumenterTest {
assertEquals(1, returnType.getAnnotatedReturnType().getAnnotations().length);
}
public static abstract class WithTypeUse extends NotNullVerifyingInstrumenterTest {
@Test
public void testTypeUseAndMemberAnnotationsOnArrays() throws Exception {
Class<?> test = prepareTest();
Object instance = test.newInstance();
Object[] singleNullArg = {null};
verifyCallThrowsException("Argument 0 for @NotNull parameter of TypeUseAndMemberAnnotationsOnArrays.notNullArray must not be null", instance, test.getMethod("notNullArray", String[].class), singleNullArg);
test.getMethod("nullableArray", String[].class).invoke(instance, singleNullArg);
verifyCallThrowsException("@NotNull method TypeUseAndMemberAnnotationsOnArrays.notNullReturn must not return null", instance, test.getMethod("notNullReturn"));
assertNull(test.getMethod("nullableReturn").invoke(instance));
}
}
@Test
public void testMalformedBytecode() throws Exception {
Class<?> testClass = prepareTest();