如何在执行之前将Entity Framework包装起来拦截LINQ表达式?
我想在执行之前重写LINQ表达式的某些部分。 而且我在将重写器注入正确的位置时遇到了问题(实际上完全没有)。
查看entity framework源(在reflection器中)它最终归结为IQueryProvider.Execute
,它在EF中通过ObjectContext
耦合到表达式,提供internal IQueryProvider Provider { get; }
属性。
所以我创建了一个包装类(实现IQueryProvider
)来在调用Execute时执行Expression重写,然后将其传递给原始的Provider。
问题是, Provider
背后的字段是private ObjectQueryProvider _queryProvider;
。 这个ObjectQueryProvider
是一个内部密封类 ,这意味着不可能创建一个提供添加重写的子类。
因此,由于非常紧密耦合的ObjectContext,这种方法让我陷入了死胡同。
如何解决这个问题呢? 我看错了方向吗? 有没有办法让自己注入这个ObjectQueryProvider
?
更新 :虽然提供的解决方案在您使用存储库模式“包装”ObjectContext时都能正常工作,但是允许从ObjectContext直接使用生成的子类的解决方案将更可取。 因此保持与Dynamic Data脚手架兼容。
根据Arthur的回答,我创建了一个工作包装器。
提供的代码片段提供了使用您自己的QueryProvider和IQueryable根包装每个LINQ查询的方法。 这意味着您必须控制初始查询的开始(因为您大部分时间都使用任何类型的模式)。
这种方法的问题在于它不透明,更理想的情况是在构造函数级别的实体容器中注入一些东西。
我已经创建了一个可编译的实现,让它与entity framework一起工作,并添加了对ObjectQuery.Include方法的支持。 可以从MSDN复制表达式访客类。
public class QueryTranslator : IOrderedQueryable { private Expression expression = null; private QueryTranslatorProvider provider = null; public QueryTranslator(IQueryable source) { expression = Expression.Constant(this); provider = new QueryTranslatorProvider (source); } public QueryTranslator(IQueryable source, Expression e) { if (e == null) throw new ArgumentNullException("e"); expression = e; provider = new QueryTranslatorProvider (source); } public IEnumerator GetEnumerator() { return ((IEnumerable )provider.ExecuteEnumerable(this.expression)).GetEnumerator(); } System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { return provider.ExecuteEnumerable(this.expression).GetEnumerator(); } public QueryTranslator Include(String path) { ObjectQuery possibleObjectQuery = provider.source as ObjectQuery ; if (possibleObjectQuery != null) { return new QueryTranslator (possibleObjectQuery.Include(path)); } else { throw new InvalidOperationException("The Include should only happen at the beginning of a LINQ expression"); } } public Type ElementType { get { return typeof(T); } } public Expression Expression { get { return expression; } } public IQueryProvider Provider { get { return provider; } } } public class QueryTranslatorProvider : ExpressionVisitor, IQueryProvider { internal IQueryable source; public QueryTranslatorProvider(IQueryable source) { if (source == null) throw new ArgumentNullException("source"); this.source = source; } public IQueryable CreateQuery (Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); return new QueryTranslator (source, expression) as IQueryable ; } public IQueryable CreateQuery(Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); Type elementType = expression.Type.GetGenericArguments().First(); IQueryable result = (IQueryable)Activator.CreateInstance(typeof(QueryTranslator<>).MakeGenericType(elementType), new object[] { source, expression }); return result; } public TResult Execute(Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); object result = (this as IQueryProvider).Execute(expression); return (TResult)result; } public object Execute(Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); Expression translated = this.Visit(expression); return source.Provider.Execute(translated); } internal IEnumerable ExecuteEnumerable(Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); Expression translated = this.Visit(expression); return source.Provider.CreateQuery(translated); } #region Visitors protected override Expression VisitConstant(ConstantExpression c) { // fix up the Expression tree to work with EF again if (c.Type == typeof(QueryTranslator)) { return source.Expression; } else { return base.VisitConstant(c); } } #endregion }
存储库中的示例用法:
public IQueryable List() { return new QueryTranslator (entities.Users).Include("Department"); }
我有你需要的源代码 – 但不知道如何附加文件。
这里有一些片段(片段!我必须调整这段代码,所以它可能无法编译):
IQueryable的:
public class QueryTranslator : IOrderedQueryable { private Expression _expression = null; private QueryTranslatorProvider _provider = null; public QueryTranslator(IQueryable source) { _expression = Expression.Constant(this); _provider = new QueryTranslatorProvider (source); } public QueryTranslator(IQueryable source, Expression e) { if (e == null) throw new ArgumentNullException("e"); _expression = e; _provider = new QueryTranslatorProvider (source); } public IEnumerator GetEnumerator() { return ((IEnumerable )_provider.ExecuteEnumerable(this._expression)).GetEnumerator(); } IEnumerator System.Collections.IEnumerable.GetEnumerator() { return _provider.ExecuteEnumerable(this._expression).GetEnumerator(); } public Type ElementType { get { return typeof(T); } } public Expression Expression { get { return _expression; } } public IQueryProvider Provider { get { return _provider; } } }
IQueryProvider:
public class QueryTranslatorProvider : ExpressionTreeTranslator, IQueryProvider { IQueryable _source; public QueryTranslatorProvider(IQueryable source) { if (source == null) throw new ArgumentNullException("source"); _source = source; } #region IQueryProvider Members public IQueryable CreateQuery (Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); return new QueryTranslator (_source, expression) as IQueryable ; } public IQueryable CreateQuery(Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); Type elementType = expression.Type.FindElementTypes().First(); IQueryable result = (IQueryable)Activator.CreateInstance(typeof(QueryTranslator<>).MakeGenericType(elementType), new object[] { _source, expression }); return result; } public TResult Execute(Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); object result = (this as IQueryProvider).Execute(expression); return (TResult)result; } public object Execute(Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); Expression translated = this.Visit(expression); return _source.Provider.Execute(translated); } internal IEnumerable ExecuteEnumerable(Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); Expression translated = this.Visit(expression); return _source.Provider.CreateQuery(translated); } #endregion #region Visits protected override MethodCallExpression VisitMethodCall(MethodCallExpression m) { return m; } protected override Expression VisitUnary(UnaryExpression u) { return Expression.MakeUnary(u.NodeType, base.Visit(u.Operand), u.Type.ToImplementationType(), u.Method); } #endregion }
用法(警告:改编代码!可能无法编译):
private Dictionary _table = new Dictionary(); public override IQueryable GetObjectQuery () { if (!_table.ContainsKey(type)) { _table[type] = new QueryTranslator ( _ctx.CreateQuery ("[" + typeof(T).Name + "]")); } return (IQueryable )_table[type]; }
表达访客/翻译:
http://blogs.msdn.com/mattwar/archive/2007/07/31/linq-building-an-iqueryable-provider-part-ii.aspx
http://msdn.microsoft.com/en-us/library/bb882521.aspx
编辑:添加了FindElementTypes()。 希望所有方法现在都存在。
/// /// Finds all implemented IEnumerables of the given Type /// public static IQueryable FindIEnumerables(this Type seqType) { if (seqType == null || seqType == typeof(object) || seqType == typeof(string)) return new Type[] { }.AsQueryable(); if (seqType.IsArray || seqType == typeof(IEnumerable)) return new Type[] { typeof(IEnumerable) }.AsQueryable(); if (seqType.IsGenericType && seqType.GetGenericArguments().Length == 1 && seqType.GetGenericTypeDefinition() == typeof(IEnumerable<>)) { return new Type[] { seqType, typeof(IEnumerable) }.AsQueryable(); } var result = new List (); foreach (var iface in (seqType.GetInterfaces() ?? new Type[] { })) { result.AddRange(FindIEnumerables(iface)); } return FindIEnumerables(seqType.BaseType).Union(result); } /// /// Finds all element types provided by a specified sequence type. /// "Element types" are T for IEnumerable<T> and object for IEnumerable. /// public static IQueryable FindElementTypes(this Type seqType) { return seqType.FindIEnumerables().Select(t => t.IsGenericType ? t.GetGenericArguments().Single() : typeof(object)); }
只是想加入Arthur的例子。
正如Arthur警告他的GetObjectQuery()方法中存在一个错误。
它使用typeof(T).Name作为EntitySet的名称创建基本查询。
EntitySet名称与类型名称完全不同。
如果您使用EF 4,您应该这样做:
public override IQueryable GetObjectQuery () { if (!_table.ContainsKey(type)) { _table[type] = new QueryTranslator ( _ctx.CreateObjectSet (); } return (IQueryable )_table[type]; }
只要您没有每种类型的多个实体集( MEST ),这是非常罕见的。
如果您使用的是3.5,则可以使用我的技巧13中的代码获取EntitySet名称并将其输入为:
public override IQueryable GetObjectQuery () { if (!_table.ContainsKey(type)) { _table[type] = new QueryTranslator ( _ctx.CreateQuery ("[" + GetEntitySetName () + "]")); } return (IQueryable )_table[type]; }
希望这可以帮助
亚历克斯
entity framework提示