Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion AssemblyToProcess/AssemblyToProcess.csproj
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
<?xml version="1.0" encoding="utf-8"?>
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFrameworks>net472;net6.0</TargetFrameworks>
<TargetFrameworks>net472;net10.0;net11.0</TargetFrameworks>
<DisableFody>true</DisableFody>
<Nullable>enable</Nullable>
</PropertyGroup>
<PropertyGroup Condition="'$(TargetFramework)'=='net11.0'">
<EnablePreviewFeatures>true</EnablePreviewFeatures>
<Features>$(Features);runtime-async=on</Features>
<DefineConstants>$(DefineConstants);NET11_0</DefineConstants>
</PropertyGroup>
<ItemGroup>
<Using Remove="System.Net.Http" />
<ProjectReference Include="..\ConfigureAwait\ConfigureAwait.csproj" />
Expand Down
2 changes: 1 addition & 1 deletion AssemblyToProcess/CatchAndFinally.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ public async Task Finally3()
}
}

#if NETCOREAPP2_0
#if NET
public async Task Catch1_WithValueTask()
{
try
Expand Down
17 changes: 12 additions & 5 deletions AssemblyToProcess/ClassWithAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ public async Task<int> AsyncMethodWithReturn(SynchronizationContext context)
public async Task AsyncGenericMethod(SynchronizationContext context)
{
SynchronizationContext.SetSynchronizationContext(context);
await Task.Run(() => 10);
await Task.Run(async () => await Return10());
}

public async Task<int> AsyncGenericMethodWithReturn(SynchronizationContext context)
{
SynchronizationContext.SetSynchronizationContext(context);
return await Task.Run(() => 10);
return await Task.Run(async () => await Return10());
}

#if NETCOREAPP2_0
#if NET
public async Task AsyncMethod_WithValueTask(SynchronizationContext context)
{
SynchronizationContext.SetSynchronizationContext(context);
Expand All @@ -45,13 +45,20 @@ public async Task<int> AsyncMethodWithReturn_WithValueTask(SynchronizationContex
public async Task AsyncGenericMethod_WithValueTask(SynchronizationContext context)
{
SynchronizationContext.SetSynchronizationContext(context);
await new ValueTask(Task.Run(() => 10));
await new ValueTask(Task.Run(async () => await Return10()));
}

public async Task<int> AsyncGenericMethodWithReturn_WithValueTask(SynchronizationContext context)
{
SynchronizationContext.SetSynchronizationContext(context);
return await new ValueTask<int>(Task.Run(() => 10));
return await new ValueTask<int>(Task.Run(async () => await Return10()));
}
#endif

// using some more complex task than async () => 10, to make sure the method is not optimized away by the compiler, which would make the test fail;
async Task<int> Return10()
{
await Task.Delay(10).ConfigureAwait(false);
return 10;
}
}
18 changes: 13 additions & 5 deletions AssemblyToProcess/DoNotWeave.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ public async Task<int> AsyncMethodWithReturn(SynchronizationContext context)
public async Task AsyncGenericMethod(SynchronizationContext context)
{
SynchronizationContext.SetSynchronizationContext(context);
await Task.Run(() => 10).ConfigureAwait(true);
await Task.Run(async () => await Return10()).ConfigureAwait(true);
}

public async Task<int> AsyncGenericMethodWithReturn(SynchronizationContext context)
{
SynchronizationContext.SetSynchronizationContext(context);
return await Task.Run(() => 10).ConfigureAwait(true);
return await Task.Run(async () => await Return10()).ConfigureAwait(true);
}

#if NETCOREAPP2_0
#if NET
public async Task AsyncMethod_WithValueTask(SynchronizationContext context)
{
SynchronizationContext.SetSynchronizationContext(context);
Expand All @@ -45,13 +45,21 @@ public async Task<int> AsyncMethodWithReturn_WithValueTask(SynchronizationContex
public async Task AsyncGenericMethod_WithValueTask(SynchronizationContext context)
{
SynchronizationContext.SetSynchronizationContext(context);
await new ValueTask(Task.Run(() => 10)).ConfigureAwait(true);
await new ValueTask(Task.Run(async () => await Return10())).ConfigureAwait(true);
}

public async Task<int> AsyncGenericMethodWithReturn_WithValueTask(SynchronizationContext context)
{
SynchronizationContext.SetSynchronizationContext(context);
return await new ValueTask<int>(Task.Run(() => 10)).ConfigureAwait(true);
return await new ValueTask<int>(Task.Run(async () => await Return10())).ConfigureAwait(true);
}

#endif

// using some more complex task than () => 10, to make sure the method is not optimized away by the compiler, which would make the test fail;
async Task<int> Return10()
{
await Task.Delay(10).ConfigureAwait(false);
return 10;
}
}
2 changes: 1 addition & 1 deletion AssemblyToProcess/Example.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public async Task<Example> AsyncMethod12()
return result;
}

#if NETCOREAPP2_0
#if NET
public async Task AsyncMethod1_WithValueTask()
{
await new ValueTask(Task.Delay(1));
Expand Down
3 changes: 1 addition & 2 deletions AssemblyToProcess/FlagSynchronizationContext.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
public class FlagSynchronizationContext :
SynchronizationContext
public class FlagSynchronizationContext : SynchronizationContext
{
public bool Flag { get; set; }

Expand Down
4 changes: 2 additions & 2 deletions AssemblyToProcess/GenericIssue.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public async Task Method(Task<TItem> itemTask)
var item = await itemTask;
}

#if NETCOREAPP2_0
#if NET
[ConfigureAwait(false)]
public async Task Method_WithValueTask(Task<TItem> itemTask)
{
Expand All @@ -26,7 +26,7 @@ public async Task Method<TItem>(Task<TItem> itemTask)
var item = await itemTask;
}

#if NETCOREAPP2_0
#if NET
[ConfigureAwait(false)]
public async Task Method_WithValueTask<TItem>(Task<TItem> itemTask)
{
Expand Down
4 changes: 2 additions & 2 deletions AssemblyToProcess/Issue1.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ async Task WithReaderAndWriter(TextWriter writer, StreamReader reader)
}
}

#if NETCOREAPP2_0
#if NET
[ConfigureAwait(false)]
async Task WithReaderAndWriter_WithValueTask(TextWriter writer, StreamReader reader)
{
while (await new ValueTask<string>(reader.ReadLineAsync()) is { } line)
while (await new ValueTask<string?>(reader.ReadLineAsync()) is { } line)
{
await new ValueTask(writer.WriteLineAsync(line));
}
Expand Down
2 changes: 1 addition & 1 deletion AssemblyToProcess/MethodWithAttribute.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ public async Task AsyncMethod(SynchronizationContext context)
await Task.Delay(0);
}

#if NETCOREAPP2_0
#if NET
[ConfigureAwait(false)]
public async Task AsyncMethod_WithValueTask(SynchronizationContext context)
{
Expand Down
2 changes: 1 addition & 1 deletion AssemblyToProcess/MethodWithUsing.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ static async Task<IDisposable> NewMethod()
return new MyDisposable();
}

#if NETCOREAPP2_0
#if NET
[ConfigureAwait(false)]
public async Task AsyncMethod_WithValueTask()
{
Expand Down
31 changes: 23 additions & 8 deletions ConfigureAwait.Fody/CecilExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,18 @@
using Mono.Cecil;
using Mono.Cecil.Cil;

enum AsyncStateMachineKind
{
None,
StateMachine,
CompilerService
}

static class CecilExtensions
{
// Not yet defined in Cecil, remove later when update is available.
const MethodImplAttributes MethodImplAttributes_Async = (MethodImplAttributes)0x2000;

public static bool IsIAsyncStateMachine(this TypeDefinition type)
{
if (type is not {HasInterfaces: true})
Expand All @@ -12,12 +22,12 @@ public static bool IsIAsyncStateMachine(this TypeDefinition type)
}

return type.Interfaces
.Any(_ => _.InterfaceType.FullName == "System.Runtime.CompilerServices.IAsyncStateMachine");
.Any(item => item.InterfaceType.FullName == "System.Runtime.CompilerServices.IAsyncStateMachine");
}

public static MethodDefinition Method(this TypeDefinition type, MethodReference reference)
{
return type.Methods.FirstOrDefault(_ => _.Name == reference.Name);
return type.Methods.FirstOrDefault(item => item.Name == reference.Name);
}

public static bool IsCompilerGenerated(this ICustomAttributeProvider provider)
Expand All @@ -39,10 +49,15 @@ public static void InsertBefore(this ILProcessor processor, Instruction target,
}
}

public static bool IsAsyncStateMachineType(this ICustomAttributeProvider provider)
public static AsyncStateMachineKind GetAsyncStateMachineKind(this MethodDefinition method)
{
return provider.CustomAttributes
.Any(a => a.AttributeType.FullName == "System.Runtime.CompilerServices.AsyncStateMachineAttribute");
if (method.CustomAttributes.Any(a => a.AttributeType.FullName == "System.Runtime.CompilerServices.AsyncStateMachineAttribute"))
return AsyncStateMachineKind.StateMachine;

if (method.ImplAttributes.HasFlag(MethodImplAttributes_Async))
return AsyncStateMachineKind.CompilerService;

return AsyncStateMachineKind.None;
}

public static TypeDefinition GetAsyncStateMachineType(this ICustomAttributeProvider provider)
Expand All @@ -53,7 +68,7 @@ public static TypeDefinition GetAsyncStateMachineType(this ICustomAttributeProvi
return (TypeDefinition)attribute?.ConstructorArguments[0].Value;
}

public static CustomAttribute GetConfigureAwaitAttribute(this ICustomAttributeProvider value)
static CustomAttribute GetConfigureAwaitAttribute(this ICustomAttributeProvider value)
{
return value.CustomAttributes.FirstOrDefault(a => a.AttributeType.FullName == "Fody.ConfigureAwaitAttribute");
}
Expand All @@ -66,11 +81,11 @@ public static CustomAttribute GetConfigureAwaitAttribute(this ICustomAttributePr
return defaultValue;
}

if (value is MethodDefinition method &&
!method.IsAsyncStateMachineType())
if (value is MethodDefinition method && method.GetAsyncStateMachineKind() == AsyncStateMachineKind.None)
{
throw new WeavingException($"ConfigureAwaitAttribute applied to non-async method '{method.FullName}'.");
}

return (bool?)attribute.ConstructorArguments[0].Value;
}

Expand Down
2 changes: 1 addition & 1 deletion ConfigureAwait.Fody/ConfigureAwait.Fody.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@
</PropertyGroup>
<ItemGroup>
<Using Remove="System.Net.Http" />
<PackageReference Include="FodyHelpers" Version="6.8.0" />
<PackageReference Include="FodyHelpers" Version="6.9.3" />
</ItemGroup>
</Project>
33 changes: 26 additions & 7 deletions ConfigureAwait.Fody/ModuleWeaver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ public override void Execute()
{
ReadConfig();

FindTypes();
FindRuntimeTypes();

var configureAwaitValue = ModuleDefinition.Assembly.GetConfigureAwaitConfig(continueOnCapturedContext);

Expand Down Expand Up @@ -51,20 +51,23 @@ void ProcessType(bool? assemblyConfigureAwaitValue, TypeDefinition type)
return;
}

var configureAwaitValue = type.GetConfigureAwaitConfig(assemblyConfigureAwaitValue);
var typeConfigureAwaitValue = type.GetConfigureAwaitConfig(assemblyConfigureAwaitValue);

foreach (var method in type.Methods)
{
var localConfigureAwaitValue = method.GetConfigureAwaitConfig(configureAwaitValue);
if (localConfigureAwaitValue == null)
{
var methodConfigureAwaitValue = method.GetConfigureAwaitConfig(typeConfigureAwaitValue);

if (!methodConfigureAwaitValue.HasValue)
continue;
}

var asyncStateMachineType = method.GetAsyncStateMachineType();
if (asyncStateMachineType != null)
{
AddAwaitConfigToAsyncMethod(asyncStateMachineType, localConfigureAwaitValue.Value);
AddAwaitConfigToAsyncMethod(asyncStateMachineType, methodConfigureAwaitValue.Value);
}
else if (method.GetAsyncStateMachineKind() == AsyncStateMachineKind.CompilerService)
{
AddAwaitConfigToAsyncMethod(method, methodConfigureAwaitValue.Value);
}
}
}
Expand Down Expand Up @@ -114,6 +117,10 @@ void TryRedirectMethodInstruction(MethodReference method, Instruction instructio
var declaringType = method.DeclaringType;
if (declaringType.FullName == "System.Threading.Tasks.Task")
{
// Only redirect GetAwaiter; constructors and other members must not be redirected
if (method.Name != "GetAwaiter")
return;

var newOperand = configuredTaskAwaitableTypeDef.Method(method);
if (newOperand != null)
{
Expand All @@ -126,6 +133,10 @@ void TryRedirectMethodInstruction(MethodReference method, Instruction instructio

if (declaringType.FullName == "System.Threading.Tasks.ValueTask")
{
// Only redirect GetAwaiter; constructors and other members must not be redirected
if (method.Name != "GetAwaiter")
return;

var newOperand = configuredValueTaskAwaitableTypeDef.Method(method);
if (newOperand != null)
{
Expand Down Expand Up @@ -168,6 +179,10 @@ void TryRedirectMethodInstruction(MethodReference method, Instruction instructio
// Change Task`1 to ConfiguredTaskAwaitable`1
if (declaringType.FullName.StartsWith("System.Threading.Tasks.Task`1"))
{
// Only redirect GetAwaiter; other members (including constructors) must not be redirected
if (method.Name != "GetAwaiter")
return;

var newOperand = genericConfiguredTaskAwaitableTypeDef.Method(method);
if (newOperand != null)
{
Expand All @@ -183,6 +198,10 @@ void TryRedirectMethodInstruction(MethodReference method, Instruction instructio
// Change Task`1 to ConfiguredTaskAwaitable`1
if (declaringType.FullName.StartsWith("System.Threading.Tasks.ValueTask`1"))
{
// Only redirect GetAwaiter; other members (including constructors) must not be redirected
if (method.Name != "GetAwaiter")
return;

var newOperand = genericConfiguredValueTaskAwaitableTypeDef.Method(method);
if (newOperand != null)
{
Expand Down
Loading