Skip to content

Commit

Permalink
Merge pull request #797 from varvet/kbs/pundit-context
Browse files Browse the repository at this point in the history
First pass of Pundit::Context
  • Loading branch information
Burgestrand authored May 8, 2024
2 parents aabb344 + 6f04482 commit 176cabb
Show file tree
Hide file tree
Showing 8 changed files with 334 additions and 131 deletions.
108 changes: 21 additions & 87 deletions lib/pundit.rb
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
require "active_support/core_ext/module/introspection"
require "active_support/dependencies/autoload"
require "pundit/authorization"
require "pundit/context"
require "pundit/cache_store/null_store"
require "pundit/cache_store/legacy_store"

# @api private
# To avoid name clashes with common Error naming when mixing in Pundit,
Expand Down Expand Up @@ -64,104 +67,35 @@ def self.included(base)
end

class << self
# Retrieves the policy for the given record, initializing it with the
# record and user and finally throwing an error if the user is not
# authorized to perform the given action.
#
# @param user [Object] the user that initiated the action
# @param possibly_namespaced_record [Object, Array] the object we're checking permissions of
# @param query [Symbol, String] the predicate method to check on the policy (e.g. `:show?`)
# @param policy_class [Class] the policy class we want to force use of
# @param cache [#[], #[]=] a Hash-like object to cache the found policy instance in
# @raise [NotAuthorizedError] if the given query method returned false
# @return [Object] Always returns the passed object record
def authorize(user, possibly_namespaced_record, query, policy_class: nil, cache: {})
record = pundit_model(possibly_namespaced_record)
policy = if policy_class
policy_class.new(user, record)
# @see [Pundit::Context#authorize]
def authorize(user, record, query, policy_class: nil, cache: nil)
context = if cache
Context.new(user: user, policy_cache: cache)
else
cache[possibly_namespaced_record] ||= policy!(user, possibly_namespaced_record)
Context.new(user: user)
end

raise NotAuthorizedError, query: query, record: record, policy: policy unless policy.public_send(query)

record
end

# Retrieves the policy scope for the given record.
#
# @see https://github.com/varvet/pundit#scopes
# @param user [Object] the user that initiated the action
# @param scope [Object] the object we're retrieving the policy scope for
# @raise [InvalidConstructorError] if the policy constructor called incorrectly
# @return [Scope{#resolve}, nil] instance of scope class which can resolve to a scope
def policy_scope(user, scope)
policy_scope_class = PolicyFinder.new(scope).scope
return unless policy_scope_class

begin
policy_scope = policy_scope_class.new(user, pundit_model(scope))
rescue ArgumentError
raise InvalidConstructorError, "Invalid #<#{policy_scope_class}> constructor is called"
end

policy_scope.resolve
context.authorize(record, query: query, policy_class: policy_class)
end

# Retrieves the policy scope for the given record.
#
# @see https://github.com/varvet/pundit#scopes
# @param user [Object] the user that initiated the action
# @param scope [Object] the object we're retrieving the policy scope for
# @raise [NotDefinedError] if the policy scope cannot be found
# @raise [InvalidConstructorError] if the policy constructor called incorrectly
# @return [Scope{#resolve}] instance of scope class which can resolve to a scope
def policy_scope!(user, scope)
policy_scope_class = PolicyFinder.new(scope).scope!
return unless policy_scope_class

begin
policy_scope = policy_scope_class.new(user, pundit_model(scope))
rescue ArgumentError
raise InvalidConstructorError, "Invalid #<#{policy_scope_class}> constructor is called"
end

policy_scope.resolve
# @see [Pundit::Context#policy_scope]
def policy_scope(user, *args, **kwargs, &block)
Context.new(user: user).policy_scope(*args, **kwargs, &block)
end

# Retrieves the policy for the given record.
#
# @see https://github.com/varvet/pundit#policies
# @param user [Object] the user that initiated the action
# @param record [Object] the object we're retrieving the policy for
# @raise [InvalidConstructorError] if the policy constructor called incorrectly
# @return [Object, nil] instance of policy class with query methods
def policy(user, record)
policy = PolicyFinder.new(record).policy
policy&.new(user, pundit_model(record))
rescue ArgumentError
raise InvalidConstructorError, "Invalid #<#{policy}> constructor is called"
# @see [Pundit::Context#policy_scope!]
def policy_scope!(user, *args, **kwargs, &block)
Context.new(user: user).policy_scope!(*args, **kwargs, &block)
end

# Retrieves the policy for the given record.
#
# @see https://github.com/varvet/pundit#policies
# @param user [Object] the user that initiated the action
# @param record [Object] the object we're retrieving the policy for
# @raise [NotDefinedError] if the policy cannot be found
# @raise [InvalidConstructorError] if the policy constructor called incorrectly
# @return [Object] instance of policy class with query methods
def policy!(user, record)
policy = PolicyFinder.new(record).policy!
policy.new(user, pundit_model(record))
rescue ArgumentError
raise InvalidConstructorError, "Invalid #<#{policy}> constructor is called"
# @see [Pundit::Context#policy]
def policy(user, *args, **kwargs, &block)
Context.new(user: user).policy(*args, **kwargs, &block)
end

private

def pundit_model(record)
record.is_a?(Array) ? record.last : record
# @see [Pundit::Context#policy!]
def policy!(user, *args, **kwargs, &block)
Context.new(user: user).policy!(*args, **kwargs, &block)
end
end

Expand Down
16 changes: 12 additions & 4 deletions lib/pundit/authorization.rb
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ module Authorization

protected

# @return [Pundit::Context] a new instance of {Pundit::Context} with the current user
def pundit
@pundit ||= Pundit::Context.new(
user: pundit_user,
policy_cache: Pundit::CacheStore::LegacyStore.new(policies)
)
end

# @return [Boolean] whether authorization has been performed, i.e. whether
# one {#authorize} or {#skip_authorization} has been called
def pundit_policy_authorized?
Expand Down Expand Up @@ -64,7 +72,7 @@ def authorize(record, query = nil, policy_class: nil)

@_pundit_policy_authorized = true

Pundit.authorize(pundit_user, record, query, policy_class: policy_class, cache: policies)
pundit.authorize(record, query: query, policy_class: policy_class)
end

# Allow this action not to perform authorization.
Expand Down Expand Up @@ -98,9 +106,9 @@ def policy_scope(scope, policy_scope_class: nil)
#
# @see https://github.com/varvet/pundit#policies
# @param record [Object] the object we're retrieving the policy for
# @return [Object, nil] instance of policy class with query methods
# @return [Object] instance of policy class with query methods
def policy(record)
policies[record] ||= Pundit.policy!(pundit_user, record)
pundit.policy!(record)
end

# Retrieves a set of permitted attributes from the policy by instantiating
Expand Down Expand Up @@ -162,7 +170,7 @@ def pundit_user
private

def pundit_policy_scope(scope)
policy_scopes[scope] ||= Pundit.policy_scope!(pundit_user, scope)
policy_scopes[scope] ||= pundit.policy_scope!(scope)
end
end
end
17 changes: 17 additions & 0 deletions lib/pundit/cache_store/legacy_store.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# frozen_string_literal: true

module Pundit
module CacheStore
# @api private
class LegacyStore
def initialize(hash = {})
@store = hash
end

def fetch(user:, record:)
_ = user
@store[record] ||= yield
end
end
end
end
18 changes: 18 additions & 0 deletions lib/pundit/cache_store/null_store.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# frozen_string_literal: true

module Pundit
module CacheStore
# @api private
class NullStore
@instance = new

class << self
attr_reader :instance
end

def fetch(*, **)
yield
end
end
end
end
127 changes: 127 additions & 0 deletions lib/pundit/context.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# frozen_string_literal: true

module Pundit
class Context
def initialize(user:, policy_cache: CacheStore::NullStore.instance)
@user = user
@policy_cache = policy_cache
end

attr_reader :user

# @api private
attr_reader :policy_cache

# Retrieves the policy for the given record, initializing it with the
# record and user and finally throwing an error if the user is not
# authorized to perform the given action.
#
# @param user [Object] the user that initiated the action
# @param possibly_namespaced_record [Object, Array] the object we're checking permissions of
# @param query [Symbol, String] the predicate method to check on the policy (e.g. `:show?`)
# @param policy_class [Class] the policy class we want to force use of
# @raise [NotAuthorizedError] if the given query method returned false
# @return [Object] Always returns the passed object record
def authorize(possibly_namespaced_record, query:, policy_class:)
record = pundit_model(possibly_namespaced_record)
policy = if policy_class
policy_class.new(user, record)
else
policy!(possibly_namespaced_record)
end

raise NotAuthorizedError, query: query, record: record, policy: policy unless policy.public_send(query)

record
end

# Retrieves the policy scope for the given record.
#
# @see https://github.com/varvet/pundit#scopes
# @param user [Object] the user that initiated the action
# @param scope [Object] the object we're retrieving the policy scope for
# @raise [InvalidConstructorError] if the policy constructor called incorrectly
# @return [Scope{#resolve}, nil] instance of scope class which can resolve to a scope
def policy_scope(scope)
policy_scope_class = policy_finder(scope).scope
return unless policy_scope_class

begin
policy_scope = policy_scope_class.new(user, pundit_model(scope))
rescue ArgumentError
raise InvalidConstructorError, "Invalid #<#{policy_scope_class}> constructor is called"
end

policy_scope.resolve
end

# Retrieves the policy scope for the given record. Raises if not found.
#
# @see https://github.com/varvet/pundit#scopes
# @param user [Object] the user that initiated the action
# @param scope [Object] the object we're retrieving the policy scope for
# @raise [NotDefinedError] if the policy scope cannot be found
# @raise [InvalidConstructorError] if the policy constructor called incorrectly
# @return [Scope{#resolve}] instance of scope class which can resolve to a scope
def policy_scope!(scope)
policy_scope_class = policy_finder(scope).scope!
return unless policy_scope_class

begin
policy_scope = policy_scope_class.new(user, pundit_model(scope))
rescue ArgumentError
raise InvalidConstructorError, "Invalid #<#{policy_scope_class}> constructor is called"
end

policy_scope.resolve
end

# Retrieves the policy for the given record.
#
# @see https://github.com/varvet/pundit#policies
# @param user [Object] the user that initiated the action
# @param record [Object] the object we're retrieving the policy for
# @raise [InvalidConstructorError] if the policy constructor called incorrectly
# @return [Object, nil] instance of policy class with query methods
def policy(record)
cached_find(record, &:policy)
end

# Retrieves the policy for the given record. Raises if not found.
#
# @see https://github.com/varvet/pundit#policies
# @param user [Object] the user that initiated the action
# @param record [Object] the object we're retrieving the policy for
# @raise [NotDefinedError] if the policy cannot be found
# @raise [InvalidConstructorError] if the policy constructor called incorrectly
# @return [Object] instance of policy class with query methods
def policy!(record)
cached_find(record, &:policy!)
end

private

def cached_find(record)
policy_cache.fetch(user: user, record: record) do
klass = yield policy_finder(record)
next unless klass

model = pundit_model(record)

begin
klass.new(user, model)
rescue ArgumentError
raise InvalidConstructorError, "Invalid #<#{klass}> constructor is called"
end
end
end

def policy_finder(record)
PolicyFinder.new(record)
end

def pundit_model(record)
record.is_a?(Array) ? record.last : record
end
end
end
Loading

0 comments on commit 176cabb

Please sign in to comment.