如何在执行之前将Entity Framework包装起来拦截LINQ表达式?

我想在执行之前重写LINQ表达式的某些部分。 而且我在将重写器注入正确的位置时遇到了问题(实际上完全没有)。

查看entity framework源(在reflection器中)它最终归结为IQueryProvider.Execute ,它在EF中通过ObjectContext耦合到表达式,提供internal IQueryProvider Provider { get; } 属性。

所以我创建了一个包装类(实现IQueryProvider )来在调用Execute时执行Expression重写,然后将其传递给原始的Provider。

问题是, Provider背后的字段是private ObjectQueryProvider _queryProvider; 。 这个ObjectQueryProvider是一个内部密封类 ,这意味着不可能创建一个提供添加重写的子类。

因此,由于非常紧密耦合的ObjectContext,这种方法让我陷入了死胡同。

如何解决这个问题呢? 我看错了方向吗? 有没有办法让自己注入这个ObjectQueryProvider

更新 :虽然提供的解决方案在您使用存储库模式“包装”ObjectContext时都能正常工作,但是允许从ObjectContext直接使用生成的子类的解决方案将更可取。 因此保持与Dynamic Data脚手架兼容。

根据Arthur的回答,我创建了一个工作包装器。

提供的代码片段提供了使用您自己的QueryProvider和IQueryable根包装每个LINQ查询的方法。 这意味着您必须控制初始查询的开始(因为您大部分时间都使用任何类型的模式)。

这种方法的问题在于它不透明,更理想的情况是在构造函数级别的实体容器中注入一些东西。

我已经创建了一个可编译的实现,让它与entity framework一起工作,并添加了对ObjectQuery.Include方法的支持。 可以从MSDN复制表达式访客类。

 public class QueryTranslator : IOrderedQueryable { private Expression expression = null; private QueryTranslatorProvider provider = null; public QueryTranslator(IQueryable source) { expression = Expression.Constant(this); provider = new QueryTranslatorProvider(source); } public QueryTranslator(IQueryable source, Expression e) { if (e == null) throw new ArgumentNullException("e"); expression = e; provider = new QueryTranslatorProvider(source); } public IEnumerator GetEnumerator() { return ((IEnumerable)provider.ExecuteEnumerable(this.expression)).GetEnumerator(); } System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator() { return provider.ExecuteEnumerable(this.expression).GetEnumerator(); } public QueryTranslator Include(String path) { ObjectQuery possibleObjectQuery = provider.source as ObjectQuery; if (possibleObjectQuery != null) { return new QueryTranslator(possibleObjectQuery.Include(path)); } else { throw new InvalidOperationException("The Include should only happen at the beginning of a LINQ expression"); } } public Type ElementType { get { return typeof(T); } } public Expression Expression { get { return expression; } } public IQueryProvider Provider { get { return provider; } } } public class QueryTranslatorProvider : ExpressionVisitor, IQueryProvider { internal IQueryable source; public QueryTranslatorProvider(IQueryable source) { if (source == null) throw new ArgumentNullException("source"); this.source = source; } public IQueryable CreateQuery(Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); return new QueryTranslator(source, expression) as IQueryable; } public IQueryable CreateQuery(Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); Type elementType = expression.Type.GetGenericArguments().First(); IQueryable result = (IQueryable)Activator.CreateInstance(typeof(QueryTranslator<>).MakeGenericType(elementType), new object[] { source, expression }); return result; } public TResult Execute(Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); object result = (this as IQueryProvider).Execute(expression); return (TResult)result; } public object Execute(Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); Expression translated = this.Visit(expression); return source.Provider.Execute(translated); } internal IEnumerable ExecuteEnumerable(Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); Expression translated = this.Visit(expression); return source.Provider.CreateQuery(translated); } #region Visitors protected override Expression VisitConstant(ConstantExpression c) { // fix up the Expression tree to work with EF again if (c.Type == typeof(QueryTranslator)) { return source.Expression; } else { return base.VisitConstant(c); } } #endregion } 

存储库中的示例用法:

 public IQueryable List() { return new QueryTranslator(entities.Users).Include("Department"); } 

我有你需要的源代码 – 但不知道如何附加文件。

这里有一些片段(片段!我必须调整这段代码,所以它可能无法编译):

IQueryable的:

 public class QueryTranslator : IOrderedQueryable { private Expression _expression = null; private QueryTranslatorProvider _provider = null; public QueryTranslator(IQueryable source) { _expression = Expression.Constant(this); _provider = new QueryTranslatorProvider(source); } public QueryTranslator(IQueryable source, Expression e) { if (e == null) throw new ArgumentNullException("e"); _expression = e; _provider = new QueryTranslatorProvider(source); } public IEnumerator GetEnumerator() { return ((IEnumerable)_provider.ExecuteEnumerable(this._expression)).GetEnumerator(); } IEnumerator System.Collections.IEnumerable.GetEnumerator() { return _provider.ExecuteEnumerable(this._expression).GetEnumerator(); } public Type ElementType { get { return typeof(T); } } public Expression Expression { get { return _expression; } } public IQueryProvider Provider { get { return _provider; } } } 

IQueryProvider:

 public class QueryTranslatorProvider : ExpressionTreeTranslator, IQueryProvider { IQueryable _source; public QueryTranslatorProvider(IQueryable source) { if (source == null) throw new ArgumentNullException("source"); _source = source; } #region IQueryProvider Members public IQueryable CreateQuery(Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); return new QueryTranslator(_source, expression) as IQueryable; } public IQueryable CreateQuery(Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); Type elementType = expression.Type.FindElementTypes().First(); IQueryable result = (IQueryable)Activator.CreateInstance(typeof(QueryTranslator<>).MakeGenericType(elementType), new object[] { _source, expression }); return result; } public TResult Execute(Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); object result = (this as IQueryProvider).Execute(expression); return (TResult)result; } public object Execute(Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); Expression translated = this.Visit(expression); return _source.Provider.Execute(translated); } internal IEnumerable ExecuteEnumerable(Expression expression) { if (expression == null) throw new ArgumentNullException("expression"); Expression translated = this.Visit(expression); return _source.Provider.CreateQuery(translated); } #endregion #region Visits protected override MethodCallExpression VisitMethodCall(MethodCallExpression m) { return m; } protected override Expression VisitUnary(UnaryExpression u) { return Expression.MakeUnary(u.NodeType, base.Visit(u.Operand), u.Type.ToImplementationType(), u.Method); } #endregion } 

用法(警告:改编代码!可能无法编译):

 private Dictionary _table = new Dictionary(); public override IQueryable GetObjectQuery() { if (!_table.ContainsKey(type)) { _table[type] = new QueryTranslator( _ctx.CreateQuery("[" + typeof(T).Name + "]")); } return (IQueryable)_table[type]; } 

表达访客/翻译:

http://blogs.msdn.com/mattwar/archive/2007/07/31/linq-building-an-iqueryable-provider-part-ii.aspx

http://msdn.microsoft.com/en-us/library/bb882521.aspx

编辑:添加了FindElementTypes()。 希望所有方法现在都存在。

  ///  /// Finds all implemented IEnumerables of the given Type ///  public static IQueryable FindIEnumerables(this Type seqType) { if (seqType == null || seqType == typeof(object) || seqType == typeof(string)) return new Type[] { }.AsQueryable(); if (seqType.IsArray || seqType == typeof(IEnumerable)) return new Type[] { typeof(IEnumerable) }.AsQueryable(); if (seqType.IsGenericType && seqType.GetGenericArguments().Length == 1 && seqType.GetGenericTypeDefinition() == typeof(IEnumerable<>)) { return new Type[] { seqType, typeof(IEnumerable) }.AsQueryable(); } var result = new List(); foreach (var iface in (seqType.GetInterfaces() ?? new Type[] { })) { result.AddRange(FindIEnumerables(iface)); } return FindIEnumerables(seqType.BaseType).Union(result); } ///  /// Finds all element types provided by a specified sequence type. /// "Element types" are T for IEnumerable<T> and object for IEnumerable. ///  public static IQueryable FindElementTypes(this Type seqType) { return seqType.FindIEnumerables().Select(t => t.IsGenericType ? t.GetGenericArguments().Single() : typeof(object)); } 

只是想加入Arthur的例子。

正如Arthur警告他的GetObjectQuery()方法中存在一个错误。

它使用typeof(T).Name作为EntitySet的名称创建基本查询。

EntitySet名称与类型名称完全不同。

如果您使用EF 4,您应该这样做:

 public override IQueryable GetObjectQuery() { if (!_table.ContainsKey(type)) { _table[type] = new QueryTranslator( _ctx.CreateObjectSet(); } return (IQueryable)_table[type]; } 

只要您没有每种类型的多个实体集( MEST ),这是非常罕见的。

如果您使用的是3.5,则可以使用我的技巧13中的代码获取EntitySet名称并将其输入为:

 public override IQueryable GetObjectQuery() { if (!_table.ContainsKey(type)) { _table[type] = new QueryTranslator( _ctx.CreateQuery("[" + GetEntitySetName() + "]")); } return (IQueryable)_table[type]; } 

希望这可以帮助

亚历克斯

entity framework提示