none
How to create a custom IDbSet<> that always filters on a value given in the constructor?

    Question

  • I have a mult-tenant database so most tables have an "AccountNumber" column that identifies who owns the row.  Almost every SQL query includes something like WHERE AccountNumber = currentAcctNumber.

    I'm looking at EF v4.1 and Code First, but I can't suss out how to get this embedded in almost every SQL query generated by EF.  I posted on StackOverflow (http://stackoverflow.com/questions/5676280/can-a-dbcontext-enforce-a-filter-policy/5680618#5680618) and it was suggested that I create my own DbSet.  I'm very close to getting this to work...but I need help with one last bit.

    I created a class that wraps the DbSet<> and overrides the GetEnumerator, like this:

    class DbSetForAccount<T> : IDbSet<T> where T : class, IDbEntityForAccount
    {
     private readonly AccountNumber mAccountNumber;
     private readonly IDbSet<T> mDbSet;
    
     public DbSetForAccount(AccountNumber accountNumber, DbContext dbContext)
     {
      mAccountNumber = accountNumber;
      mDbSet = dbContext.Set<T>();
     }
    
     public IEnumerator<T> GetEnumerator()
     {
      return (from i in mDbSet 
        where i.AccountNumber.Value == mAccountNumber.Value
        select i).GetEnumerator();
     }
    }
    
     This works for code that enumerates the set:
    using (var context = new MyContext(accountNumber))
    {
     // This code only shows items WHERE AN=accountNumber
     // Also, the SQL WHERE clause is modified as expected.
     foreach (var item in context.FilteredSet<MyThing>())
      Console.WriteLine(item.Name);
    }
    
    Unfortunately, it fails for code that aggregates the set:
    using (var context = new MyContext(accountNumber))
    {
     // This code returns ALL the records in the database
     // The SQL WHERE clause is NOT modified
     Console.WriteLine(context.FilteredSet<MyThing>().Count());
    }
    
    How can I create an IDbSet that always includes the WHERE AccountNumber == accountNumber logic?
    Wednesday, April 20, 2011 6:05 PM

Answers

  • Miguel,

    Thanks for the code because it put me on the right track.  There were some things that concerned me:

    1. Add/Attach/Remove allow items that do not match the filter
    2. Find can return items that do not match the filter
    3. Why did you seal the class?
    4. Why implement IListSource only to return false?

    I ended up with some code that is very close to yours.  It may just be a style difference, but I addressed the concerns I mentioned above.

    To use the class, you only need to pass in a DbContext, a filter expresssion, and (optionally) an Action that will be used to initialize Added/Created entities.

    Again, thanks for pointing me in the right direction.

    using System;
    using System.Collections;
    using System.Collections.Generic;
    using System.Collections.ObjectModel;
    using System.ComponentModel;
    using System.Data.Entity;
    using System.Linq;
    using System.Linq.Expressions;
    
    
    namespace MakeMyPledge.Data
    {
      class FilteredDbSet<TEntity> : IDbSet<TEntity>, IOrderedQueryable<TEntity>, IOrderedQueryable, IQueryable<TEntity>, IQueryable, IEnumerable<TEntity>, IEnumerable, IListSource
        where TEntity : class
      {
        private readonly DbSet<TEntity> Set;
        private readonly IQueryable<TEntity> FilteredSet;
        private readonly Action<TEntity> InitializeEntity;
    
        public FilteredDbSet(DbContext context)
          : this(context.Set<TEntity>(), i => true, null)
        {
        }
    
        public FilteredDbSet(DbContext context, Expression<Func<TEntity, bool>> filter)
          : this(context.Set<TEntity>(), filter, null)
        {
        }
    
        public FilteredDbSet(DbContext context, Expression<Func<TEntity, bool>> filter, Action<TEntity> initializeEntity)
          : this(context.Set<TEntity>(), filter, initializeEntity)
        {
        }
    
        private FilteredDbSet(DbSet<TEntity> set, Expression<Func<TEntity, bool>> filter, Action<TEntity> initializeEntity)
        {
          Set = set;
          FilteredSet = set.Where(filter);
          MatchesFilter = filter.Compile();
          InitializeEntity = initializeEntity;
        }
    
        public Func<TEntity, bool> MatchesFilter { get; private set; }
    
        public void ThrowIfEntityDoesNotMatchFilter(TEntity entity)
        {
          if (!MatchesFilter(entity))
            throw new ArgumentOutOfRangeException();
        }
    
        public TEntity Add(TEntity entity)
        {
          DoInitializeEntity(entity);
          ThrowIfEntityDoesNotMatchFilter(entity);
          return Set.Add(entity);
        }
    
        public TEntity Attach(TEntity entity)
        {
          ThrowIfEntityDoesNotMatchFilter(entity);
          return Set.Attach(entity);
        }
    
        public TDerivedEntity Create<TDerivedEntity>() where TDerivedEntity : class, TEntity
        {
          var entity = Set.Create<TDerivedEntity>();
          DoInitializeEntity(entity);
          return (TDerivedEntity)entity;
        }
    
        public TEntity Create()
        {
          var entity = Set.Create();
          DoInitializeEntity(entity);
          return entity;
        }
    
        public TEntity Find(params object[] keyValues)
        {
          var entity = Set.Find(keyValues);
          if (entity == null)
            return null;
    
          // If the user queried an item outside the filter, then we throw an error.
          // If IDbSet had a Detach method we would use it...sadly, we have to be ok with the item being in the Set.
          ThrowIfEntityDoesNotMatchFilter(entity);
          return entity;
        }
    
        public TEntity Remove(TEntity entity)
        {
          ThrowIfEntityDoesNotMatchFilter(entity);
          return Set.Remove(entity);
        }
    
        /// <summary>
        /// Returns the items in the local cache
        /// </summary>
        /// <remarks>
        /// It is possible to add/remove entities via this property that do NOT match the filter.
        /// Use the <see cref="ThrowIfEntityDoesNotMatchFilter"/> method before adding/removing an item from this collection.
        /// </remarks>
        public ObservableCollection<TEntity> Local { get { return Set.Local; } }
    
        IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator() { return FilteredSet.GetEnumerator(); }
    
        IEnumerator IEnumerable.GetEnumerator() { return FilteredSet.GetEnumerator(); }
    
        Type IQueryable.ElementType { get { return typeof(TEntity); } }
    
        Expression IQueryable.Expression { get { return FilteredSet.Expression; } }
    
        IQueryProvider IQueryable.Provider { get { return FilteredSet.Provider; } }
    
        bool IListSource.ContainsListCollection { get { return false; } }
    
        IList IListSource.GetList() { throw new InvalidOperationException(); }
    
        void DoInitializeEntity(TEntity entity)
        {
          if (InitializeEntity != null)
            InitializeEntity(entity);
        }
      }
    }
    
    

    • Marked as answer by Doug Clutter Friday, April 22, 2011 1:04 PM
    Thursday, April 21, 2011 2:36 PM

All replies

  • Hello, you need to create a construstor with a expression for your filter:

    using System;
    using System.Collections;
    using System.Collections.Generic;
    using System.Collections.ObjectModel;
    using System.Collections.Specialized;
    using System.ComponentModel;
    using System.Data.Entity;
    using System.Linq;
    using System.Linq.Expressions;
    
    namespace Delfin.Data
    {
      public sealed class CustomDbSet<TEntity> : IDbSet<TEntity>, IOrderedQueryable<TEntity>, IOrderedQueryable, IQueryable<TEntity>, IQueryable, IEnumerable<TEntity>, IEnumerable, IListSource where TEntity : class
      {
        private DbContext Context;
        private DbSet<TEntity> Set;
        private IQueryable<TEntity> Entities;
    
        public CustomDbSet(DbContext context)
          : this(context, context.Set<TEntity>(), context.Set<TEntity>())
        {
        }
    
        public CustomDbSet(DbContext context, Expression<Func<TEntity, bool>> filter)
          : this(context, context.Set<TEntity>(), context.Set<TEntity>().Where(filter))
        {
        }
    
        private CustomDbSet(DbContext context, DbSet<TEntity> set, IQueryable<TEntity> entities)
        {
          this.Context = context;
          this.Set = set;
          this.Entities = entities;
        }
    
        public TEntity Add(TEntity entity)
        {
          return this.Set.Add(entity);
        }
    
        public TEntity Attach(TEntity entity)
        {
          return this.Set.Attach(entity);
        }
    
        public TDerivedEntity Create<TDerivedEntity>() where TDerivedEntity : class, TEntity
        {
          return this.Set.Create<TDerivedEntity>();
        }
    
        public TEntity Create()
        {
          return this.Set.Create();
        }
    
        public TEntity Find(params object[] keyValues)
        {
          return this.Set.Find(keyValues);
        }
    
        public ObservableCollection<TEntity> Local
        {
          get
          {
            return this.Set.Local;
          }
        }
    
        public TEntity Remove(TEntity entity)
        {
          return this.Set.Remove(entity);
        }
    
        IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator()
        {
          return this.Entities.GetEnumerator();
        }
    
        IEnumerator IEnumerable.GetEnumerator()
        {
          return this.Entities.GetEnumerator();
        }
    
        Type IQueryable.ElementType
        {
          get
          {
            return this.Entities.ElementType;
          }
        }
    
        Expression IQueryable.Expression
        {
          get
          {
            return this.Entities.Expression;
          }
        }
    
        IQueryProvider IQueryable.Provider
        {
          get
          {
            return this.Entities.Provider;
          }
        }
    
        bool IListSource.ContainsListCollection
        {
          get
          {
            return false;
          }
        }
    
        IList IListSource.GetList()
        {
          throw new NotImplementdException();
        }
      }
    }
    
    

    Hope this helps,

    Miguel.

    • Marked as answer by Doug Clutter Friday, April 22, 2011 1:05 PM
    • Unmarked as answer by Doug Clutter Friday, April 22, 2011 2:44 PM
    Thursday, April 21, 2011 4:32 AM
  • Miguel,

    Thanks for the code because it put me on the right track.  There were some things that concerned me:

    1. Add/Attach/Remove allow items that do not match the filter
    2. Find can return items that do not match the filter
    3. Why did you seal the class?
    4. Why implement IListSource only to return false?

    I ended up with some code that is very close to yours.  It may just be a style difference, but I addressed the concerns I mentioned above.

    To use the class, you only need to pass in a DbContext, a filter expresssion, and (optionally) an Action that will be used to initialize Added/Created entities.

    Again, thanks for pointing me in the right direction.

    using System;
    using System.Collections;
    using System.Collections.Generic;
    using System.Collections.ObjectModel;
    using System.ComponentModel;
    using System.Data.Entity;
    using System.Linq;
    using System.Linq.Expressions;
    
    
    namespace MakeMyPledge.Data
    {
      class FilteredDbSet<TEntity> : IDbSet<TEntity>, IOrderedQueryable<TEntity>, IOrderedQueryable, IQueryable<TEntity>, IQueryable, IEnumerable<TEntity>, IEnumerable, IListSource
        where TEntity : class
      {
        private readonly DbSet<TEntity> Set;
        private readonly IQueryable<TEntity> FilteredSet;
        private readonly Action<TEntity> InitializeEntity;
    
        public FilteredDbSet(DbContext context)
          : this(context.Set<TEntity>(), i => true, null)
        {
        }
    
        public FilteredDbSet(DbContext context, Expression<Func<TEntity, bool>> filter)
          : this(context.Set<TEntity>(), filter, null)
        {
        }
    
        public FilteredDbSet(DbContext context, Expression<Func<TEntity, bool>> filter, Action<TEntity> initializeEntity)
          : this(context.Set<TEntity>(), filter, initializeEntity)
        {
        }
    
        private FilteredDbSet(DbSet<TEntity> set, Expression<Func<TEntity, bool>> filter, Action<TEntity> initializeEntity)
        {
          Set = set;
          FilteredSet = set.Where(filter);
          MatchesFilter = filter.Compile();
          InitializeEntity = initializeEntity;
        }
    
        public Func<TEntity, bool> MatchesFilter { get; private set; }
    
        public void ThrowIfEntityDoesNotMatchFilter(TEntity entity)
        {
          if (!MatchesFilter(entity))
            throw new ArgumentOutOfRangeException();
        }
    
        public TEntity Add(TEntity entity)
        {
          DoInitializeEntity(entity);
          ThrowIfEntityDoesNotMatchFilter(entity);
          return Set.Add(entity);
        }
    
        public TEntity Attach(TEntity entity)
        {
          ThrowIfEntityDoesNotMatchFilter(entity);
          return Set.Attach(entity);
        }
    
        public TDerivedEntity Create<TDerivedEntity>() where TDerivedEntity : class, TEntity
        {
          var entity = Set.Create<TDerivedEntity>();
          DoInitializeEntity(entity);
          return (TDerivedEntity)entity;
        }
    
        public TEntity Create()
        {
          var entity = Set.Create();
          DoInitializeEntity(entity);
          return entity;
        }
    
        public TEntity Find(params object[] keyValues)
        {
          var entity = Set.Find(keyValues);
          if (entity == null)
            return null;
    
          // If the user queried an item outside the filter, then we throw an error.
          // If IDbSet had a Detach method we would use it...sadly, we have to be ok with the item being in the Set.
          ThrowIfEntityDoesNotMatchFilter(entity);
          return entity;
        }
    
        public TEntity Remove(TEntity entity)
        {
          ThrowIfEntityDoesNotMatchFilter(entity);
          return Set.Remove(entity);
        }
    
        /// <summary>
        /// Returns the items in the local cache
        /// </summary>
        /// <remarks>
        /// It is possible to add/remove entities via this property that do NOT match the filter.
        /// Use the <see cref="ThrowIfEntityDoesNotMatchFilter"/> method before adding/removing an item from this collection.
        /// </remarks>
        public ObservableCollection<TEntity> Local { get { return Set.Local; } }
    
        IEnumerator<TEntity> IEnumerable<TEntity>.GetEnumerator() { return FilteredSet.GetEnumerator(); }
    
        IEnumerator IEnumerable.GetEnumerator() { return FilteredSet.GetEnumerator(); }
    
        Type IQueryable.ElementType { get { return typeof(TEntity); } }
    
        Expression IQueryable.Expression { get { return FilteredSet.Expression; } }
    
        IQueryProvider IQueryable.Provider { get { return FilteredSet.Provider; } }
    
        bool IListSource.ContainsListCollection { get { return false; } }
    
        IList IListSource.GetList() { throw new InvalidOperationException(); }
    
        void DoInitializeEntity(TEntity entity)
        {
          if (InitializeEntity != null)
            InitializeEntity(entity);
        }
      }
    }
    
    

    • Marked as answer by Doug Clutter Friday, April 22, 2011 1:04 PM
    Thursday, April 21, 2011 2:36 PM
  • Hello again, you are welcome. I spent time winth the EF 4.0 CTP5 to implement filter to the DbSets. I declared the CustomDbSet sealed because there are many people developing and I don't want that people extends the functionality of the DbSet. We provide a CustomDbSet, a CustomDbContext, and a DbEntity as a base class for all the entities (implementing INotifyPropertyChanged, IDataErrorInfo ... and custom validation) and in our library al the contructors of the CustomDbSet are internal, because they are created by the CustomDbContext (you can see only DbSets can be created if the entities are derived from DbEntity):

     

    using System;
    using System.Collections;
    using System.Collections.Generic;
    using System.ComponentModel;
    using System.Data;
    using System.Data.Entity;
    using System.Data.Entity.Infrastructure;
    using System.Data.Entity.ModelConfiguration.Conventions;
    using System.Data.Entity.Validation;
    using System.Data.Objects;
    using System.Linq;
    using System.Reflection;
    
    namespace Delfin.Data
    {
     public abstract class CustomDbContext : System.Data.Entity.DbContext
     {
     private ObjectContext ObjectContext;
     private Hashtable DbSets = new Hashtable();
    
     public DbContext(string nameOrConnectionString)
      : base(nameOrConnectionString)
     {
      try
      {
      this.ObjectContext = (this as IObjectContextAdapter).ObjectContext;
      this.ObjectContext.ObjectStateManager.ObjectStateManagerChanged += (sender, e) => this.OnObjectStateManagerChanged(e.Element as DbEntity, e.Action);
      this.ObjectContext.ObjectMaterialized += (sender, e) => this.OnObjectMaterialized(e.Entity as DbEntity);
      this.ObjectContext.SavingChanges += (sender, e) => this.OnSavingChanges();
    
      this.OnModelCreated();
      }
      catch (Exception ex)
      {
      throw new Exception("Unable to create the context. See inner exception for details.", ex);
      }
     }
    
     private void OnModelCreated()
     {
      var type = this.GetType();
      var method = type.GetMethod("Set", new Type[0]);
      type.GetProperties().Cast<PropertyInfo>().Where(property => property.PropertyType.IsGenericType && property.PropertyType.IsInterface).ForEach(property =>
      {
      var setType = property.PropertyType.GetGenericTypeDefinition();
      var entityType = property.PropertyType.GetGenericArguments().First();
      if (setType == typeof(IDbSet<>) &&
       typeof(DbEntity).IsAssignableFrom(entityType))
      {
       property.SetValue(this, method.MakeGenericMethod(property.PropertyType.GetGenericArguments()).Invoke(this, new object[0]), null);
      }
      });
     }
    
     private void OnObjectMaterialized(DbEntity entity)
     {
      entity.Context = this;
     }
    
     private void OnObjectStateManagerChanged(DbEntity entity, CollectionChangeAction action)
     {
      switch (action)
      {
      case CollectionChangeAction.Add:
       {
       if (!object.ReferenceEquals(this, entity.Context))
       {
        entity.Context = this;
       }
       }
       break;
    
      case CollectionChangeAction.Remove:
       {
       if (object.ReferenceEquals(this, entity.Context) &&
        this.Entry(entity).State == EntityState.Detached)
       {
        entity.Context = null;
       }
       }
       break;
      }
     }
    
     public void Detach(DbEntity entity)
     {
      if (entity == null)
      {
      throw new ArgumentNullException("entity");
      }
    
      this.ObjectContext.Detach(entity);
     }
    
     new public CustomDbSet<TEntity> Set<TEntity>() where TEntity : DbEntity
     {
      return (CustomDbSet<TEntity>)(this.DbSets[typeof(TEntity)] ?? (this.DbSets[typeof(TEntity)] = new CustomDbSet<TEntity>(this)));
     }
    
     protected virtual void OnSavingChanges()
     {
     }
     }
    }
    

     

    The DbEntity base class is defined (basically) as follows:

     

     

     public abstract class DbEntity : INotifyPropertyChanging, INotifyPropertyChanged, IComparable, IComparable<DbEntity>, IEquatable<DbEntity>, IEditableObject, IDataErrorInfo
     {
     protected internal DbContext Context
     {
      get;
      internal set;
     }
     }

     

     

    As you can see this DbEntity in addition with the CustomDbContext allow us to access the DbContext wich the entity belongs inside the code of each entity. Then to use the above code you can write and transparently al Db related objects are wrapped:

     

     

     
    public class TestDbContext : CustomDbContext
     {
     public TestDbContext()
      : base("YourConnectionString")
     {
     }
    
     public IDbSet<YourEntity> YourEntities
     {
      get;
      private set;
     }
     }
    
     [Table("YourTable")]
     public class YourEntity : DbEntity
     {
     [Key]
     public Guid YourEntityId
     {
      get;
      set;
     }
    
     public string Nombre
     {
      get;
      set;
     }
     }
    

     

     

    Finally, the behavior of the IListSource implementation could be changed (and I changed it) but the EF implements IListSource as I wrote. You can return the Local collection of the wrapped DbSet previously loaded.

    Hope this helps you,

    Miguel.

     



    Thursday, April 21, 2011 3:49 PM
  • Doug,

    What do you put in your DBContext-derived object? I tried:

        public class MyContext : DbContext
        {
            public DbSet<NormalThing> NormalThingss { get; set; }
            public FilteredDbSet<FilteredThing> FilteredThings { get; set; }
        }
    

    But then when I try

        using (MyContext db = new MyContext) 
        {
            var oneThing = db.FilteredThings.FirstOrDefault();
        }
    

    I get a error because FilteredThings is null. So then I tried

    public class MyContext : DbContext
    {
         // ...
         public FilteredDbSet<FilteredThing> FilteredThings 
         { 
              get { return new FilteredDbSet<FilteredThing>(this); } 
         }
    }
    
    which seems to be working better, but I am worried about a number of aspects of this approach as well. Would you mind posting what you did? It would sure save me a lot of trial and error.

    Thanks,

    -Robert


    • Edited by Rnickel42 Wednesday, April 3, 2013 11:35 PM
    Wednesday, April 3, 2013 11:28 PM
  • Hi Doug,

    I know its been a while since you answered this question, hope you can still remember it. I need some further help if I may...

    I implemented the above FilteredDBSet<> class of yours and it works like a mint. However I have additional problem. How do I default a value of particular field when adding records to a FilteredDBSet (I'm using the same field as my "default" query)...

    E.g. say few of my tables have OrganisationId field. As I access the FilteredDBSet, I am automatically providing a default expression which filters out records based on the above field.

    Next I would like to automatically set this OrganisationId field as I insert new records to the FilteredDBSet...

    I am trying to avoid changing Entity classes, since these are automatically created by the Model first approach and may be re-generated if I update the model.

    Hope you can help. Thanks

    Friday, January 31, 2014 10:27 PM