Wednesday, 30 January 2008

Linq Data Context Extension Methods

I've been working on a project using a lot of linq lately and came across a problem where the objects we were creating had unique id's that were integer values. A new object within the system would have a decrementing id value such as -1, -2, -3 for each new instance of the object type.

After working with Linq for a while it became apparent that when you used the InsertOnSubmit method (which used to be called Add in the Beta 2 version hence the confusion) the new object was not added to the data context but was actually added to the datacontext.ChangeSet collection.

We therefore needed a way to easily find objects within the data context whether they were new, modified or unchanged. I created a data context extension helper which allows the user to specify the type of object to be found and a search function (using the Func<T,TResult> delegate) to indicate which objects to search for.

The code for the DataContextExtension class is shown below:

public static class DataContextExtension
{
/// <summary>
///
Finds a linq object using the specified search function.
/// </summary>
/// <typeparam name="TLinq">
The type of linq object to be returned.</typeparam>
/// <param name="context">
The data context in which to search for the entity.</param>
/// <param name="searchFunction">
The search condition to meet in order to find the linq object.</param>
/// <returns>
///
The first or default linq object found by the search; otherwise null.
/// </returns>
/// <remarks>
This method should be used when
/// it is not known whether the linq object you are looking for is new or existing. This will
/// first search the new list, if this returns null then the existing list is searched. </remarks>
public static TLinq Find<TLinq>(this DataContext context, Func<TLinq, bool> searchFunction)
where TLinq : class
{
return Find<TLinq>(context, searchFunction, true) ??
Find<TLinq>(context, searchFunction, false);
}

/// <summary>
///
Finds a linq object using an entity id.
/// </summary>
/// <typeparam name="TLinq">
The type of the linq.</typeparam>
/// <param name="context">
The data context in which to search for the entity.</param>
/// <param name="searchFunction">
The search condition to meet in order to find the linq object.</param>
/// <param name="id">
The id.</param>
/// <returns>
///
The first or default linq object found by the search.
/// </returns>
public static TLinq Find<TLinq>(this DataContext context, Func<TLinq, bool> searchFunction, int id)
where TLinq : class
{
return Find<TLinq>(context, searchFunction, id < 0);
}

/// <summary>
///
Finds an linq object of the specified type by first searching for any objects that have been
/// added to the data context changeset or looking in the linq table for a single row that
/// meets the specified search function.
/// </summary>
/// <typeparam name="TLinq">
The type of linq object to be returned.</typeparam>
/// <param name="context">
The data context in which to search for the entity.</param>
/// <param name="searchFunction">
The search condition to meet in order to find the linq object.</param>
/// <param name="isNew">
if set to <c>true</c> the object to find is new.</param>
/// <returns>
///
The first or default linq object found by the search.
/// </returns>
public static TLinq Find<TLinq>(this DataContext context, Func<TLinq, bool> searchFunction, bool isNew)
where TLinq : class
{
ChangeSet changeSet = context.GetChangeSet();

return isNew ? Get<TLinq>(changeSet.Inserts, searchFunction)
: (Get<TLinq>(changeSet.Updates, searchFunction) ??
Get<TLinq>(changeSet.Deletes, searchFunction)) ??
context.GetTable<TLinq>().FirstOrDefault<TLinq>(searchFunction);
}

/// <summary>
///
Counts the number of linq objects of the given type within the specified context.
/// </summary>
/// <typeparam name="TLinq">
The type of the linq.</typeparam>
/// <param name="context">
The context.</param>
/// <param name="searchFunction">
The search function used to match the objects.</param>
/// <returns></returns>
public static int Count<TLinq>(this DataContext context, Func<TLinq, bool> searchFunction) where TLinq : class
{
int count = 0;

ChangeSet changeSet = context.GetChangeSet();

count += (changeSet.Inserts.Where<object>(o => o is TLinq)).Cast<TLinq>().Count<TLinq>(searchFunction);
count += context.GetTable<TLinq>().Count<TLinq>(searchFunction);
count -= (changeSet.Deletes.Where<object>(o => o is TLinq)).Cast<TLinq>().Count<TLinq>(searchFunction);

return count;
}

#region Private methods
/// <summary>
///
Gets the first or default linq object from the collection of objects using the specified
/// search function.
/// </summary>
/// <param name="objects">
The object collection to search.</param>
/// <param name="searchFunction">
The function to use to search.</param>
/// <returns>
A linq object matching the condition.</returns>
private static TLinq Get<TLinq>(IEnumerable<object> objects, Func<TLinq, bool> searchFunction)
{
return objects.Where<object>(o => o is TLinq).Cast<TLinq>().FirstOrDefault<TLinq>(searchFunction);
}
#endregion
}


Here is an example of how it can be used:



private MyLinqObject BuildMyLinqObject(int id, MyDataContext context)
{
MyLinqObject myLinq = context.Find<MyLinqObject>(m => m.Id == id);
if (myLinq == null)
{
myLinq = new MyLinqObject();
}

/* populate details as necessary */

return myLinq;
}


So you can quite easily make a simple method call and find any instance of the type of object you require that may have been created within the Linq data context.



Hope this helps



Enjoy!

No comments: