Don't fall into an interface extension loop!

๐Ÿ•ณ๏ธ๐Ÿšถ โ˜› ๐Ÿ—ฌ๐Ÿ•ณ๏ธ โ˜› ๐Ÿšง๐Ÿ•ณ๏ธ๐Ÿšง โ˜› ๐Ÿ”

 January 31, 2018


Extension methods are a great feature of C# (and other .NET languages). But, their usage on interfaces recently caused me some problems. I made an assumption about how the object would be "seen" at runtime and caused a loop. A quick fix was to cast the interface-instance to a known type, but this ultimately wouldn't scale since I needed to manually check the type at runtime and either build if/else or switch logic. Here's a modified form of the initial problem and the solution.

Creating a union of two objects

I'm working to release an Alexa library for .NET and needed a way to combine two or more response representations. These representations implement an interface, which defines property-types using interfaces, and that's where the problem begins. The derived representations ultimately don't (or shouldn't) know what those property-types are, just that it needs to be able to union them. Additionally, this "union" behavior is not part of the "core" package, so everything is implemented as an extension method.

Simplified-version of the various objects involved

public interface IResponse
{
    IOutput Output { get; }
}

public interface IOutput
{
    string Text { get ; }
}

public class Response : IResponse
{
    public IOutput Output { get; set; }
}

public static class ResponseExtensions
{
    public static IResponse Union(this Response value, Response response)
    {
        return new Response()
        {
            Output = value.Output.Union(response.Output)
        };
    }
}

public class Output : IOutput
{
    public string Text { get; set; }
}

public static class OutputExtensions
{
    public static Output Union(this Output value, Output output)
    {
        return new Output()
        {
            Text = value.Text + output.Text
        };
    }
}

// This is the class we will address in this post.
public static class IOutputExtensions
{
    public static IOutput Union(this IOutput value, IOutput output)
    {
        return value.Union(output);
    }
}

public class Program
{
    public static void Main()
    {
        var response1 = new Response()
        {
            Output = new Output()
            {
                Text = "Hello, "
            }
        };

        var response2 = new Response()
        {
            Output = new Output()
            {
                Text = "World."
            }
        };

        var finalResponse = response1.Union(response2);

        Console.WriteLine(finalResponse.Output.Text);

        // Expected output: "Hello, World."
        // Instead, we get a loop at `IOutputExtensions.IOutput.Union`.
    }
}

What I expected to happen โš—๏ธ

When response.Union is called, it also calls Union on the IOutput-property. I knew this would call the IOutputExtensions.IOuput.Union extension method. But, I also expected the derived IOutput-type, Output, would get resolved within this method at runtime and Output.Union would get called.

Testing the loop ๐Ÿ“

We can update the IOutputExtensions.IOutput.Union method to throw an exception if it gets called more than once. In our simple example, it would denote we are stuck in a loop:

public static class IOutputExtensions
{
    private static int count = 0;

    public static IOutput Union(this IOutput value, IOutput output)
    {
        if (++count > 1)
        {
            throw new Exception("We are in a loop!");	
        }

        return value.Union(output);
    }
}

Simple fix ๐Ÿ”จ

As I alluded to, a simple fix would be to check the type at runtime and call the extension method of the derived type. But, this means any new derived-type would need to get added here to prevent things from breaking. It's just a ticking ๐Ÿ’ฃ:

public static class IOutputExtensions
{
    public static IOutput Union(IOutput value, IOutput output)
    {
        if (value.GetType() == typeof(Output))
        {
            return (value as Output).Union(output as Output);
        }
        else
        {
            return value.Union(output);
        }
    }
}

Yes, both value and output need to be cast. Otherwise, the runtime tries to find the lowest matching method and you're back in loop-mode. This won't work:


public static class IOutputExtensions
{
    public static IOutput Union(IOutput value, IOutput output)
    {
        if (value.GetType() == typeof(Output))
        {
            // โš ๏ธ DON'T DO THIS!
            return (value as Output).Union(output);
        }
    }
}

Better fix ๐Ÿ› ๏ธ

Since the Union logic is implemented as extension methods, we can't use the regular way of getting the method at runtime since it's not actually on the object (i.e. this won't work: value.GetType().GetMethod("Union", BindingFlags.Public | BindingFlags.Instance)). Instead, we need to query methods which are extension methods:

public static IEnumerable<MethodInfo> GetExtensionMethods(this object value, Assembly assembly)
{
    // `assembly` is a parameter since I want to make sure my extension implementations are used.
    // However, someone else implementing their own version could also pass in their assembly.
    assembly = assembly ?? throw new ArgumentNullException(nameof(assembly));

    return
        from types in assembly.GetTypes()

        // We now use the `Static`-flag instead of `Instance`.
        from methods in types.GetMethods(BindingFlags.Static | BindingFlags.Public)

        where methods.IsDefined(typeof(ExtensionAttribute))
        where methods.GetParameters().First().ParameterType == value.GetType()
        select methods;
}

And now, the final-version of the interface extension method:

public static class IOutputExtensions
{
    public static IOutput Union(IOutput value, IOutput output)
    {
        var unionMethod = value.GetExtensionMethods(typeof(IOutputExtensions).Assembly)
            .FirstOrDefault(
                // We need to do some additional filtering to make sure it's exactly the method we want to use.
                x => x.Name == nameof(Union) &&
                x.GetParameters().Count() == 2 &&
                x.GetParameters().All(y => y.ParameterType == value.GetType()));

        return unionMethod?.Invoke(null, new[] { value, outputSpeech }) as IOutput
            ?? throw new ArgumentException(
                nameof(value),
                $"Does not implement expected `{ nameof(Union) }` method.");
    }
}

๐Ÿ Conclusions ๐Ÿ

It's a bit more work, but now I can implement any number of IOuput versions and not worry about things breaking down at runtime. And, extension methods are ๐Ÿ‘! But sometimes, using them on interfaces can lead to unexpected results โšกโšก.