Skip to content

Commit

Permalink
Foreign Key support improved
Browse files Browse the repository at this point in the history
  • Loading branch information
ackava committed Nov 3, 2022
1 parent 562b1d8 commit 3f4a699
Show file tree
Hide file tree
Showing 12 changed files with 342 additions and 42 deletions.
2 changes: 0 additions & 2 deletions NeuroSpeech.EntityAccessControl.Tests/BaseTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,9 @@ public void Seed(AppDbContext db)
Banned = true
}, new Account
{
// banned user
AccountID = 4
},new Account
{
// banned user
AccountID = 5
});

Expand Down
57 changes: 57 additions & 0 deletions NeuroSpeech.EntityAccessControl.Tests/Insert/InsertTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,63 @@ public async Task InsertAuthorsAsync()

}

[TestMethod]
public async Task InsertLikeActivityAsync()
{
using var scope = CreateScope();
using var db = scope.GetRequiredService<AppDbContext>();
db.UserID = 2;
var sdb = db;

var post = new Post
{
Name = "a",
Tags = new List<PostTag> {
new PostTag {
Name = "funny"
},
new PostTag
{
Name = "public"
}
},
Contents = new List<PostContent> {
new PostContent {
Name = "b",
Tags = new List<PostContentTag> {
new PostContentTag {
Name = "funny"
},
new PostContentTag
{
Name = "public"
}
}
}
}
};

sdb.Add(post);

await sdb.SaveChangesAsync();

using var scope2 = CreateScope();

var db2 = scope2.GetRequiredService<AppDbContext>();
db2.UserID = 4;

db2.PostActivities.Add(new PostActivity {
PostID = post.PostID,
AccountID = db2.UserID
});

db2.RaiseEvents = true;
db2.EnforceSecurity = true;

await db2.SaveChangesAsync();
}


[TestMethod]
public async Task UnauthorizedInsertAuthorsAsync()
{
Expand Down
28 changes: 26 additions & 2 deletions NeuroSpeech.EntityAccessControl.Tests/Model/AppDbContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ public AppDbContext(

public DbSet<Post> Posts { get; set; }

public DbSet<PostActivity> PostActivities { get; set; }

public DbSet<Campaign> Campaigns { get; set; }

public DbSet<Account> Accounts { get; set; }
Expand Down Expand Up @@ -54,8 +56,30 @@ protected override void OnModelCreating(ModelBuilder modelBuilder)
x.Keyword
});
}
}

}

[Table("PostActivities")]
public class PostActivity
{

[Key,DatabaseGenerated(DatabaseGeneratedOption.Identity)]
public long ActivityID { get; set; }

public long AccountID { get; set; }

public long PostID { get; set; }

[MaxLength(50)]
public string ActivityType { get; set; }

[ForeignKey(nameof(PostID))]
public Post Post { get; set; }

[ForeignKey(nameof(AccountID))]
public Account Account { get; set; }

}

[Table("Accounts")]
public class Account
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ private void SetupPostEvents()
{
Register<AccountEvents>();
Register<PostEvents>();
Register<PostActivityEvents>();
Register<PostTagEvents>();
Register<PostContentEvents>();
Register<PostAuthorEvents>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
namespace NeuroSpeech.EntityAccessControl.Tests.Model.Events
{
internal class AppEntityEvents<T>: DbEntityEvents<T>
where T : class
{
protected readonly AppDbContext db;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
using Microsoft.EntityFrameworkCore.ChangeTracking;
using NeuroSpeech.EntityAccessControl.Security;
using System.Reflection;

namespace NeuroSpeech.EntityAccessControl.Tests.Model.Events
{
internal class PostActivityEvents : AppEntityEvents<PostActivity>
{
public PostActivityEvents(AppDbContext db) : base(db)
{
}

public override IQueryContext<PostActivity> Filter(IQueryContext<PostActivity> q)
{
return q.Where(x => x.AccountID == db.UserID);
}

protected override IQueryContext ForeignKeyFilter(
ForeignKeyInfo<PostActivity> fk)
{
if(fk.Is(x => x.PostID))
{
return null;
}
return base.ForeignKeyFilter(fk);
}
}
}
4 changes: 4 additions & 0 deletions NeuroSpeech.EntityAccessControl/BaseDbContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ public IQueryable<T> FilteredQuery<T>()
public IEntityEvents? GetEntityEvents(Type type)
{
var eh = events.GetEvents(services, type);
if (eh != null)
{
eh.EnforceSecurity = this.EnforceSecurity;
}
return eh;
}

Expand Down
3 changes: 2 additions & 1 deletion NeuroSpeech.EntityAccessControl/DbContextEvents.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ internal abstract class EntityHandler
}

public void Register<T1, TE>()
where TE: DbEntityEvents<T1>
where T1 : class
where TE : DbEntityEvents<T1>
{
registrations[typeof(T1)] = typeof(TE);
}
Expand Down
35 changes: 26 additions & 9 deletions NeuroSpeech.EntityAccessControl/DbEntityEvents.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using NeuroSpeech.EntityAccessControl.Internal;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using NeuroSpeech.EntityAccessControl.Internal;
using NeuroSpeech.EntityAccessControl.Security;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
Expand All @@ -15,6 +17,7 @@ public delegate void IgnoreDelegate<T>(Expression<Func<T, object>> expression,
JsonIgnoreCondition condition = JsonIgnoreCondition.Always);

public class DbEntityEvents<T> : IEntityEvents
where T: class
{
public bool EnforceSecurity { get; set; }

Expand Down Expand Up @@ -144,15 +147,15 @@ public virtual IQueryContext<T> IncludeFilter(IQueryContext<T> q)
return Filter(q);
}

public virtual IQueryContext<T> ReferenceFilter(IQueryContext<T> q, FilterContext fc)
{
return ModifyFilter(q);
}
//public virtual IQueryContext<T> ReferenceFilter(IQueryContext<T> q, FilterContext fc)
//{
// return ModifyFilter(q);
//}

IQueryContext IEntityEvents.ReferenceFilter(IQueryContext q, FilterContext fc)
{
return ReferenceFilter((IQueryContext<T>)q, fc);
}
//IQueryContext IEntityEvents.ReferenceFilter(IQueryContext q, FilterContext fc)
//{
// return ReferenceFilter((IQueryContext<T>)q, fc);
//}


IQueryContext IEntityEvents.Filter(IQueryContext q)
Expand Down Expand Up @@ -225,6 +228,20 @@ public virtual Task UpdatedAsync(T entity)
return Task.CompletedTask;
}

IQueryContext? IEntityEvents.ForeignKeyFilter(EntityEntry entity, PropertyInfo key, object value, FilterFactory fs)
{
if(!EnforceSecurity)
{
return null;
}
return ForeignKeyFilter(new ForeignKeyInfo<T>((entity as EntityEntry<T>)!, key, value, fs));
}

protected virtual IQueryContext? ForeignKeyFilter(ForeignKeyInfo<T> fk)
{
return fk.Filtered();
}

Task IEntityEvents.UpdatedAsync(object entity)
{
return UpdatedAsync((T)entity);
Expand Down
7 changes: 5 additions & 2 deletions NeuroSpeech.EntityAccessControl/IEntityEvents.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
using System.Collections.Generic;
using Microsoft.EntityFrameworkCore.ChangeTracking;
using NeuroSpeech.EntityAccessControl.Security;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text.Json.Serialization;
Expand Down Expand Up @@ -34,6 +37,6 @@ public interface IEntityEvents

Task DeletedAsync(object entity);

IQueryContext ReferenceFilter(IQueryContext qec, FilterContext fc);
IQueryContext? ForeignKeyFilter(EntityEntry entity, PropertyInfo key, object value, FilterFactory fs);
}
}
124 changes: 124 additions & 0 deletions NeuroSpeech.EntityAccessControl/Security/FilterFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
using Microsoft.EntityFrameworkCore.ChangeTracking;
using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Reflection;
using System.Text;

namespace NeuroSpeech.EntityAccessControl.Security
{
public readonly ref struct ForeignKeyInfo<T>
where T : class
{
public readonly EntityEntry<T> Entry;
public readonly PropertyInfo Property;
public readonly object Value;
private readonly FilterFactory factory;
public readonly string Name;

public ForeignKeyInfo(
EntityEntry<T> entry,
PropertyInfo property,
object value,
FilterFactory factory)
{
this.Entry = entry;
this.Property = property;
this.Value = value;
this.factory = factory;
this.Name = property.Name;
}

public bool Is(string name)
{
return this.Name == name;
}

public bool Is<TR>(Expression<Func<T,TR>> exp)
{
if (exp.Body is MemberExpression me)
{
if(me.Member is PropertyInfo mp)
{
if (mp == this.Property)
{
return true;
}
}
}

return false;
}

public IQueryContext Set()
{
return factory.Set();
}

public IQueryContext Filtered()
{
return factory.Filtered();
}


public IQueryContext<TEntity> Set<TEntity>()
{
return factory.Set<TEntity>();
}

public IQueryContext<TEntity> Filtered<TEntity>()
{
return factory.Filtered<TEntity>();
}
}

public readonly struct FilterFactory
{
private readonly Func<IQueryContext> filteredSet;
private readonly Func<IQueryContext> set;

internal static FilterFactory From<T>(ISecureQueryProvider db, IQueryContext<T> feqc)
where T: class
{
var qc = () => feqc ?? new QueryContext<T>(db, db.Set<T>());
var fqc = () =>
{
feqc ??= new QueryContext<T>(db, db.Set<T>());
var eh = db.GetEntityEvents(typeof(T));
if (eh == null)
{
throw new EntityAccessException($"Access to {typeof(T).Name} denied");
}
return eh.ModifyFilter(feqc);
};
return new FilterFactory(fqc, qc);
}

internal FilterFactory(Func<IQueryContext> filteredSet, Func<IQueryContext> set)
{
this.filteredSet = filteredSet;
this.set = set;
}


public IQueryContext Filtered()
{
return this.filteredSet();
}

public IQueryContext Set()
{
return this.set();
}

public IQueryContext<T> Filtered<T>()
{
return (this.filteredSet() as IQueryContext<T>)!;
}

public IQueryContext<T> Set<T>()
{
return (this.set() as IQueryContext<T>)!;
}
}
}
Loading

0 comments on commit 3f4a699

Please sign in to comment.