Finding Extensions

Finding Extensions

Several years ago I was spelunking through the Entity Framework code and came across the implementation for the Include extension method. That method allows a query for a data context to include a related model in the result-set. What I remember though, was how the Include method did, what it did. It actually went out, at runtime, and searched for other extension method(s), and then, deferred to those methods for the bulk of the operation.

That got me thinking of ways I might be able to use the approach in situations where I couldn’t hard code things. For instance, I had a notion of using the configuration to drive the setup of standard services, for an application (things like logging, email, repositories, etc). In that scenario, with everything driven from the configuration, I knew that I would never be able to directly call any downstream extension methods. The entire process would have to be completely data driven.

Only thing is, the method I saw in Microsoft’s code wasn’t particular performant. I knew that I would have to be smarter about how I went about finding extension methods, or the process would be too slow to do anyone any good.

Fast forward to today. I have the approach encapsulated into an extension method of my own. After trying various approaches, I finally decided to hang the logic off the AppDomain type, using an extension method. Today, I use this method in several of my NUGET packages. The performance is better because I’ve learned a few tricks, along the way.

I though I might cover this method in today’s blog post.

I’ll start by showing the listing for the method, then I’ll walk through it, step by step:

public static IEnumerable<MethodInfo> ExtensionMethods(
    this AppDomain appDomain,
    Type extensionType,
    string methodName,
    Type[] parameterTypes = null,
    string assemblyWhiteList = "",
    string assemblyBlackList = "Microsoft*, System*, mscorlib, netstandard"
    )
{
    Guard.Instance().ThrowIfNull(appDomain, nameof(appDomain))
        .ThrowIfNull(extensionType, nameof(extensionType))
        .ThrowIfNullOrEmpty(methodName, nameof(methodName));

    if (parameterTypes == null)
    {
        parameterTypes = new Type[0];
    }

    var assemblies = AppDomain.CurrentDomain.GetAssemblies().ToList();

    if (!string.IsNullOrEmpty(assemblyWhiteList))
    {
        var toLoad = assemblyWhiteList.Split(',').Where(
            x => assemblies.Any(y => !y.GetName().Name.IsMatch(x))
            );

        if (toLoad.Any())
        {
            toLoad.ForEach(x =>
            {
                var files = Directory.GetFiles(
                    AppDomain.CurrentDomain.BaseDirectory,
                    x.EndsWith(".dll") ? x : $"{x}.dll"
                    );
                foreach (var file in files)
                {
                    try
                    {
                        Assembly.LoadFile(file);
                    }
                    catch (Exception)
                    {
                        // Don't care, just won't search this file.
                    }
                }
            });

            assemblies = AppDomain.CurrentDomain.GetAssemblies().ToList();

            assemblies = assemblies.ApplyWhiteList(
                x => x.GetName().Name, assemblyWhiteList
                ).ToList();
        }
    }

    if (!string.IsNullOrEmpty(assemblyBlackList))
    {
        var blackParts = assemblyBlackList.Split(',');

        assemblies = assemblies.ApplyBlackList(
            x => x.GetName().Name, assemblyBlackList
            ).ToList();
    }

    var methods = new List<MethodInfo>();
    var options = new ParallelOptions()
    {
#if DEBUG
        MaxDegreeOfParallelism = 1 // <-- to make debugging easier.
#else
        MaxDegreeOfParallelism = Environment.ProcessorCount
#endif
    };

    Parallel.ForEach(assemblies, options, (assembly) =>
    {
        var types = assembly.GetTypes().Where(x => 
            x.IsClass && x.IsPublic && 
            x.IsSealed && !x.IsNested && 
            !x.IsGenericType
            );

        foreach (var type in types)
        {
            var typeMethods = type.GetMethods(
                BindingFlags.Static | 
                BindingFlags.Public
                ).Where(x => 
                    x.Name == methodName && 
                    x.IsDefined(typeof(ExtensionAttribute), false) &&
                    !x.ContainsGenericParameters
                    );

            foreach (var method in typeMethods)
            {
                var pi = method.GetParameters();
                var lhs = pi.Select(x => x.ParameterType).ToArray();
                var rhs = new Type[] { extensionType }.Concat(parameterTypes).ToArray();

                if (lhs.Count() != rhs.Count())
                {
                    continue;
                }

                var shouldAdd = true; 
                for (int z = 0; z < lhs.Length; z++)
                {
                    if (false == lhs[z].IsAssignableFrom(rhs[z]))
                    {
                        shouldAdd = false; 
                        break;
                    }
                }
                if (shouldAdd)
                {
                    methods.Add(method);
                    break;
                }
            }
        }
    });
    return methods;
}

The method starts by validating the incoming parameters. Next, I get a list of all the assemblies that are loaded into the current AppDomain instance. Now, for a typical .NET application, this list can grow to quite a few assemblies. If I were to iterate through all of them it would slow everything down. And, really, there’s no reason to do that since most of the assemblies in the list are standard .NET assemblies that, likely, will never contain whatever extension method(s) I’m looking for.

That was one of the tricks I learned, early on: filter the list before I start iterating through it. I thought of various ways to perform that filtering and I finally settled on the use of a white list and a black list. The method applies the white list first, assuming anything has been passed into the method, using the assemblyWhiteList parameter .

Now, assuming a white list was supplied, I first split that list up into an array. That just makes the information easier to work with. After splitting, I look for all assemblies in the white list that are not already in the AppDomain’s list. That gives me a list of assemblies to load. Assuming that list isn’t empty, I then iterate though and do my best to load each assembly.

After that, the list of assemblies has changed. I update it by calling the ApplyWhiteList extension method. That method is part of the CG.Core NUGET package. For now, just know that it applies a white list, which, essentially, just ensures that the resulting list includes anything on the white list, as well as everything that was on the AppDomain’s list.

The next step is to apply anything in the black list. Assuming a black list was specified, I split it up into an array, just like I did with the white list. Then I call the ApplyBlackList extension method, to pare away anything that’s on the black list, from the AppDomain’s assembly list. The ApplyBlackList extension is also part of the CG.Core NUGET package. For now, just know that this method will trim down the AppDomain’s assembly list by removing anything that matches whatever was on the black list.

The result, after those two steps, is that the AppDomain’s assembly list contains everything on the white list (if anything was specified), and nothing that was on the black list (assuming a black list was specified). This method of filtering is so powerful that I included a default black list, for the assemblyBlackList parameter, that includes any assembly that starts with the word ‘Microsoft’, or ‘System’, or the mscorlib library, or the netstandard library. My thinking is, it’s pretty unlikely I’ll ever find myself looking in any of those assemblies for an extension method. If I do, of course, then I can certainly modify the blacklist. My point is, using the black list / white list method, I can quickly narrow a list of thousands of .NET assemblies down to a tiny handful.

For the next part, I decided that I could probably search the filtered list of assemblies in parallel, to improve the overall runtime. I do that using the Parallel.ForEach method. For each assembly, I first get a list of all the public, non-sealed, non-nested, non-generic class types. Then I iterate through that list of types. For each type, I look for methods that are static, public, decorated with the ExtensionAttribute attribute (only extension methods have this decoration), and, of course, have a name that matches whatever method I’m look for. If I find any methods that way, I iterate through them all and try to match the parameters. I’m specifically trying to ensure that the parameter count, and types, match what I’m lookin for. If I find a method that matches all those things, I add it to a list.

Once the parallel search is done, I end up with a list of MethodInfo objects, with each item in the list representing an extension method that is a pretty darned good candidate for what I’m look for. Since this method only searches for the methods, and doesn’t need to actually call them, then I’m done at this point. The method ends by returning the list of MethodInfo objects.

So how might some use this to find an extension method? Here’s a quick example:

public static class X 
{
   public static string MyMethod(this string a)
   {
      return a + " called from MyMethod";
   }
}

public class Test
{
   public static void PerformTest()
   {
      var a = "hi there";
      var methods = AppDomain.CurrentDomain.ExtensionMethods(
         typeof(string), // extension method hangs off the string type.
         "MyMethod", // looking for something called MyMethod.             
      );
     var b = methods.First().Invoke();
     // b should contain: "hi there called from MyMethod"
   }
}

So, in some assembly, add your X class, with the MyMethod extension method. Then, in your calling application/library/whatever, call the ExtensionMethods method, on the current AppDomain instance. Supply the parameters that make sense, to you. Take the resulting list of MethodInfo objects, choose one, and call it. The end result is exactly like you directly called the MyMethod extension method – but, without creating any compile time links between your calling application and whatever assembly you put the MyMethod extension method into.

I also added two overloads, for ExtensionMethods, that accept generic parameters. There is one that accepts a single generic parameter, and another that accepts two generic parameters. Those methods allow us to search for matching generic extension methods. Otherwise, they work pretty much exactly like this method does. I wrote the generic overloads to support my generic loader classes, for repository and strategy types, in my CG.Business NUGET package. I’ll blog about those one of these days, soon.


All the code for ExtensionMethods method is available on GitHub HERE. The CG.Reflection package itself is available HERE.

Thanks for reading!

Photo by JOSHUA COLEMAN on Unsplash