diff --git a/cedar-policy-core/src/ast/entity.rs b/cedar-policy-core/src/ast/entity.rs index 38a9ee402..0f4d7d88a 100644 --- a/cedar-policy-core/src/ast/entity.rs +++ b/cedar-policy-core/src/ast/entity.rs @@ -569,14 +569,7 @@ impl Entity { Ok(()) } - /// Mark the given `UID` as an ancestor of this `Entity`. - // When fuzzing, `add_ancestor()` is fully `pub`. - #[cfg(not(fuzzing))] - pub(crate) fn add_ancestor(&mut self, uid: EntityUID) { - self.ancestors.insert(uid); - } /// Mark the given `UID` as an ancestor of this `Entity` - #[cfg(fuzzing)] pub fn add_ancestor(&mut self, uid: EntityUID) { self.ancestors.insert(uid); } diff --git a/cedar-policy-validator/src/entity_manifest.rs b/cedar-policy-validator/src/entity_manifest.rs index 78b7fa1cf..213c351f8 100644 --- a/cedar-policy-validator/src/entity_manifest.rs +++ b/cedar-policy-validator/src/entity_manifest.rs @@ -20,20 +20,24 @@ use std::collections::HashMap; use std::fmt::{Display, Formatter}; use cedar_policy_core::ast::{ - BinaryOp, EntityUID, Expr, ExprKind, Literal, PolicyID, PolicySet, RequestType, UnaryOp, Var, + BinaryOp, EntityUID, Expr, ExprKind, Literal, PolicySet, RequestType, UnaryOp, Var, }; use cedar_policy_core::entities::err::EntitiesError; -use cedar_policy_core::impl_diagnostic_from_source_loc_opt_field; -use cedar_policy_core::parser::Loc; use miette::Diagnostic; use serde::{Deserialize, Serialize}; use serde_with::serde_as; use smol_str::SmolStr; use thiserror::Error; +mod analysis; +mod loader; +pub mod slicing; +mod type_annotations; + +use crate::entity_manifest::analysis::{EntityManifestAnalysisResult, WrappedAccessPaths}; use crate::{ typecheck::{PolicyCheck, Typechecker}, - types::{EntityRecordKind, Type}, + types::Type, ValidationMode, ValidatorSchema, }; use crate::{ValidationResult, Validator}; @@ -52,14 +56,10 @@ use crate::{ValidationResult, Validator}; #[serde_as] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -pub struct EntityManifest -where - T: Clone, -{ +pub struct EntityManifest { /// A map from request types to [`RootAccessTrie`]s. #[serde_as(as = "Vec<(_, _)>")] - #[serde(bound(deserialize = "T: Default"))] - per_action: HashMap>, + pub(crate) per_action: HashMap, } /// A map of data fields to [`AccessTrie`]s. @@ -68,8 +68,8 @@ where // CAUTION: this type is publicly exported in `cedar-policy`. // Don't make fields `pub`, don't make breaking changes, and use caution // when adding public methods. -#[doc = include_str!("../experimental_warning.md")] -pub type Fields = HashMap>>; +#[doc = include_str!("../../cedar-policy/experimental_warning.md")] +pub type Fields = HashMap>; /// The root of a data path or [`RootAccessTrie`]. // CAUTION: this type is publicly exported in `cedar-policy`. @@ -110,14 +110,10 @@ impl Display for EntityRoot { #[serde_as] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -pub struct RootAccessTrie -where - T: Clone, -{ +pub struct RootAccessTrie { /// The data that needs to be loaded, organized by root. #[serde_as(as = "Vec<(_, _)>")] - #[serde(bound(deserialize = "T: Default"))] - trie: HashMap>, + pub(crate) trie: HashMap, } /// A Trie representing a set of data paths to load, @@ -133,74 +129,38 @@ where #[serde_as] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] -pub struct AccessTrie { +pub struct AccessTrie { /// Child data of this entity slice. /// The keys are edges in the trie pointing to sub-trie values. #[serde_as(as = "Vec<(_, _)>")] - children: Fields, - /// For entity types, this boolean may be `true` - /// to signal that all the ancestors in the entity hierarchy - /// are required (transitively). - ancestors_required: bool, - /// Optional data annotation, usually used for type information. - #[serde(skip_serializing, skip_deserializing)] - #[serde(bound(deserialize = "T: Default"))] - data: T, + pub(crate) children: Fields, + /// `ancestors_trie` is another [`RootAccessTrie`] representing + /// all of the ancestors of this entity that are required. + /// The ancestors trie is a subset of the original [`RootAccessTrie`]. + /// See the [`RootAccessTrie::is_ancestor`] annotation. + pub(crate) ancestors_trie: RootAccessTrie, + /// When ancestors are required, each node marked `is_ancestor` + /// represents an ancestor or set of ancestors that are required. + /// An ancestor trie can be thought of as a set of pointers to + /// nodes in the original trie, one `is_ancestor`-marked node per pointer. + pub(crate) is_ancestor: bool, + /// The type of this node in the [`AccessTrie`]. + /// From the public API, this field should always be `Some`. + /// It is `None` after deserialization or after first being constructed, but it is type annotated right away. + #[serde(skip_serializing)] + #[serde(skip_deserializing)] + pub(crate) node_type: Option, } -/// A data path that may end with requesting the parents of -/// an entity. -#[derive(Debug, Clone, PartialEq, Eq)] -struct AccessPath { +/// An access path represents path of fields, starting with an [`EntityRoot`]. +/// Fields may be record fields or entity fields. +/// If an access path ends with an entity type, it may also require the ancestors of the entity. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct AccessPath { /// The root variable that begins the data path pub root: EntityRoot, /// The path of fields of entities or structs pub path: Vec, - /// Request all the parents in the entity hierarchy of this entity. - pub ancestors_required: bool, -} - -/// Entity manifest computation does not handle the full -/// cedar language. In particular, the policies must follow the -/// following grammar: -/// ```text -/// = -/// in -/// + -/// if { } { } -/// ... all other cedar operators not mentioned by datapath-expr - -/// = . -/// has -/// -/// -/// ``` -/// The `get_expr_path` function handles `datapath-expr` expressions. -/// This error message tells the user not to use certain operators -/// before accessing record or entity attributes, breaking this grammar. -// CAUTION: this type is publicly exported in `cedar-policy`. -// Don't make fields `pub`, don't make breaking changes, and use caution -// when adding public methods. -#[derive(Debug, Clone, Error, Hash, Eq, PartialEq)] -#[error("for policy `{policy_id}`, failed to analyze expression while computing entity manifest`")] -pub struct FailedAnalysisError { - /// Source location - source_loc: Option, - /// Policy ID where the error occurred - policy_id: PolicyID, - /// The kind of the expression that was unexpected - expr_kind: ExprKind>, -} - -impl Diagnostic for FailedAnalysisError { - impl_diagnostic_from_source_loc_opt_field!(source_loc); - - fn help<'a>(&'a self) -> Option> { - Some(Box::new(format!( - "failed to compute entity manifest: {} operators are not allowed before accessing record or entity attributes", - self.expr_kind.operator_description() - ))) - } } /// Error when expressions are partial during entity @@ -225,8 +185,6 @@ pub struct PartialRequestError {} impl Diagnostic for PartialRequestError {} /// An error generated by entity slicing. -/// See [`FailedAnalysisError`] for details on the fragment -/// of Cedar handled by entity slicing. #[derive(Debug, Error)] pub enum EntityManifestError { /// A validation error was encountered @@ -243,28 +201,109 @@ pub enum EntityManifestError { /// A policy was partial #[error(transparent)] PartialExpression(#[from] PartialExpressionError), + /// Unsupported feature + #[error(transparent)] + UnsupportedCedarFeature(#[from] UnsupportedCedarFeatureError), +} + +/// Error when entity manifest analysis cannot handle a Cedar feature +// CAUTION: this type is publicly exported in `cedar-policy`. +// Don't make fields `pub`, don't make breaking changes, and use caution +// when adding public methods. +#[derive(Debug, Clone, Error, Diagnostic)] +#[error("entity manifest analysis currently doesn't support Cedar feature: {feature}")] +pub struct UnsupportedCedarFeatureError { + pub(crate) feature: SmolStr, +} - /// A policy was not analyzable because it used unsupported operators - /// before a [`ExprKind::GetAttr`] - /// See [`FailedAnalysisError`] for more details. +/// Error when the manifest has an entity the schema lacks. +// CAUTION: this type is publicly exported in `cedar-policy`. +// Don't make fields `pub`, don't make breaking changes, and use caution +// when adding public methods. +#[derive(Debug, Clone, Error, Hash, Eq, PartialEq)] +#[error("entity manifest doesn't match schema. Schema is missing entity {entity}. Either you wrote an entity manifest by hand (not recommended) or you are using an out-of-date entity manifest with respect to the schema")] +pub struct MismatchedMissingEntityError { + pub(crate) entity: EntityUID, +} + +/// Error when the schema isn't valid in strict mode. +// CAUTION: this type is publicly exported in `cedar-policy`. +// Don't make fields `pub`, don't make breaking changes, and use caution +// when adding public methods. +#[derive(Debug, Clone, Error, Hash, Eq, PartialEq)] +#[error("entity manifests are only compatible with schemas that validate in strict mode. Tried to use an invalid schema with an entity manifest")] +pub struct MismatchedNotStrictSchemaError {} + +/// An error generated by entity manifest parsing. These happen +/// when the entity manifest doesn't conform to the schema. +/// Either the user wrote an entity manifest by hand (not reccomended) +/// or they used an out-of-date entity manifest (after updating the schema). +/// Warning: This error is not guaranteed to happen, even when an entity +/// manifest is out-of-date with respect to a schema! Users must ensure +/// that entity manifests are in-sync with the schema and policies. +// CAUTION: this type is publicly exported in `cedar-policy`. +// Don't make fields `pub`, don't make breaking changes, and use caution +// when adding public methods. +#[derive(Debug, Clone, Error, Hash, Eq, PartialEq)] +pub enum MismatchedEntityManifestError { + /// Mismatch between entity in manifest and schema + #[error(transparent)] + MismatchedMissingEntity(#[from] MismatchedMissingEntityError), + /// Found a schema that isn't valid in strict mode + #[error(transparent)] + MismatchedNotStrictSchema(#[from] MismatchedNotStrictSchemaError), +} + +/// An error generated when parsing entity manifests from json +#[derive(Debug, Error)] +pub enum EntityManifestFromJsonError { + /// A Serde error happened + #[error(transparent)] + SerdeJsonParseError(#[from] serde_json::Error), + /// A mismatched entity manifest error #[error(transparent)] - FailedAnalysis(#[from] FailedAnalysisError), + MismatchedEntityManifest(#[from] MismatchedEntityManifestError), } -impl EntityManifest { +impl EntityManifest { /// Get the contents of the entity manifest /// indexed by the type of the request. - pub fn per_action(&self) -> &HashMap> { + pub fn per_action(&self) -> &HashMap { &self.per_action } + + /// Convert a json string to an [`EntityManifest`]. + /// Requires the schema in order to add type annotations. + pub fn from_json_str( + json: &str, + schema: &ValidatorSchema, + ) -> Result { + match serde_json::from_str::(json) { + Ok(manifest) => manifest.to_typed(schema).map_err(|e| e.into()), + Err(e) => Err(e.into()), + } + } + + /// Convert a json value to an [`EntityManifest`]. + /// Requires the schema in order to add type annotations. + #[allow(dead_code)] + pub fn from_json_value( + value: serde_json::Value, + schema: &ValidatorSchema, + ) -> Result { + match serde_json::from_value::(value) { + Ok(manifest) => manifest.to_typed(schema).map_err(|e| e.into()), + Err(e) => Err(e.into()), + } + } } /// Union two tries by combining the fields. -fn union_fields(first: &Fields, second: &Fields) -> Fields { +fn union_fields(first: &Fields, second: &Fields) -> Fields { let mut res = first.clone(); for (key, value) in second { res.entry(key.clone()) - .and_modify(|existing| *existing = Box::new((*existing).union(value))) + .and_modify(|existing| existing.union_mut(value)) .or_insert(value.clone()); } res @@ -272,39 +311,43 @@ fn union_fields(first: &Fields, second: &Fields) -> Fields { impl AccessPath { /// Convert a [`AccessPath`] into corresponding [`RootAccessTrie`]. - fn to_root_access_trie(&self) -> RootAccessTrie { - self.to_root_access_trie_with_leaf(AccessTrie { - ancestors_required: true, - children: Default::default(), - data: (), - }) + pub fn to_root_access_trie(&self) -> RootAccessTrie { + self.to_root_access_trie_with_leaf(AccessTrie::default()) } /// Convert an [`AccessPath`] to a [`RootAccessTrie`], and also /// add a full trie as the leaf at the end. - fn to_root_access_trie_with_leaf(&self, leaf_trie: AccessTrie) -> RootAccessTrie { + pub(crate) fn to_root_access_trie_with_leaf(&self, leaf_trie: AccessTrie) -> RootAccessTrie { let mut current = leaf_trie; + // reverse the path, visiting the last access first for field in self.path.iter().rev() { let mut fields = HashMap::new(); fields.insert(field.clone(), Box::new(current)); + current = AccessTrie { - ancestors_required: false, + ancestors_trie: Default::default(), + is_ancestor: false, children: fields, - data: (), + node_type: None, }; } let mut primary_map = HashMap::new(); - primary_map.insert(self.root.clone(), current); + + // special case: if the path is completely empty, + // no need to insert anything + if current != AccessTrie::new() { + primary_map.insert(self.root.clone(), current); + } RootAccessTrie { trie: primary_map } } } -impl RootAccessTrie { +impl RootAccessTrie { /// Get the trie as a hash map from [`EntityRoot`] /// to sub-[`AccessTrie`]s. - pub fn trie(&self) -> &HashMap> { + pub fn trie(&self) -> &HashMap { &self.trie } } @@ -318,18 +361,22 @@ impl RootAccessTrie { } } -impl RootAccessTrie { +impl RootAccessTrie { /// Union two [`RootAccessTrie`]s together. /// The new trie requests the data from both of the original. - fn union(&self, other: &Self) -> Self { - let mut res = self.clone(); + pub fn union(mut self, other: &Self) -> Self { + self.union_mut(other); + self + } + + /// Like [`RootAccessTrie::union`], but modifies the current trie. + pub fn union_mut(&mut self, other: &Self) { for (key, value) in &other.trie { - res.trie + self.trie .entry(key.clone()) - .and_modify(|existing| *existing = (*existing).union(value)) + .and_modify(|existing| existing.union_mut(value)) .or_insert(value.clone()); } - res } } @@ -339,46 +386,51 @@ impl Default for RootAccessTrie { } } -impl AccessTrie { +impl AccessTrie { /// Union two [`AccessTrie`]s together. /// The new trie requests the data from both of the original. - fn union(&self, other: &Self) -> Self { - Self { - children: union_fields(&self.children, &other.children), - ancestors_required: self.ancestors_required || other.ancestors_required, - data: self.data.clone(), - } + pub fn union(mut self, other: &Self) -> Self { + self.union_mut(other); + self + } + + /// Like [`AccessTrie::union`], but modifies the current trie. + pub fn union_mut(&mut self, other: &Self) { + self.children = union_fields(&self.children, &other.children); + self.ancestors_trie.union_mut(&other.ancestors_trie); + self.is_ancestor = self.is_ancestor || other.is_ancestor; } /// Get the children of this [`AccessTrie`]. - pub fn children(&self) -> &Fields { + pub fn children(&self) -> &Fields { &self.children } /// Get a boolean which is true if this trie /// requires all ancestors of the entity to be loaded. - pub fn ancestors_required(&self) -> bool { - self.ancestors_required - } - - /// Get the data associated with this [`AccessTrie`]. - /// This is usually `()` unless it is annotated by a type. - pub fn data(&self) -> &T { - &self.data + pub fn ancestors_required(&self) -> &RootAccessTrie { + &self.ancestors_trie } } impl AccessTrie { /// A new trie that requests no data. - fn new() -> Self { + pub(crate) fn new() -> Self { Self { children: Default::default(), - ancestors_required: false, - data: (), + ancestors_trie: Default::default(), + is_ancestor: false, + node_type: None, } } } +impl Default for AccessTrie { + fn default() -> Self { + Self::new() + } +} + /// Computes an [`EntityManifest`] from the schema and policies. /// The policies must validate against the schema in strict mode, /// otherwise an error is returned. @@ -405,7 +457,7 @@ pub fn compute_entity_manifest( PolicyCheck::Success(typechecked_expr) => { // compute the trie from the typechecked expr // using static analysis - compute_root_trie(&typechecked_expr, policy.id()) + entity_manifest_from_expr(&typechecked_expr).map(|val| val.global_trie) } PolicyCheck::Irrelevant(_, _) => { // this policy is irrelevant, so we need no data @@ -422,259 +474,188 @@ pub fn compute_entity_manifest( let request_type = request_env .to_request_type() .ok_or(PartialRequestError {})?; - // Add to the manifest based on the request type. manifest .entry(request_type) - .and_modify(|existing| { - *existing = existing.union(&new_primary_slice); - }) + .and_modify(|existing| existing.union_mut(&new_primary_slice)) .or_insert(new_primary_slice); } } + // PANIC SAFETY: entity manifest cannot be out of date, since it was computed from the schema given + #[allow(clippy::unwrap_used)] Ok(EntityManifest { per_action: manifest, - }) + } + .to_typed(schema) + .unwrap()) } /// A static analysis on type-annotated cedar expressions. /// Computes the [`RootAccessTrie`] representing all the data required /// to evaluate the expression. -fn compute_root_trie( - expr: &Expr>, - policy_id: &PolicyID, -) -> Result { - let mut primary_slice = RootAccessTrie::new(); - add_to_root_trie(&mut primary_slice, expr, policy_id, false)?; - Ok(primary_slice) -} - -/// Add the expression's requested data to the [`RootAccessTrie`]. -/// This handles s from the grammar (see [`FailedAnalysisError`]) -/// while [`get_expr_path`] handles the s. -fn add_to_root_trie( - root_trie: &mut RootAccessTrie, +fn entity_manifest_from_expr( expr: &Expr>, - policy_id: &PolicyID, - should_load_all: bool, -) -> Result<(), EntityManifestError> { +) -> Result { match expr.expr_kind() { - // Literals, variables, and unkonwns without any GetAttr operations - // on them are okay, since no fields need to be loaded. - ExprKind::Lit(_) => Ok(()), - ExprKind::Var(_) => Ok(()), - ExprKind::Slot(_) => Ok(()), + ExprKind::Slot(slot_id) => { + if slot_id.is_principal() { + Ok(EntityManifestAnalysisResult::from_root(EntityRoot::Var( + Var::Principal, + ))) + } else { + assert!(slot_id.is_resource()); + Ok(EntityManifestAnalysisResult::from_root(EntityRoot::Var( + Var::Resource, + ))) + } + } + ExprKind::Var(var) => Ok(EntityManifestAnalysisResult::from_root(EntityRoot::Var( + *var, + ))), + ExprKind::Lit(Literal::EntityUID(literal)) => Ok(EntityManifestAnalysisResult::from_root( + EntityRoot::Literal((**literal).clone()), + )), ExprKind::Unknown(_) => Err(PartialExpressionError {})?, + + // Non-entity literals need no fields to be loaded. + ExprKind::Lit(_) => Ok(EntityManifestAnalysisResult::default()), ExprKind::If { test_expr, then_expr, else_expr, - } => { - add_to_root_trie(root_trie, test_expr, policy_id, should_load_all)?; - add_to_root_trie(root_trie, then_expr, policy_id, should_load_all)?; - add_to_root_trie(root_trie, else_expr, policy_id, should_load_all)?; - Ok(()) - } - ExprKind::And { left, right } => { - add_to_root_trie(root_trie, left, policy_id, should_load_all)?; - add_to_root_trie(root_trie, right, policy_id, should_load_all)?; - Ok(()) - } - ExprKind::Or { left, right } => { - add_to_root_trie(root_trie, left, policy_id, should_load_all)?; - add_to_root_trie(root_trie, right, policy_id, should_load_all)?; - Ok(()) - } + } => Ok(entity_manifest_from_expr(test_expr)? + .empty_paths() + .union(&entity_manifest_from_expr(then_expr)?) + .union(&entity_manifest_from_expr(else_expr)?)), + ExprKind::And { left, right } + | ExprKind::Or { left, right } + | ExprKind::BinaryApp { + op: BinaryOp::Less | BinaryOp::LessEq | BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul, + arg1: left, + arg2: right, + } => Ok(entity_manifest_from_expr(left)? + .empty_paths() + .union(&entity_manifest_from_expr(right)?.empty_paths())), ExprKind::UnaryApp { op, arg } => { match op { - UnaryOp::Not => add_to_root_trie(root_trie, arg, policy_id, should_load_all)?, - UnaryOp::Neg => add_to_root_trie(root_trie, arg, policy_id, should_load_all)?, - }; - Ok(()) - } - ExprKind::BinaryApp { op, arg1, arg2 } => match op { - // Special case! Equality between records requires - // that we load all fields. - // This could be made more precise if we check the type. - BinaryOp::Eq => { - add_to_root_trie(root_trie, arg1, policy_id, true)?; - add_to_root_trie(root_trie, arg2, policy_id, true)?; - Ok(()) - } - BinaryOp::In => { - // Recur normally on the rhs - add_to_root_trie(root_trie, arg2, policy_id, should_load_all)?; - - // The lhs should be a datapath expression. - let mut flat_slice = get_expr_path(arg1, policy_id)?; - flat_slice.ancestors_required = true; - *root_trie = root_trie.union(&flat_slice.to_root_access_trie()); - Ok(()) - } - BinaryOp::Contains | BinaryOp::ContainsAll | BinaryOp::ContainsAny => { - // Like equality, another special case for records. - add_to_root_trie(root_trie, arg1, policy_id, true)?; - add_to_root_trie(root_trie, arg2, policy_id, true)?; - Ok(()) - } - BinaryOp::Less | BinaryOp::LessEq | BinaryOp::Add | BinaryOp::Sub | BinaryOp::Mul => { - // These operators work on literals, so no special - // case is needed. - add_to_root_trie(root_trie, arg1, policy_id, should_load_all)?; - add_to_root_trie(root_trie, arg2, policy_id, should_load_all)?; - Ok(()) + // both unary ops are on booleans, so they are simple + UnaryOp::Not | UnaryOp::Neg => Ok(entity_manifest_from_expr(arg)?.empty_paths()), } - BinaryOp::GetTag | BinaryOp::HasTag => { - unimplemented!("interaction between RFCs 74 and 82") + } + ExprKind::BinaryApp { + op: + BinaryOp::Eq + | BinaryOp::In + | BinaryOp::Contains + | BinaryOp::ContainsAll + | BinaryOp::ContainsAny, + arg1, + arg2, + } => { + // TODO Is there more elegant way to bind op using rust pattern matching? + // PANIC SAFETY: Matched a binary app above, so expr must still be a binary app. + #[allow(clippy::panic)] + let ExprKind::BinaryApp { op, .. } = expr.expr_kind() else { + panic!("Matched above"); + }; + + // First, find the data paths for each argument + let mut arg1_res = entity_manifest_from_expr(arg1)?; + let arg2_res = entity_manifest_from_expr(arg2)?; + + // PANIC SAFETY: Typechecking succeeded, so type annotations are present. + #[allow(clippy::expect_used)] + let ty1 = arg1 + .data() + .as_ref() + .expect("Expected annotated types after typechecking"); + // PANIC SAFETY: Typechecking succeeded, so type annotations are present. + #[allow(clippy::expect_used)] + let ty2 = arg2 + .data() + .as_ref() + .expect("Expected annotated types after typechecking"); + + // For the `in` operator, we need the ancestors of entities. + if let BinaryOp::In = op { + arg1_res = arg1_res + .with_ancestors_required(&arg2_res.resulting_paths.to_ancestor_access_trie()); } - }, + + // Load all fields using `full_type_required`, since + // these operations do equality checks. + Ok(arg1_res + .full_type_required(ty1) + .union(&arg2_res.full_type_required(ty2)) + .empty_paths()) + } + ExprKind::BinaryApp { + op: BinaryOp::GetTag | BinaryOp::HasTag, + arg1: _, + arg2: _, + } => Err(UnsupportedCedarFeatureError { + feature: "entity tags".into(), + } + .into()), ExprKind::ExtensionFunctionApp { fn_name: _, args } => { // WARNING: this code assumes that extension functions - // don't take full structs as inputs. - // If they did, we would need to use logic similar to the Eq binary operator. + // all take primitives as inputs and produce + // primitives as outputs. + // If not, we would need to use logic similar to the Eq binary operator. + + let mut res = EntityManifestAnalysisResult::default(); + for arg in args.iter() { - add_to_root_trie(root_trie, arg, policy_id, should_load_all)?; + res = res.union(&entity_manifest_from_expr(arg)?); } - Ok(()) - } - ExprKind::Like { expr, pattern: _ } => { - add_to_root_trie(root_trie, expr, policy_id, should_load_all)?; - Ok(()) + Ok(res) } - ExprKind::Is { + ExprKind::Like { expr, pattern: _ } + | ExprKind::Is { expr, entity_type: _, } => { - add_to_root_trie(root_trie, expr, policy_id, should_load_all)?; - Ok(()) + // drop paths since boolean returned + Ok(entity_manifest_from_expr(expr)?.empty_paths()) } ExprKind::Set(contents) => { + let mut res = EntityManifestAnalysisResult::default(); + + // take union of all of the contents for expr in &**contents { - add_to_root_trie(root_trie, expr, policy_id, should_load_all)?; - } - Ok(()) - } - ExprKind::Record(content) => { - for expr in content.values() { - add_to_root_trie(root_trie, expr, policy_id, should_load_all)?; + let content = entity_manifest_from_expr(expr)?; + + res = res.union(&content); } - Ok(()) - } - ExprKind::HasAttr { expr, attr } => { - let mut flat_slice = get_expr_path(expr, policy_id)?; - flat_slice.path.push(attr.clone()); - *root_trie = root_trie.union(&flat_slice.to_root_access_trie()); - Ok(()) - } - ExprKind::GetAttr { .. } => { - let flat_slice = get_expr_path(expr, policy_id)?; - // PANIC SAFETY: Successfuly typechecked expressions should always have annotated types. - #[allow(clippy::expect_used)] - let leaf_field = if should_load_all { - type_to_access_trie( - expr.data() - .as_ref() - .expect("Typechecked expression missing type"), - ) - } else { - AccessTrie::new() - }; + // now, wrap result in a set + res.resulting_paths = WrappedAccessPaths::SetLiteral(Box::new(res.resulting_paths)); - *root_trie = root_trie.union(&flat_slice.to_root_access_trie_with_leaf(leaf_field)); - Ok(()) + Ok(res) } - } -} + ExprKind::Record(content) => { + let mut record_contents = HashMap::new(); + let mut global_trie = RootAccessTrie::default(); -/// Compute the full [`AccessTrie`] required for the type. -fn type_to_access_trie(ty: &Type) -> AccessTrie { - match ty { - // if it's not an entity or record, slice ends here - Type::ExtensionType { .. } - | Type::Never - | Type::True - | Type::False - | Type::Primitive { .. } - | Type::Set { .. } => AccessTrie::new(), - Type::EntityOrRecord(record_type) => entity_or_record_to_access_trie(record_type), - } -} + for (key, child_expr) in content.iter() { + let res = entity_manifest_from_expr(child_expr)?; + record_contents.insert(key.clone(), Box::new(res.resulting_paths)); -/// Compute the full [`AccessTrie`] for the given entity or record type. -fn entity_or_record_to_access_trie(ty: &EntityRecordKind) -> AccessTrie { - match ty { - EntityRecordKind::ActionEntity { attrs, .. } | EntityRecordKind::Record { attrs, .. } => { - let mut fields = HashMap::new(); - for (attr_name, attr_type) in attrs.iter() { - fields.insert( - attr_name.clone(), - Box::new(type_to_access_trie(&attr_type.attr_type)), - ); - } - AccessTrie { - children: fields, - ancestors_required: false, - data: (), + global_trie = global_trie.union(&res.global_trie); } - } - EntityRecordKind::Entity(_) | EntityRecordKind::AnyEntity => { - // no need to load data for entities, which are compared - // using ids - AccessTrie::new() + Ok(EntityManifestAnalysisResult { + resulting_paths: WrappedAccessPaths::RecordLiteral(record_contents), + global_trie, + }) } - } -} - -/// Given an expression, get the corresponding data path -/// starting with a variable. -/// If the expression is not a ``, return an error. -/// See [`FailedAnalysisError`] for more information. -fn get_expr_path( - expr: &Expr>, - policy_id: &PolicyID, -) -> Result { - Ok(match expr.expr_kind() { - ExprKind::Slot(slot_id) => { - if slot_id.is_principal() { - AccessPath { - root: EntityRoot::Var(Var::Principal), - path: vec![], - ancestors_required: false, - } - } else { - assert!(slot_id.is_resource()); - AccessPath { - root: EntityRoot::Var(Var::Resource), - path: vec![], - ancestors_required: false, - } - } - } - ExprKind::Var(var) => AccessPath { - root: EntityRoot::Var(*var), - path: vec![], - ancestors_required: false, - }, ExprKind::GetAttr { expr, attr } => { - let mut slice = get_expr_path(expr, policy_id)?; - slice.path.push(attr.clone()); - slice + Ok(entity_manifest_from_expr(expr)?.get_or_has_attr(attr)) } - ExprKind::Lit(Literal::EntityUID(literal)) => AccessPath { - root: EntityRoot::Literal((**literal).clone()), - path: vec![], - ancestors_required: false, - }, - ExprKind::Unknown(_) => Err(PartialExpressionError {})?, - // all other variants of expressions result in failure to analyze. - _ => Err(EntityManifestError::FailedAnalysis(FailedAnalysisError { - source_loc: expr.source_loc().cloned(), - policy_id: policy_id.clone(), - expr_kind: expr.expr_kind().clone(), - }))?, - }) + ExprKind::HasAttr { expr, attr } => Ok(entity_manifest_from_expr(expr)? + .get_or_has_attr(attr) + .empty_paths()), + } } #[cfg(test)] @@ -704,15 +685,38 @@ action Read appliesTo { .0 } + fn document_fields_schema() -> ValidatorSchema { + ValidatorSchema::from_cedarschema_str( + " +entity User = { +name: String, +}; + +entity Document = { +owner: User, +viewer: User, +}; + +action Read appliesTo { +principal: [User], +resource: [Document] +}; +", + Extensions::all_available(), + ) + .unwrap() + .0 + } + #[test] fn test_simple_entity_manifest() { let mut pset = PolicySet::new(); let policy = parse_policy( None, - "permit(principal, action, resource) + r#"permit(principal, action, resource) when { - principal.name == \"John\" -};", + principal.name == "John" +};"#, ) .expect("should succeed"); pset.add(policy.into()).expect("should succeed"); @@ -720,6 +724,40 @@ when { let schema = schema(); let entity_manifest = compute_entity_manifest(&schema, &pset).expect("Should succeed"); + let expected_rust = EntityManifest { + per_action: [( + RequestType { + principal: "User".parse().unwrap(), + resource: "Document".parse().unwrap(), + action: r#"Action::"Read""#.parse().unwrap(), + }, + RootAccessTrie { + trie: [( + EntityRoot::Var(Var::Principal), + AccessTrie { + children: [( + SmolStr::new("name"), + Box::new(AccessTrie { + children: HashMap::new(), + ancestors_trie: RootAccessTrie::new(), + is_ancestor: false, + node_type: Some(Type::primitive_string()), + }), + )] + .into_iter() + .collect(), + ancestors_trie: RootAccessTrie::new(), + is_ancestor: false, + node_type: Some(Type::named_entity_reference("User".parse().unwrap())), + }, + )] + .into_iter() + .collect(), + }, + )] + .into_iter() + .collect(), + }; let expected = serde_json::json! ({ "perAction": [ [ @@ -743,11 +781,13 @@ when { "name", { "children": [], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ] ], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ] ] @@ -755,8 +795,9 @@ when { ] ] }); - let expected_manifest = serde_json::from_value(expected).unwrap(); + let expected_manifest = EntityManifest::from_json_value(expected, &schema).unwrap(); assert_eq!(entity_manifest, expected_manifest); + assert_eq!(entity_manifest, expected_rust); } #[test] @@ -788,7 +829,7 @@ when { ] ] }); - let expected_manifest = serde_json::from_value(expected).unwrap(); + let expected_manifest = EntityManifest::from_json_value(expected, &schema).unwrap(); assert_eq!(entity_manifest, expected_manifest); } @@ -847,11 +888,39 @@ action Read appliesTo { "manager", { "children": [], - "ancestorsRequired": true + "ancestorsTrie": { + "trie": [ + [ + { + "var": "resource", + }, + { + "children": [], + "isAncestor": true, + "ancestorsTrie": { "trie": [] } + } + ] + ] + }, + "isAncestor": false } ] ], - "ancestorsRequired": true + "ancestorsTrie": { + "trie": [ + [ + { + "var": "resource", + }, + { + "children": [], + "isAncestor": true, + "ancestorsTrie": { "trie": [] } + } + ] + ] + }, + "isAncestor": false } ] ] @@ -859,7 +928,7 @@ action Read appliesTo { ] ] }); - let expected_manifest = serde_json::from_value(expected).unwrap(); + let expected_manifest = EntityManifest::from_json_value(expected, &schema).unwrap(); assert_eq!(entity_manifest, expected_manifest); } @@ -868,10 +937,10 @@ action Read appliesTo { let mut pset = PolicySet::new(); let policy = parse_policy( None, - "permit(principal, action, resource) + r#"permit(principal, action, resource) when { - principal.name == \"John\" -};", + principal.name == "John" +};"#, ) .expect("should succeed"); pset.add(policy.into()).expect("should succeed"); @@ -924,11 +993,13 @@ action Read appliesTo { "name", { "children": [], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ] ], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ] ] @@ -955,11 +1026,13 @@ action Read appliesTo { "name", { "children": [], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ] ], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ] ] @@ -967,7 +1040,7 @@ action Read appliesTo { ] ] }); - let expected_manifest = serde_json::from_value(expected).unwrap(); + let expected_manifest = EntityManifest::from_json_value(expected, &schema).unwrap(); assert_eq!(entity_manifest, expected_manifest); } @@ -1057,30 +1130,34 @@ action Read appliesTo { "owner", { "children": [], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ] ], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ], [ "readers", { "children": [], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ] ], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } - ] + ], ] } ] ] }); - let expected_manifest = serde_json::from_value(expected).unwrap(); + let expected_manifest = EntityManifest::from_json_value(expected, &schema).unwrap(); assert_eq!(entity_manifest, expected_manifest); } @@ -1153,22 +1230,26 @@ action BeSad appliesTo { "nickname", { "children": [], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ], [ "friends", { "children": [], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ] ], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ] ], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ] ] @@ -1176,7 +1257,7 @@ action BeSad appliesTo { ] ] }); - let expected_manifest = serde_json::from_value(expected).unwrap(); + let expected_manifest = EntityManifest::from_json_value(expected, &schema).unwrap(); assert_eq!(entity_manifest, expected_manifest); } @@ -1246,22 +1327,26 @@ action Hello appliesTo { "friends", { "children": [], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ], [ "nickname", { "children": [], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ] ], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ] ], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ], [ @@ -1278,22 +1363,26 @@ action Hello appliesTo { "nickname", { "children": [], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ], [ "friends", { "children": [], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ] ], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ] ], - "ancestorsRequired": false + "ancestorsTrie": { "trie": []}, + "isAncestor": false } ] ] @@ -1301,7 +1390,231 @@ action Hello appliesTo { ] ] }); - let expected_manifest = serde_json::from_value(expected).unwrap(); + let expected_manifest = EntityManifest::from_json_value(expected, &schema).unwrap(); + assert_eq!(entity_manifest, expected_manifest); + } + + #[test] + fn test_entity_manifest_with_if() { + let mut pset = PolicySet::new(); + + let schema = document_fields_schema(); + + let policy = parse_policy( + None, + r#"permit(principal, action, resource) +when { + if principal.name == "John" + then resource.owner.name == User::"oliver".name + else resource.viewer == User::"oliver" +};"#, + ) + .expect("should succeed"); + pset.add(policy.into()).expect("should succeed"); + + let entity_manifest = compute_entity_manifest(&schema, &pset).expect("Should succeed"); + let expected = serde_json::json! ( { + "perAction": [ + [ + { + "principal": "User", + "action": { + "ty": "Action", + "eid": "Read" + }, + "resource": "Document" + }, + { + "trie": [ + [ + { + "var": "principal" + }, + { + "children": [ + [ + "name", + { + "children": [], + "ancestorsTrie": { "trie": []}, + "isAncestor": false + } + ] + ], + "ancestorsTrie": { "trie": []}, + "isAncestor": false + } + ], + [ + { + "literal": { + "ty": "User", + "eid": "oliver" + } + }, + { + "children": [ + [ + "name", + { + "children": [], + "ancestorsTrie": { "trie": []}, + "isAncestor": false + } + ] + ], + "ancestorsTrie": { "trie": []}, + "isAncestor": false + } + ], + [ + { + "var": "resource" + }, + { + "children": [ + [ + "viewer", + { + "children": [], + "ancestorsTrie": { "trie": []}, + "isAncestor": false + } + ], + [ + "owner", + { + "children": [ + [ + "name", + { + "children": [], + "ancestorsTrie": { "trie": []}, + "isAncestor": false + } + ] + ], + "ancestorsTrie": { "trie": []}, + "isAncestor": false + } + ] + ], + "ancestorsTrie": { "trie": []}, + "isAncestor": false + } + ] + ] + } + ] + ] + } + ); + let expected_manifest = EntityManifest::from_json_value(expected, &schema).unwrap(); + assert_eq!(entity_manifest, expected_manifest); + } + + #[test] + fn test_entity_manifest_if_literal_record() { + let mut pset = PolicySet::new(); + + let schema = document_fields_schema(); + + let policy = parse_policy( + None, + r#"permit(principal, action, resource) +when { + { + "myfield": + { + "secondfield": + if principal.name == "yihong" + then principal + else resource.owner, + "ignored but still important due to errors": + resource.viewer + } + }["myfield"]["secondfield"].name == "pavel" +};"#, + ) + .expect("should succeed"); + pset.add(policy.into()).expect("should succeed"); + + let entity_manifest = compute_entity_manifest(&schema, &pset).expect("Should succeed"); + let expected = serde_json::json! ( { + "perAction": [ + [ + { + "principal": "User", + "action": { + "ty": "Action", + "eid": "Read" + }, + "resource": "Document" + }, + { + "trie": [ + [ + { + "var": "principal" + }, + { + "children": [ + [ + "name", + { + "children": [], + "ancestorsTrie": { "trie": []}, + "isAncestor": false + } + ] + ], + "ancestorsTrie": { "trie": []}, + "isAncestor": false + } + ], + [ + { + "var": "resource" + }, + { + "children": [ + [ + "viewer", + { + "children": [], + "ancestorsTrie": { "trie": []}, + "isAncestor": false + } + ], + [ + "owner", + { + "children": [ + [ + "name", + { + "children": [], + "ancestorsTrie": { "trie": []}, + "isAncestor": false + } + ] + ], + "ancestorsTrie": { "trie": []}, + "isAncestor": false + } + ] + ], + "ancestorsTrie": { "trie": []}, + "isAncestor": false + } + ] + ] + } + ] + ] + } + ); + let expected_manifest = EntityManifest::from_json_value(expected, &schema).unwrap(); assert_eq!(entity_manifest, expected_manifest); } } diff --git a/cedar-policy-validator/src/entity_manifest/analysis.rs b/cedar-policy-validator/src/entity_manifest/analysis.rs new file mode 100644 index 000000000..bbe1ef77e --- /dev/null +++ b/cedar-policy-validator/src/entity_manifest/analysis.rs @@ -0,0 +1,293 @@ +use std::collections::HashMap; + +use smol_str::SmolStr; + +use crate::{ + entity_manifest::{AccessPath, AccessTrie, EntityRoot, RootAccessTrie}, + types::{EntityRecordKind, Type}, +}; + +/// Represents [`AccessPath`]s possibly +/// wrapped in record or set literals. +/// +/// This allows the Entity Manifest to soundly handle +/// data that is wrapped in record or set literals, then used in equality +/// operators or dereferenced. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub(crate) enum WrappedAccessPaths { + /// No access paths are needed. + #[default] + Empty, + /// A single access path, starting with a cedar variable. + AccessPath(AccessPath), + /// The union of two [`WrappedAccessPaths`], denoting that + /// all access paths from both are required. + /// This is useful for join points in the analysis (`if`, set literals, ect) + Union(Box, Box), + /// A record literal, each field having access paths. + RecordLiteral(HashMap>), + /// A set literal containing access paths. + /// Used to note that this type is wrapped in a literal set. + SetLiteral(Box), +} + +/// During Entity Manifest analysis, each sub-expression +/// produces an [`EntityManifestAnalysisResult`]. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub(crate) struct EntityManifestAnalysisResult { + /// INVARIANT: The `global_trie` stores all of the data paths this sub-expression + /// could have accessed, including all those in `resulting_paths`. + pub(crate) global_trie: RootAccessTrie, + /// `resulting_paths` stores a list of `AccessPathRecord`, + /// Each representing a data path + /// (possibly wrapped in a record literal) + /// that could be accessed using the `.` operator. + pub(crate) resulting_paths: WrappedAccessPaths, +} + +impl EntityManifestAnalysisResult { + /// Drop the resulting paths part of the analysis. + /// This is necessary when the expression is a primitive value, so it + /// can't be dereferenced. + pub fn empty_paths(mut self) -> Self { + self.resulting_paths = Default::default(); + self + } + + /// Union two [`EntityManifestAnalysisResult`]s together, + /// keeping the paths from both global tries and concatenating + /// the resulting paths. + pub fn union(mut self, other: &Self) -> Self { + self.global_trie = self.global_trie.union(&other.global_trie); + self.resulting_paths = WrappedAccessPaths::Union( + Box::new(self.resulting_paths), + Box::new(other.resulting_paths.clone()), + ); + self + } + + /// Create an analysis result that starts with a cedar variable + pub fn from_root(root: EntityRoot) -> Self { + let path = AccessPath { root, path: vec![] }; + Self { + global_trie: path.to_root_access_trie(), + resulting_paths: WrappedAccessPaths::AccessPath(path), + } + } + + /// Extend all the access paths with this attr, + /// adding all the new paths to the global trie. + pub fn get_or_has_attr(mut self, attr: &SmolStr) -> Self { + self.resulting_paths = self.resulting_paths.get_or_has_attr(attr); + + self.restore_global_trie_invariant() + } + + /// Restores the `global_trie` invariant by adding all paths + /// in `resulting_paths` to the `global_trie`. + /// This is necessary after modifying the `resulting_paths`. + pub(crate) fn restore_global_trie_invariant(mut self) -> Self { + self.global_trie.add_wrapped_access_paths( + &self.resulting_paths, + false, + &Default::default(), + ); + self + } + + /// Add the ancestors required flag to all of the + /// resulting paths for this analysis result, but only set it + /// for entity types. + /// Add the ancestors required flag to all of the resulting + /// paths for this path record. + pub(crate) fn with_ancestors_required(mut self, ancestors_trie: &RootAccessTrie) -> Self { + self.global_trie + .add_wrapped_access_paths(&self.resulting_paths, false, ancestors_trie); + self + } + + /// For equality or containment checks, all paths in the type + /// are required. + /// This function extends the paths with the fields mentioned + /// by the type, adding these to the global trie. + /// + /// It also drops the resulting paths, since these checks result + /// in booleans. + pub(crate) fn full_type_required(mut self, ty: &Type) -> Self { + let mut paths = Default::default(); + std::mem::swap(&mut self.resulting_paths, &mut paths); + + self.global_trie = self.global_trie.union(&paths.full_type_required(ty)); + + self + } +} + +impl WrappedAccessPaths { + /// Add accessting this attribute to all access paths + fn get_or_has_attr(self, attr: &SmolStr) -> Self { + match self { + WrappedAccessPaths::AccessPath(mut access_path) => { + access_path.path.push(attr.clone()); + WrappedAccessPaths::AccessPath(access_path) + } + WrappedAccessPaths::RecordLiteral(mut record) => { + if let Some(field) = record.remove(attr) { + *field + } else { + // otherwise, this is a `has` expression + // but the record literal didn't have it. + // do nothing in this case + WrappedAccessPaths::RecordLiteral(record) + } + } + // PANIC SAFETY: Type checker should prevent using `.` operator on a set type. + #[allow(clippy::panic)] + WrappedAccessPaths::SetLiteral(_) => { + panic!("Attempted to dereference a set literal.") + } + WrappedAccessPaths::Empty => WrappedAccessPaths::Empty, + WrappedAccessPaths::Union(left, right) => WrappedAccessPaths::Union( + Box::new(left.get_or_has_attr(attr)), + Box::new(right.get_or_has_attr(attr)), + ), + } + } + + fn full_type_required(self, ty: &Type) -> RootAccessTrie { + match self { + WrappedAccessPaths::AccessPath(path) => { + let leaf_trie = type_to_access_trie(ty); + path.to_root_access_trie_with_leaf(leaf_trie.clone()) + } + WrappedAccessPaths::RecordLiteral(mut literal_fields) => match ty { + Type::EntityOrRecord(EntityRecordKind::Record { + attrs: record_attrs, + .. + }) => { + let mut res = RootAccessTrie::new(); + for (attr, attr_ty) in &record_attrs.attrs { + // PANIC SAFETY: Record literals should have attributes that match the type. + #[allow(clippy::panic)] + let field = literal_fields + .remove(attr) + .unwrap_or_else(|| panic!("Missing field {attr} in record literal")); + + res = res.union(&field.full_type_required(&attr_ty.attr_type)); + } + + res + } + // PANIC SAFETY: Typechecking should identify record literals as record types. + #[allow(clippy::panic)] + _ => { + panic!("Found record literal when expected {} type", ty); + } + }, + WrappedAccessPaths::SetLiteral(elements) => match ty { + Type::Set { element_type } => { + // PANIC SAFETY: Typechecking should give concrete types for set elements. + #[allow(clippy::expect_used)] + let ele_type = element_type + .as_ref() + .expect("Expected concrete set type after typechecking"); + elements.full_type_required(ele_type) + } + // PANIC SAFETY: Typechecking should identify set literals as set types. + #[allow(clippy::panic)] + _ => { + panic!("Found set literal when expected {} type", ty); + } + }, + WrappedAccessPaths::Empty => RootAccessTrie::new(), + WrappedAccessPaths::Union(left, right) => left + .full_type_required(ty) + .union(&right.full_type_required(ty)), + } + } + + pub(crate) fn to_ancestor_access_trie(&self) -> RootAccessTrie { + let mut trie = RootAccessTrie::default(); + trie.add_wrapped_access_paths(self, true, &Default::default()); + trie + } +} + +impl RootAccessTrie { + pub(crate) fn add_wrapped_access_paths( + &mut self, + path: &WrappedAccessPaths, + is_ancestor: bool, + ancestors_trie: &RootAccessTrie, + ) { + match path { + WrappedAccessPaths::AccessPath(access_path) => { + let mut leaf = AccessTrie::new(); + leaf.is_ancestor = is_ancestor; + leaf.ancestors_trie = ancestors_trie.clone(); + self.add_access_path(access_path, &leaf); + } + WrappedAccessPaths::RecordLiteral(record) => { + for field in record.values() { + self.add_wrapped_access_paths(field, is_ancestor, ancestors_trie); + } + } + WrappedAccessPaths::SetLiteral(elements) => { + self.add_wrapped_access_paths(elements, is_ancestor, ancestors_trie) + } + WrappedAccessPaths::Empty => (), + WrappedAccessPaths::Union(left, right) => { + self.add_wrapped_access_paths(left, is_ancestor, ancestors_trie); + self.add_wrapped_access_paths(right, is_ancestor, ancestors_trie); + } + } + } + + pub(crate) fn add_access_path(&mut self, access_path: &AccessPath, leaf_trie: &AccessTrie) { + // could be more efficient by mutating self + // instead we use the existing union function. + let other_trie = access_path.to_root_access_trie_with_leaf(leaf_trie.clone()); + self.union_mut(&other_trie) + } +} + +/// Compute the full [`AccessTrie`] required for the type. +fn type_to_access_trie(ty: &Type) -> AccessTrie { + match ty { + // if it's not an entity or record, slice ends here + Type::ExtensionType { .. } + | Type::Never + | Type::True + | Type::False + | Type::Primitive { .. } + | Type::Set { .. } => AccessTrie::new(), + Type::EntityOrRecord(record_type) => entity_or_record_to_access_trie(record_type), + } +} + +/// Compute the full [`AccessTrie`] for the given entity or record type. +fn entity_or_record_to_access_trie(ty: &EntityRecordKind) -> AccessTrie { + match ty { + EntityRecordKind::ActionEntity { attrs, .. } | EntityRecordKind::Record { attrs, .. } => { + let mut fields = HashMap::new(); + for (attr_name, attr_type) in attrs.iter() { + fields.insert( + attr_name.clone(), + Box::new(type_to_access_trie(&attr_type.attr_type)), + ); + } + AccessTrie { + children: fields, + ancestors_trie: Default::default(), + is_ancestor: false, + node_type: None, + } + } + + EntityRecordKind::Entity(_) | EntityRecordKind::AnyEntity => { + // no need to load data for entities, which are compared + // using ids + AccessTrie::new() + } + } +} diff --git a/cedar-policy-validator/src/entity_manifest/loader.rs b/cedar-policy-validator/src/entity_manifest/loader.rs new file mode 100644 index 000000000..902f80c3c --- /dev/null +++ b/cedar-policy-validator/src/entity_manifest/loader.rs @@ -0,0 +1,460 @@ +/* + * Copyright Cedar Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! Entity Loader API implementation +//! Loads entities based on the entity manifest. + +use std::{ + collections::{BTreeMap, HashMap, HashSet}, + sync::Arc, +}; + +use cedar_policy_core::{ + ast::{Context, Entity, EntityUID, Literal, PartialValue, Request, Value, ValueKind, Var}, + entities::{Entities, NoEntitiesSchema, TCComputation}, + extensions::Extensions, +}; +use smol_str::SmolStr; + +use crate::entity_manifest::{ + slicing::{ + EntitySliceError, PartialContextError, PartialEntityError, WrongNumberOfEntitiesError, + }, + AccessTrie, EntityManifest, EntityRoot, PartialRequestError, RootAccessTrie, +}; + +/// A request that an entity be loaded. +/// Optionally, instead of loading the full entity the `access_trie` +/// may be used to load only some fields of the entity. +#[derive(Debug)] +pub(crate) struct EntityRequest { + /// The id of the entity requested + pub(crate) entity_id: EntityUID, + /// The fieds of the entity requested + pub(crate) access_trie: AccessTrie, +} + +/// An entity request may be an entity or `None` when +/// the entity is not present. +pub(crate) type EntityAnswer = Option; + +/// The entity request before sub-entitity tries have been +/// pruned using `prune_child_entity_dereferences`. +pub(crate) struct EntityRequestRef<'a> { + entity_id: EntityUID, + access_trie: &'a AccessTrie, +} + +impl<'a> EntityRequestRef<'a> { + fn to_request(&self) -> EntityRequest { + EntityRequest { + entity_id: self.entity_id.clone(), + access_trie: self.access_trie.prune_child_entity_dereferences(), + } + } +} + +/// A request that the ancestors of an entity be loaded. +/// Optionally, the `ancestors` set may be used to just load ancestors in the set. +#[derive(Debug)] +pub(crate) struct AncestorsRequest { + /// The id of the entity whose ancestors are requested + pub(crate) entity_id: EntityUID, + /// The ancestors that are requested, if present + pub(crate) ancestors: HashSet, +} + +/// Implement [`EntityLoader`] to easily load entities using their ids +/// into a Cedar [`Entities`] store. +/// The most basic implementation loads full entities (including all ancestors) in the `load_entities` method and loads the context in the `load_context` method. +/// More advanced implementations make use of the [`AccessTrie`]s provided to load partial entities and context, as well as the `load_ancestors` method to load particular ancestors. +/// +/// Warning: `load_entities` is called multiple times. If database +/// consistency is required, this API should not be used. Instead, use the entity manifest directly. +pub(crate) trait EntityLoader { + /// `load_entities` is called multiple times to load entities based on their ids. + /// For each entity request in the `to_load` vector, expects one loaded entity in the resulting vector. + /// Each [`EntityRequest`] comes with an [`AccessTrie`], which can optionally be used. + /// Only fields mentioned in the entity's [`AccessTrie`] are needed, but it is sound to provide other fields as well. + /// Note that the same entity may be requested multiple times, with different [`AccessTrie`]s. + /// + /// Either `load_entities` must load all the ancestors of each entity, unless `load_ancestors` is implemented. + fn load_entities( + &mut self, + to_load: &[EntityRequest], + ) -> Result, EntitySliceError>; + + /// Optionally, `load_entities` can forgo loading ancestors in the entity hierarchy. + /// Instead, `load_ancestors` implements loading them. + /// For each entity, `load_ancestors` produces a set of ancestors entities in the resulting vector. + /// + /// Each [`AncestorsRequest`] should result in one set of ancestors in the resulting vector. + /// Only ancestors in the request are required, but it is sound to provide other ancestors as well. + fn load_ancestors( + &mut self, + entities: &[AncestorsRequest], + ) -> Result>, EntitySliceError>; +} + +fn initial_entities_to_load<'a>( + root_access_trie: &'a RootAccessTrie, + context: &Context, + request: &Request, + required_ancestors: &mut HashSet, +) -> Result>, EntitySliceError> { + let Context::Value(context_value) = &context else { + return Err(PartialContextError {}.into()); + }; + + let mut to_load = match root_access_trie.trie.get(&EntityRoot::Var(Var::Context)) { + Some(access_trie) => { + find_remaining_entities_context(context_value, access_trie, required_ancestors)? + } + _ => vec![], + }; + + for (key, access_trie) in &root_access_trie.trie { + to_load.push(EntityRequestRef { + entity_id: match key { + EntityRoot::Var(Var::Principal) => request + .principal() + .uid() + .ok_or(PartialRequestError {})? + .clone(), + EntityRoot::Var(Var::Action) => request + .action() + .uid() + .ok_or(PartialRequestError {})? + .clone(), + EntityRoot::Var(Var::Resource) => request + .resource() + .uid() + .ok_or(PartialRequestError {})? + .clone(), + EntityRoot::Literal(lit) => lit.clone(), + EntityRoot::Var(Var::Context) => continue, + }, + access_trie, + }); + } + + Ok(to_load) +} + +impl AccessTrie { + /// Removes any entity dereferences in the children of this trie, + /// recursively. + /// These can be included in [`EntityRequest`]s, which don't include + /// referenced entities. + pub(crate) fn prune_child_entity_dereferences(&self) -> AccessTrie { + let children = self + .children + .iter() + .map(|(k, v)| (k.clone(), Box::new(v.prune_entity_dereferences()))) + .collect(); + + AccessTrie { + children, + ancestors_trie: self.ancestors_trie.clone(), + is_ancestor: self.is_ancestor, + node_type: self.node_type.clone(), + } + } + + pub(crate) fn prune_entity_dereferences(&self) -> AccessTrie { + // PANIC SAFETY: Node types should always be present on entity manifests after creation. + #[allow(clippy::unwrap_used)] + let children = if self.node_type.as_ref().unwrap().is_entity_type() { + HashMap::new() + } else { + self.children + .iter() + .map(|(k, v)| (k.clone(), Box::new(v.prune_entity_dereferences()))) + .collect() + }; + + AccessTrie { + children, + ancestors_trie: self.ancestors_trie.clone(), + is_ancestor: self.is_ancestor, + node_type: self.node_type.clone(), + } + } +} + +/// Loads entities based on the entity manifest, request, and +/// the implemented [`EntityLoader`]. +pub(crate) fn load_entities( + manifest: &EntityManifest, + request: &Request, + loader: &mut dyn EntityLoader, +) -> Result { + let Some(root_access_trie) = manifest + .per_action + .get(&request.to_request_type().ok_or(PartialRequestError {})?) + else { + match Entities::from_entities( + vec![], + None::<&NoEntitiesSchema>, + TCComputation::AssumeAlreadyComputed, + Extensions::all_available(), + ) { + Ok(entities) => return Ok(entities), + Err(err) => return Err(err.into()), + }; + }; + + let context = request.context().ok_or(PartialRequestError {})?; + + let mut entities: HashMap = Default::default(); + // entity requests in progress + let mut to_load: Vec> = + initial_entities_to_load(root_access_trie, context, request, &mut Default::default())?; + // later, find the ancestors of these entities using their ancestor tries + let mut to_find_ancestors = vec![]; + + // Main loop of loading entities, one batch at a time + while !to_load.is_empty() { + // first, record the entities in `to_find_ancestors` + for entity_request in &to_load { + to_find_ancestors.push(( + entity_request.entity_id.clone(), + &entity_request.access_trie.ancestors_trie, + )); + } + + let new_entities = loader.load_entities( + &to_load + .iter() + .map(|entity_ref| entity_ref.to_request()) + .collect::>(), + )?; + if new_entities.len() != to_load.len() { + return Err(WrongNumberOfEntitiesError { + expected: to_load.len(), + got: new_entities.len(), + } + .into()); + } + + let mut next_to_load = vec![]; + for (entity_request, loaded_maybe) in to_load.drain(..).zip(new_entities) { + if let Some(loaded) = loaded_maybe { + next_to_load.extend(find_remaining_entities( + &loaded, + entity_request.access_trie, + &mut Default::default(), + )?); + entities.insert(entity_request.entity_id, loaded); + } + } + + to_load = next_to_load; + } + + // now that all the entities are loaded + // we need to load their ancestors + let mut ancestors_requests = vec![]; + for (entity_id, ancestors_trie) in to_find_ancestors { + ancestors_requests.push(compute_ancestors_request( + entity_id, + ancestors_trie, + &entities, + context, + request, + )?); + } + + let loaded_ancestors = loader.load_ancestors(&ancestors_requests)?; + for (request, ancestors) in ancestors_requests.into_iter().zip(loaded_ancestors) { + if let Some(entity) = entities.get_mut(&request.entity_id) { + ancestors + .into_iter() + .for_each(|ancestor| entity.add_ancestor(ancestor)); + } + } + + // finally, convert the loaded entities into a Cedar Entities store + match Entities::from_entities( + entities.values().cloned(), + None::<&NoEntitiesSchema>, + TCComputation::AssumeAlreadyComputed, + Extensions::all_available(), + ) { + Ok(entities) => Ok(entities), + Err(e) => Err(e.into()), + } +} + +/// Given a context value and an access trie, find all of the remaining +/// entities in the context. +/// Also keep track of required ancestors when encountering the `is_ancestor` flag. +fn find_remaining_entities_context<'a>( + context_value: &Arc>, + fields: &'a AccessTrie, + required_ancestors: &mut HashSet, +) -> Result>, EntitySliceError> { + let mut remaining = vec![]; + for (field, slice) in &fields.children { + if let Some(value) = context_value.get(field) { + find_remaining_entities_value(&mut remaining, value, slice, required_ancestors)?; + } + // the attribute may not be present, since the schema can define + // attributes that are optional + } + Ok(remaining) +} + +/// This helper function finds all entity references that need to be +/// loaded given an already-loaded [`Entity`] and corresponding [`Fields`]. +/// Returns pairs of entity and slices that need to be loaded. +/// Also, any sets marked `is_ancestor` are added to the `required_ancestors` set. +fn find_remaining_entities<'a>( + entity: &Entity, + fields: &'a AccessTrie, + required_ancestors: &mut HashSet, +) -> Result>, EntitySliceError> { + let mut remaining = vec![]; + for (field, slice) in &fields.children { + if let Some(pvalue) = entity.get(field) { + let PartialValue::Value(value) = pvalue else { + return Err(PartialEntityError {}.into()); + }; + find_remaining_entities_value(&mut remaining, value, slice, required_ancestors)?; + } + // the attribute may not be present, since the schema can define + // attributes that are optional + } + + Ok(remaining) +} + +/// Like `find_remaining_entities`, but for values. +/// Any sets that are marked `is_ancestor` are added to the `required_ancestors` set. +fn find_remaining_entities_value<'a>( + remaining: &mut Vec>, + value: &Value, + trie: &'a AccessTrie, + required_ancestors: &mut HashSet, +) -> Result<(), EntitySliceError> { + // unless this is an entity id, ancestors should not be required + assert!( + trie.ancestors_trie == Default::default() + || matches!(value.value_kind(), ValueKind::Lit(Literal::EntityUID(_))) + ); + + // unless this is an entity id or set, it should not be an + // ancestor + assert!( + !trie.is_ancestor + || matches!( + value.value_kind(), + ValueKind::Lit(Literal::EntityUID(_)) | ValueKind::Set(_) + ) + ); + + match value.value_kind() { + ValueKind::Lit(literal) => { + if let Literal::EntityUID(entity_id) = literal { + // no need to add to ancestors set here because + // we are creating an entity request. + + remaining.push(EntityRequestRef { + entity_id: (**entity_id).clone(), + access_trie: trie, + }); + } + } + ValueKind::Set(set) => { + // when this is an ancestor, request all of the entities + // in this set + if trie.is_ancestor { + for val in set.iter() { + match val.value_kind() { + ValueKind::Lit(Literal::EntityUID(id)) => { + required_ancestors.insert((**id).clone()); + } + // PANIC SAFETY: see assert above- ancestor annotation is only valid on sets of entities or entities + #[allow(clippy::panic)] + _ => { + panic!( + "Found is_ancestor on set of non-entity-type {}", + val.value_kind() + ); + } + } + } + } + } + ValueKind::ExtensionValue(_) => (), + ValueKind::Record(record) => { + for (field, child_slice) in &trie.children { + // only need to slice if field is present + if let Some(value) = record.get(field) { + find_remaining_entities_value( + remaining, + value, + child_slice, + required_ancestors, + )?; + } + } + } + }; + Ok(()) +} + +/// Traverse the already-loaded entities using the ancestors trie +/// to find the entity ids that are required. +fn compute_ancestors_request( + entity_id: EntityUID, + ancestors_trie: &RootAccessTrie, + entities: &HashMap, + context: &Context, + request: &Request, +) -> Result { + // similar to load_entities, we traverse the access trie + // this time using the already-loaded entities and looking for + // is_ancestor tags. + let mut ancestors = HashSet::new(); + + let mut to_visit = initial_entities_to_load(ancestors_trie, context, request, &mut ancestors)?; + + while !to_visit.is_empty() { + let mut next_to_visit = vec![]; + for entity_request in to_visit.drain(..) { + // check the is_ancestor flag for entities + // the is_ancestor flag on sets of entities is handled by find_remaining_entities + if entity_request.access_trie.is_ancestor { + ancestors.insert(entity_request.entity_id.clone()); + } + + if let Some(entity) = entities.get(&entity_request.entity_id) { + next_to_visit.extend(find_remaining_entities( + entity, + entity_request.access_trie, + &mut ancestors, + )?); + } + } + to_visit = next_to_visit; + } + + Ok(AncestorsRequest { + ancestors, + entity_id, + }) +} diff --git a/cedar-policy-validator/src/entity_manifest/slicing.rs b/cedar-policy-validator/src/entity_manifest/slicing.rs new file mode 100644 index 000000000..999f756ee --- /dev/null +++ b/cedar-policy-validator/src/entity_manifest/slicing.rs @@ -0,0 +1,927 @@ +//! Entity Slicing + +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::fmt::Display; + +use cedar_policy_core::entities::err::EntitiesError; +use cedar_policy_core::entities::Dereference; +use cedar_policy_core::{ + ast::{Entity, EntityUID, Literal, PartialValue, Request, Value, ValueKind}, + entities::Entities, +}; +use miette::Diagnostic; +use smol_str::SmolStr; +use thiserror::Error; + +use crate::entity_manifest::loader::{ + load_entities, AncestorsRequest, EntityAnswer, EntityLoader, EntityRequest, +}; +use crate::entity_manifest::{AccessTrie, EntityManifest, PartialRequestError}; + +/// Error when expressions are partial during entity +/// slicing. +// CAUTION: this type is publicly exported in `cedar-policy`. +// Don't make fields `pub`, don't make breaking changes, and use caution +// when adding public methods. +#[derive(Debug, Clone, Error, Eq, PartialEq)] +#[error("entity slicing requires fully concrete policies. Got a policy with an unknown expression")] +pub struct PartialExpressionError {} + +impl Diagnostic for PartialExpressionError {} + +/// Error when expressions are partial during entity +/// slicing. +// CAUTION: this type is publicly exported in `cedar-policy`. +// Don't make fields `pub`, don't make breaking changes, and use caution +// when adding public methods. +#[derive(Debug, Clone, Error, Eq, PartialEq)] +#[error("entity slicing requires fully concrete policies. Got a policy with an unknown expression")] +pub struct IncompatibleEntityManifestError { + non_record_entity_value: Value, +} + +impl Diagnostic for IncompatibleEntityManifestError { + fn help<'a>(&'a self) -> Option> { + Some(Box::new(format!( + "expected entity or record during entity loading. Got value: {}", + self.non_record_entity_value + ))) + } +} + +/// Error when entities are partial during entity manifest computation. +// CAUTION: this type is publicly exported in `cedar-policy`. +// Don't make fields `pub`, don't make breaking changes, and use caution +// when adding public methods. +#[derive(Debug, Clone, Error, Eq, PartialEq)] +#[error("entity slicing requires fully concrete entities. Got a partial entity")] +pub struct PartialEntityError {} + +impl Diagnostic for PartialEntityError {} + +/// Error when an entity loader returns the wrong number of entities. +// CAUTION: this type is publicly exported in `cedar-policy`. +// Don't make fields `pub`, don't make breaking changes, and use caution +// when adding public methods. +#[derive(Debug, Clone, Error, Eq, PartialEq)] +#[error("entity loader returned the wrong number of entities. Expected {expected} but got {got} entities")] +pub struct WrongNumberOfEntitiesError { + pub(crate) expected: usize, + pub(crate) got: usize, +} + +/// Error when an entity loader returns a value missing an attribute. +// CAUTION: this type is publicly exported in `cedar-policy`. +// Don't make fields `pub`, don't make breaking changes, and use caution +// when adding public methods. +#[derive(Debug, Clone, Error, Eq, PartialEq)] +#[error("entity loader produced entity with value {value}. Expected value to be a record with attribute {attribute}")] +pub struct NonRecordValueError { + pub(crate) value: Value, + pub(crate) attribute: SmolStr, +} + +/// Context was partial during entity loading +// CAUTION: this type is publicly exported in `cedar-policy`. +// Don't make fields `pub`, don't make breaking changes, and use caution +// when adding public methods. +#[derive(Debug, Clone, Error, Eq, PartialEq)] +#[error("entity loader produced a partial context. Expected a concrete value")] +pub struct PartialContextError {} + +/// An error generated by entity slicing. +/// TODO make public API wrapper +#[derive(Debug, Error, Diagnostic)] +pub enum EntitySliceError { + /// An entities error was encountered + #[error(transparent)] + #[diagnostic(transparent)] + Entities(#[from] EntitiesError), + + /// The request was partial + #[error(transparent)] + PartialRequest(#[from] PartialRequestError), + /// A policy was partial + #[error(transparent)] + PartialExpression(#[from] PartialExpressionError), + + /// During entity loading, attempted to load from + /// a type without fields. + #[error(transparent)] + IncompatibleEntityManifest(#[from] IncompatibleEntityManifestError), + + /// Found a partial entity during entity loading. + #[error(transparent)] + PartialEntity(#[from] PartialEntityError), + + /// The entity loader returned a partial context. + #[error(transparent)] + PartialContext(#[from] PartialContextError), + + /// The entity loader produced the wrong number of entities. + #[error(transparent)] + WrongNumberOfEntities(#[from] WrongNumberOfEntitiesError), +} + +impl EntityManifest { + /// Use this entity manifest to + /// find an entity slice using an existing [`Entities`] store. + pub fn slice_entities( + &self, + entities: &Entities, + request: &Request, + ) -> Result { + let mut slicer = EntitySlicer { entities }; + load_entities(self, request, &mut slicer) + } +} + +struct EntitySlicer<'a> { + entities: &'a Entities, +} + +impl<'a> EntityLoader for EntitySlicer<'a> { + fn load_entities( + &mut self, + to_load: &[EntityRequest], + ) -> Result, EntitySliceError> { + let mut res = vec![]; + for request in to_load { + if let Dereference::Data(entity) = self.entities.entity(&request.entity_id) { + // filter down the entity fields to those requested + res.push(Some(request.access_trie.slice_entity(entity)?)); + } else { + res.push(None); + } + } + + Ok(res) + } + + fn load_ancestors( + &mut self, + entities: &[AncestorsRequest], + ) -> Result>, EntitySliceError> { + let mut res = vec![]; + + for request in entities { + if let Dereference::Data(entity) = self.entities.entity(&request.entity_id) { + let mut ancestors = HashSet::new(); + + for required_ancestor in &request.ancestors { + if entity.is_descendant_of(required_ancestor) { + ancestors.insert(required_ancestor.clone()); + } + } + + res.push(ancestors); + } else { + // if the entity isn't there, we don't need any ancestors + res.push(HashSet::new()); + } + } + + Ok(res) + } +} + +impl AccessTrie { + /// Given an entities store, an entity id, and a resulting store + /// Slice the entities and put them in the resulting store. + fn slice_entity(&self, entity: &Entity) -> Result { + let mut new_entity = HashMap::::new(); + for (field, slice) in &self.children { + // only slice when field is available + if let Some(pval) = entity.get(field).cloned() { + let PartialValue::Value(val) = pval else { + return Err(PartialEntityError {}.into()); + }; + let sliced = slice.slice_val(&val)?; + + new_entity.insert(field.clone(), PartialValue::Value(sliced)); + } + } + + Ok(Entity::new_with_attr_partial_value( + entity.uid().clone(), + new_entity, + Default::default(), + )) + } + + fn slice_val(&self, val: &Value) -> Result { + Ok(match val.value_kind() { + ValueKind::Lit(Literal::EntityUID(_)) => { + // entities shouldn't need to be dereferenced + assert!(self.children.is_empty()); + val.clone() + } + ValueKind::Set(_) | ValueKind::ExtensionValue(_) | ValueKind::Lit(_) => { + if !self.children.is_empty() { + return Err(IncompatibleEntityManifestError { + non_record_entity_value: val.clone(), + } + .into()); + } + + val.clone() + } + ValueKind::Record(record) => { + let mut new_map = BTreeMap::::new(); + for (field, slice) in &self.children { + // only slice when field is available + if let Some(v) = record.get(field) { + new_map.insert(field.clone(), slice.slice_val(v)?); + } + } + + Value::new(ValueKind::record(new_map), None) + } + }) + } +} + +#[cfg(test)] +mod entity_slice_tests { + use std::collections::BTreeSet; + + use cedar_policy_core::{ + ast::{Context, PolicyID, PolicySet}, + entities::{EntityJsonParser, TCComputation}, + extensions::Extensions, + parser::parse_policy, + }; + + use crate::{entity_manifest::compute_entity_manifest, CoreSchema, ValidatorSchema}; + + use super::*; + + /// The implementation of [`Eq`] and [`PartialEq`] for + /// entities just compares entity ids. + /// This implementation does a more traditional, deep equality + /// check comparing attributes, ancestors, and the id. + fn entity_deep_equal(this: &Entity, other: &Entity) -> bool { + this.uid() == other.uid() + && BTreeMap::from_iter(this.attrs()) == BTreeMap::from_iter(other.attrs()) + && BTreeSet::from_iter(this.ancestors()) == BTreeSet::from_iter(other.ancestors()) + } + + /// The implementation of [`Eq`] and [`PartialEq`] on [`Entities`] + /// only checks equality by id for entities in the store. + /// This method checks that the entities are equal deeply, + /// using `[Entity::deep_equal]` to check equality. + /// Note that it ignores mode + fn entities_deep_equal(this: &Entities, other: &Entities) -> bool { + for this_entity in this.iter() { + let key = this_entity.uid(); + if let Dereference::Data(other_value) = other.entity(key) { + if !entity_deep_equal(this_entity, other_value) { + return false; + } + } else { + return false; + } + } + + for key in other.iter() { + if !matches!(this.entity(key.uid()), Dereference::Data(_)) { + return false; + } + } + + true + } + + // Schema for testing in this module + fn schema() -> ValidatorSchema { + ValidatorSchema::from_cedarschema_str( + " +entity User = { + name: String, +}; + +entity Document; + +action Read appliesTo { + principal: [User], + resource: [Document] +}; + ", + Extensions::all_available(), + ) + .unwrap() + .0 + } + + fn schema_with_hierarchy() -> ValidatorSchema { + ValidatorSchema::from_cedarschema_str( + " +entity User in [Document] = { + name: String, + manager: User, + personaldoc: Document, +}; + +entity Document; + +action Read appliesTo { + principal: [User], + resource: [Document] +}; + ", + Extensions::all_available(), + ) + .unwrap() + .0 + } + + fn expect_entity_slice_to( + original: serde_json::Value, + expected: serde_json::Value, + schema: &ValidatorSchema, + manifest: &EntityManifest, + ) { + let request = Request::new( + ( + EntityUID::with_eid_and_type("User", "oliver").unwrap(), + None, + ), + ( + EntityUID::with_eid_and_type("Action", "Read").unwrap(), + None, + ), + ( + EntityUID::with_eid_and_type("Document", "dummy").unwrap(), + None, + ), + Context::empty(), + Some(schema), + Extensions::all_available(), + ) + .unwrap(); + + let schema = CoreSchema::new(schema); + let parser: EntityJsonParser<'_, '_, CoreSchema<'_>> = EntityJsonParser::new( + Some(&schema), + Extensions::all_available(), + TCComputation::AssumeAlreadyComputed, + ); + let original_entities = parser.from_json_value(original).unwrap(); + + // Entity slicing results in invalid entity stores + // since attributes may be missing. + let parser_without_validation: EntityJsonParser<'_, '_> = EntityJsonParser::new( + None, + Extensions::all_available(), + TCComputation::AssumeAlreadyComputed, + ); + let expected_entities = parser_without_validation.from_json_value(expected).unwrap(); + + let sliced_entities = manifest + .slice_entities(&original_entities, &request) + .unwrap(); + + // PANIC SAFETY: panic in testing when test fails + #[allow(clippy::panic)] + if !entities_deep_equal(&sliced_entities, &expected_entities) { + panic!( + "Sliced entities differed from expected. Expected:\n{}\nGot:\n{}", + expected_entities.to_json_value().unwrap(), + sliced_entities.to_json_value().unwrap() + ); + } + } + + #[test] + fn test_simple_entity_manifest() { + let mut pset = PolicySet::new(); + let policy = parse_policy( + None, + r#"permit(principal, action, resource) +when { + principal.name == "John" +};"#, + ) + .expect("should succeed"); + pset.add(policy.into()).expect("should succeed"); + + let schema = schema(); + + let entity_manifest = compute_entity_manifest(&schema, &pset).expect("Should succeed"); + + let entities_json = serde_json::json!( + [ + { + "uid" : { "type" : "User", "id" : "oliver"}, + "attrs" : { + "name" : "Oliver" + }, + "parents" : [] + }, + { + "uid" : { "type" : "User", "id" : "oliver2"}, + "attrs" : { + "name" : "Oliver2" + }, + "parents" : [] + }, + ] + ); + + let expected_entities_json = serde_json::json!( + [ + { + "uid" : { "type" : "User", "id" : "oliver"}, + "attrs" : { + "name" : "Oliver" + }, + "parents" : [] + }, + ] + ); + + expect_entity_slice_to( + entities_json, + expected_entities_json, + &schema, + &entity_manifest, + ); + } + + #[test] + #[should_panic(expected = "Sliced entities differed")] + fn sanity_test_empty_entity_manifest() { + let mut pset = PolicySet::new(); + let policy = + parse_policy(None, "permit(principal, action, resource);").expect("should succeed"); + pset.add(policy.into()).expect("should succeed"); + + let schema = schema(); + + let entity_manifest = compute_entity_manifest(&schema, &pset).expect("Should succeed"); + + let entities_json = serde_json::json!( + [ + { + "uid" : { "type" : "User", "id" : "oliver"}, + "attrs" : { + "name" : "Oliver" + }, + "parents" : [] + }, + { + "uid" : { "type" : "User", "id" : "oliver2"}, + "attrs" : { + "name" : "Oliver2" + }, + "parents" : [] + }, + ] + ); + + let expected_entities_json = serde_json::json!([ + { + "uid" : { "type" : "User", "id" : "oliver"}, + "attrs" : { + "name" : "Oliver" + }, + "parents" : [] + }, + { + "uid" : { "type" : "User", "id" : "oliver2"}, + "attrs" : { + "name" : "Oliver2" + }, + "parents" : [] + }, + ]); + + expect_entity_slice_to( + entities_json, + expected_entities_json, + &schema, + &entity_manifest, + ); + } + + #[test] + fn test_empty_entity_manifest() { + let mut pset = PolicySet::new(); + let policy = + parse_policy(None, "permit(principal, action, resource);").expect("should succeed"); + pset.add(policy.into()).expect("should succeed"); + + let schema = schema(); + + let entity_manifest = compute_entity_manifest(&schema, &pset).expect("Should succeed"); + + let entities_json = serde_json::json!( + [ + { + "uid" : { "type" : "User", "id" : "oliver"}, + "attrs" : { + "name" : "Oliver" + }, + "parents" : [] + }, + { + "uid" : { "type" : "User", "id" : "oliver2"}, + "attrs" : { + "name" : "Oliver2" + }, + "parents" : [] + }, + ] + ); + + let expected_entities_json = serde_json::json!([]); + + expect_entity_slice_to( + entities_json, + expected_entities_json, + &schema, + &entity_manifest, + ); + } + + #[test] + fn test_entity_manifest_ancestors_skipped() { + let mut pset = PolicySet::new(); + let policy = parse_policy( + None, + "permit(principal, action, resource) +when { + principal in resource || principal.manager in resource +};", + ) + .expect("should succeed"); + pset.add(policy.into()).expect("should succeed"); + + let schema = schema_with_hierarchy(); + + let entity_manifest = compute_entity_manifest(&schema, &pset).expect("Should succeed"); + + let entities_json = serde_json::json!( + [ + { + "uid" : { "type" : "User", "id" : "oliver"}, + "attrs" : { + "name" : "Oliver", + "manager": { "type" : "User", "id" : "george"}, + "personaldoc": { "type" : "Document", "id" : "oliverdocument"} + }, + "parents" : [ + { "type" : "Document", "id" : "oliverdocument"}, + { "type" : "Document", "id" : "dummy"} + ] + }, + { + "uid" : { "type" : "User", "id" : "george"}, + "attrs" : { + "name" : "George", + "manager": { "type" : "User", "id" : "george"}, + "personaldoc": { "type" : "Document", "id" : "georgedocument"} + }, + "parents" : [ + ] + }, + ] + ); + + let expected_entities_json = serde_json::json!( + [ + { + "uid" : { "type" : "User", "id" : "oliver"}, + "attrs" : { + "manager": { "__entity": { "type" : "User", "id" : "george"} } + }, + "parents" : [ + { "type" : "Document", "id" : "dummy"} + ] + }, + { + "uid" : { "type" : "User", "id" : "george"}, + "attrs" : { + }, + "parents" : [ + ] + }, + ] + ); + + expect_entity_slice_to( + entities_json, + expected_entities_json, + &schema, + &entity_manifest, + ); + } + + #[test] + fn test_entity_manifest_possible_ancestors() { + let mut pset = PolicySet::new(); + let policy = parse_policy( + None, + r#"permit(principal, action, resource) +when { + principal in (if 2 > 3 + then Document::"dummy" + else principal.personaldoc) +};"#, + ) + .expect("should succeed"); + pset.add(policy.into()).expect("should succeed"); + + let schema = schema_with_hierarchy(); + + let entity_manifest = compute_entity_manifest(&schema, &pset).expect("Should succeed"); + + let entities_json = serde_json::json!( + [ + { + "uid" : { "type" : "User", "id" : "oliver"}, + "attrs" : { + "name" : "Oliver", + "manager": { "type" : "User", "id" : "george"}, + "personaldoc": { "type" : "Document", "id" : "oliverdocument"} + }, + "parents" : [ + { "type" : "Document", "id" : "oliverdocument"}, + { "type" : "Document", "id" : "georgedocument"}, + { "type" : "Document", "id" : "dummy"} + ] + }, + ] + ); + + let expected_entities_json = serde_json::json!( + [ + { + "uid" : { "type" : "User", "id" : "oliver"}, + "attrs" : { + "personaldoc":{"__entity":{"type":"Document","id":"oliverdocument"}}, + }, + "parents" : [ + { "type" : "Document", "id" : "dummy"}, + { "type" : "Document", "id" : "oliverdocument"} + ] + } + ] + ); + + expect_entity_slice_to( + entities_json, + expected_entities_json, + &schema, + &entity_manifest, + ); + } + + #[test] + fn test_entity_manifest_set_of_ancestors() { + let mut pset = PolicySet::new(); + let policy = parse_policy( + None, + "permit(principal, action, resource) +when { + principal in principal.managers +};", + ) + .expect("should succeed"); + pset.add(policy.into()).expect("should succeed"); + + let schema = ValidatorSchema::from_cedarschema_str( + " +entity User in [User] = { + name: String, + managers: Set +}; + +entity Document; + +action Read appliesTo { + principal: [User], + resource: [Document] +}; + ", + Extensions::all_available(), + ) + .unwrap() + .0; + + let entity_manifest = compute_entity_manifest(&schema, &pset).expect("Should succeed"); + + let entities_json = serde_json::json!( + [ + { + "uid" : { "type" : "User", "id" : "oliver"}, + "attrs" : { + "name" : "Oliver", + "managers": [ + { "type" : "User", "id" : "george"}, + { "type" : "User", "id" : "yihong"}, + { "type" : "User", "id" : "ignored"}, + ] + }, + "parents" : [ + { "type" : "User", "id" : "dummy"}, + { "type" : "User", "id" : "george"}, + { "type" : "User", "id" : "yihong"}, + ] + }, + ] + ); + + let expected_entities_json = serde_json::json!( + [ + { + "uid" : { "type" : "User", "id" : "oliver"}, + "attrs" : { + "managers": [ + { "__entity": { "type" : "User", "id" : "george"}}, + { "__entity": { "type" : "User", "id" : "yihong"}}, + { "__entity": { "type" : "User", "id" : "ignored"}}, + ] + }, + "parents" : [ + { "type" : "User", "id" : "george"}, + { "type" : "User", "id" : "yihong"}, + ] + }, + ] + ); + + expect_entity_slice_to( + entities_json, + expected_entities_json, + &schema, + &entity_manifest, + ); + } + + #[test] + fn test_entity_manifest_multiple_branches() { + let mut pset = PolicySet::new(); + let policy1 = parse_policy( + None, + r#" +permit( + principal, + action == Action::"Read", + resource +) +when +{ + resource.readers.contains(principal) +};"#, + ) + .unwrap(); + let policy2 = parse_policy( + Some(PolicyID::from_string("Policy2")), + r#"permit( + principal, + action == Action::"Read", + resource +) +when +{ + resource.metadata.owner == principal +};"#, + ) + .unwrap(); + pset.add(policy1.into()).expect("should succeed"); + pset.add(policy2.into()).expect("should succeed"); + + let schema = ValidatorSchema::from_cedarschema_str( + " +entity User; + +entity Metadata = { + owner: User, + time: String, +}; + +entity Document = { + metadata: Metadata, + readers: Set, +}; + +action Read appliesTo { + principal: [User], + resource: [Document] +}; + ", + Extensions::all_available(), + ) + .unwrap() + .0; + + let entity_manifest = compute_entity_manifest(&schema, &pset).expect("Should succeed"); + + let entities_json = serde_json::json!( + [ + { + "uid" : { "type" : "User", "id" : "oliver"}, + "attrs" : { + }, + "parents" : [ + ] + }, + { + "uid": { "type": "Document", "id": "dummy"}, + "attrs": { + "metadata": { "type": "Metadata", "id": "olivermetadata"}, + "readers": [{"type": "User", "id": "oliver"}] + }, + "parents": [], + }, + { + "uid": { "type": "Metadata", "id": "olivermetadata"}, + "attrs": { + "owner": { "type": "User", "id": "oliver"}, + "time": "now" + }, + "parents": [], + }, + ] + ); + + let expected_entities_json = serde_json::json!( + [ + { + "uid": { "type": "Document", "id": "dummy"}, + "attrs": { + "metadata": {"__entity": { "type": "Metadata", "id": "olivermetadata"}}, + "readers": [{ "__entity": {"type": "User", "id": "oliver"}}] + }, + "parents": [], + }, + { + "uid": { "type": "Metadata", "id": "olivermetadata"}, + "attrs": { + "owner": {"__entity": { "type": "User", "id": "oliver"}}, + }, + "parents": [], + }, + { + "uid" : { "type" : "User", "id" : "oliver"}, + "attrs" : { + }, + "parents" : [ + ] + }, + ] + ); + + expect_entity_slice_to( + entities_json, + expected_entities_json, + &schema, + &entity_manifest, + ); + } + + #[test] + fn test_entity_manifest_struct_equality() { + let mut pset = PolicySet::new(); + // we need to load all of the metadata, not just nickname + // no need to load actual name + let policy = parse_policy( + None, + r#"permit(principal, action, resource) +when { + principal.metadata.nickname == "timmy" && principal.metadata == { + "friends": [ "oliver" ], + "nickname": "timmy" + } +};"#, + ) + .expect("should succeed"); + pset.add(policy.into()).expect("should succeed"); + + let schema = ValidatorSchema::from_cedarschema_str( + " +entity User = { + name: String, + metadata: { + friends: Set, + nickname: String, + }, +}; + +entity Document; + +action BeSad appliesTo { + principal: [User], + resource: [Document] +}; + ", + Extensions::all_available(), + ) + .unwrap() + .0; + + let entity_manifest = compute_entity_manifest(&schema, &pset).expect("Should succeed"); + assert_eq!(entity_manifest, entity_manifest); + } +} diff --git a/cedar-policy-validator/src/entity_manifest/type_annotations.rs b/cedar-policy-validator/src/entity_manifest/type_annotations.rs new file mode 100644 index 000000000..6ae2fde80 --- /dev/null +++ b/cedar-policy-validator/src/entity_manifest/type_annotations.rs @@ -0,0 +1,177 @@ +/* + * Copyright Cedar Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +//! Annotate entity manifest with type information. + +use std::collections::HashMap; + +use cedar_policy_core::ast::{RequestType, Var}; + +use crate::{ + entity_manifest::{ + AccessTrie, EntityManifest, EntityRoot, Fields, MismatchedEntityManifestError, + MismatchedMissingEntityError, MismatchedNotStrictSchemaError, RootAccessTrie, + }, + types::{Attributes, EntityRecordKind, Type}, + ValidatorSchema, +}; + +impl EntityManifest { + /// Given an untyped entity manifest and the schema that produced it, + /// return a newly typed entity manifest. + pub(crate) fn to_typed( + &self, + schema: &ValidatorSchema, + ) -> Result { + Ok( + EntityManifest { + per_action: + self.per_action + .iter() + .map(|(key, val)| Ok((key.clone(), val.to_typed(key, schema)?))) + .collect::, + MismatchedEntityManifestError, + >>()?, + }, + ) + } +} + +impl RootAccessTrie { + /// Type-annotate this primary slice, given the type of + /// the request and the schema. + pub(crate) fn to_typed( + &self, + request_type: &RequestType, + schema: &ValidatorSchema, + ) -> Result { + Ok(RootAccessTrie { + trie: self + .trie + .iter() + .map(|(key, slice)| { + Ok(( + key.clone(), + match key { + EntityRoot::Literal(lit) => slice.to_typed( + request_type, + &Type::euid_literal(lit.clone(), schema).ok_or( + MismatchedMissingEntityError { + entity: lit.clone(), + }, + )?, + schema, + )?, + EntityRoot::Var(Var::Action) => { + let ty = Type::euid_literal(request_type.action.clone(), schema) + .ok_or(MismatchedMissingEntityError { + entity: request_type.action.clone(), + })?; + slice.to_typed(request_type, &ty, schema)? + } + EntityRoot::Var(Var::Principal) => slice.to_typed( + request_type, + &Type::named_entity_reference(request_type.principal.clone()), + schema, + )?, + EntityRoot::Var(Var::Resource) => slice.to_typed( + request_type, + &Type::named_entity_reference(request_type.resource.clone()), + schema, + )?, + EntityRoot::Var(Var::Context) => { + let ty = &schema + .get_action_id(&request_type.action.clone()) + .ok_or(MismatchedMissingEntityError { + entity: request_type.action.clone(), + })? + .context; + slice.to_typed(request_type, ty, schema)? + } + }, + )) + }) + .collect::, MismatchedEntityManifestError>>( + )?, + }) + } +} + +impl AccessTrie { + pub(crate) fn to_typed( + &self, + request_type: &RequestType, + ty: &Type, + schema: &ValidatorSchema, + ) -> Result { + let children: Fields = match ty { + Type::Never + | Type::True + | Type::False + | Type::Primitive { .. } + | Type::Set { .. } + | Type::ExtensionType { .. } => { + assert!(self.children.is_empty()); + HashMap::default() + } + Type::EntityOrRecord(entity_or_record_ty) => { + let attributes: &Attributes = match entity_or_record_ty { + EntityRecordKind::Record { + attrs, + open_attributes: _, + } => attrs, + EntityRecordKind::AnyEntity => Err(MismatchedNotStrictSchemaError {})?, + // PANIC SAFETY: entity LUB should succeed after strict validation, and so should looking up the resulting type + #[allow(clippy::unwrap_used)] + EntityRecordKind::Entity(entitylub) => { + let entity_ty = schema + .get_entity_type( + entitylub + .get_single_entity() + .ok_or(MismatchedNotStrictSchemaError {})?, + ) + .ok_or(MismatchedNotStrictSchemaError {})?; + &entity_ty.attributes + } + EntityRecordKind::ActionEntity { name: _, attrs } => attrs, + }; + + let mut new_children = HashMap::new(); + for (field, child) in self.children.iter() { + // if the schema doesn't mention an attribute, + // it's safe to drop it. + // this can come up with the `has` operator + // on a type that doesn't have the attribute + if let Some(ty) = attributes.attrs.get(field) { + new_children.insert( + field.clone(), + Box::new(child.to_typed(request_type, &ty.attr_type, schema)?), + ); + } + } + new_children + } + }; + + Ok(AccessTrie { + children, + node_type: Some(ty.clone()), + ancestors_trie: self.ancestors_trie.to_typed(request_type, schema)?, + is_ancestor: self.is_ancestor, + }) + } +} diff --git a/cedar-policy-validator/src/types.rs b/cedar-policy-validator/src/types.rs index e9c689da2..81e204651 100644 --- a/cedar-policy-validator/src/types.rs +++ b/cedar-policy-validator/src/types.rs @@ -661,6 +661,17 @@ impl Type { } } + /// Returns `true` when the type is a type of an entity + #[cfg(feature = "entity-manifest")] + pub(crate) fn is_entity_type(&self) -> bool { + matches!( + self, + Type::EntityOrRecord(EntityRecordKind::Entity(_)) + | Type::EntityOrRecord(EntityRecordKind::AnyEntity) + | Type::EntityOrRecord(EntityRecordKind::ActionEntity { .. }) + ) + } + pub(crate) fn support_operator_overloading(&self) -> bool { match self { Self::ExtensionType { name } => { diff --git a/cedar-policy/src/api.rs b/cedar-policy/src/api.rs index 0844387e4..0a77583e8 100644 --- a/cedar-policy/src/api.rs +++ b/cedar-policy/src/api.rs @@ -4457,5 +4457,5 @@ pub fn compute_entity_manifest( schema: &Schema, pset: &PolicySet, ) -> Result { - entity_manifest::compute_entity_manifest(&schema.0, &pset.ast).map_err(Into::into) + entity_manifest::compute_entity_manifest(&schema.0, &pset.ast).map_err(std::convert::Into::into) } diff --git a/cedar-policy/src/api/err.rs b/cedar-policy/src/api/err.rs index 58f818b96..d81dac72b 100644 --- a/cedar-policy/src/api/err.rs +++ b/cedar-policy/src/api/err.rs @@ -30,8 +30,10 @@ pub use cedar_policy_core::extensions::{ use cedar_policy_core::{ast, authorizer, est}; pub use cedar_policy_validator::cedar_schema::{schema_warnings, SchemaWarning}; #[cfg(feature = "entity-manifest")] +pub use cedar_policy_validator::entity_manifest::slicing::EntitySliceError; +#[cfg(feature = "entity-manifest")] use cedar_policy_validator::entity_manifest::{ - self, FailedAnalysisError, PartialExpressionError, PartialRequestError, + self, PartialExpressionError, PartialRequestError, UnsupportedCedarFeatureError, }; pub use cedar_policy_validator::{schema_errors, SchemaError}; use miette::Diagnostic; @@ -1212,11 +1214,6 @@ pub mod request_validation_errors { } /// An error generated by entity slicing. -/// See [`FailedAnalysisError`] for details on the fragment -/// of Cedar handled by entity slicing. -// CAUTION: this type is publicly exported in `cedar-policy`. -// Don't make fields `pub`, don't make breaking changes, and use caution -// when adding public methods. #[derive(Debug, Error, Diagnostic)] #[non_exhaustive] #[cfg(feature = "entity-manifest")] @@ -1238,12 +1235,10 @@ pub enum EntityManifestError { #[error(transparent)] #[diagnostic(transparent)] PartialExpression(#[from] PartialExpressionError), - - /// A policy was not analyzable because it used unsupported operators. - /// See [`FailedAnalysisError`] for more details. + /// Encounters unsupported Cedar feature #[error(transparent)] #[diagnostic(transparent)] - FailedAnalysis(#[from] FailedAnalysisError), + UnsupportedCedarFeature(#[from] UnsupportedCedarFeatureError), } #[cfg(feature = "entity-manifest")] @@ -1256,7 +1251,9 @@ impl From for EntityManifestError { entity_manifest::EntityManifestError::PartialExpression(e) => { Self::PartialExpression(e) } - entity_manifest::EntityManifestError::FailedAnalysis(e) => Self::FailedAnalysis(e), + entity_manifest::EntityManifestError::UnsupportedCedarFeature(e) => { + Self::UnsupportedCedarFeature(e) + } } } } diff --git a/cedar-testing/Cargo.toml b/cedar-testing/Cargo.toml index 76648b9ae..9f7b6e239 100644 --- a/cedar-testing/Cargo.toml +++ b/cedar-testing/Cargo.toml @@ -19,6 +19,7 @@ default = ["ipaddr", "decimal"] decimal = ["cedar-policy/decimal"] ipaddr = ["cedar-policy/ipaddr"] integration-testing = [] +entity-manifest = ["cedar-policy/entity-manifest"] [dev-dependencies] assert_cmd = "2.0" diff --git a/cedar-testing/src/integration_testing.rs b/cedar-testing/src/integration_testing.rs index 54275aa53..21d89bd8a 100644 --- a/cedar-testing/src/integration_testing.rs +++ b/cedar-testing/src/integration_testing.rs @@ -27,6 +27,8 @@ use cedar_policy_core::ast::{EntityUID, PolicySet, Request}; use cedar_policy_core::entities::{self, json::err::JsonDeserializationErrorContext, Entities}; use cedar_policy_core::extensions::Extensions; use cedar_policy_core::{jsonvalue::JsonValueWithNoDuplicateKeys, parser}; +#[cfg(feature = "entity-manifest")] +use cedar_policy_validator::entity_manifest::compute_entity_manifest; use cedar_policy_validator::ValidatorSchema; use serde::{Deserialize, Serialize}; use std::{ @@ -252,6 +254,47 @@ pub fn parse_request_from_test( }) } +/// Asserts that the test response matches the json request, +/// including errors when the error comparison mode is enabled. +fn check_matches_json( + response: TestResponse, + json_request: &JsonRequest, + error_comparison_mode: ErrorComparisonMode, + test_name: &str, +) { + // check decision + assert_eq!( + response.response.decision(), + json_request.decision, + "test {test_name} failed for request \"{}\": unexpected decision", + &json_request.description + ); + // check reason + let reason: HashSet = response.response.diagnostics().reason().cloned().collect(); + assert_eq!( + reason, + json_request.reason.iter().cloned().collect(), + "test {test_name} failed for request \"{}\": unexpected reason", + &json_request.description + ); + // check errors, if applicable + // for now, the integration tests only support the `PolicyIds` comparison mode + if matches!(error_comparison_mode, ErrorComparisonMode::PolicyIds) { + let errors: HashSet = response + .response + .diagnostics() + .errors() + .map(|err| err.policy_id.clone()) + .collect(); + assert_eq!( + errors, + json_request.errors.iter().cloned().collect(), + "test {test_name} failed for request \"{}\": unexpected errors", + &json_request.description + ); + } +} + /// Run an integration test starting from a pre-parsed `JsonTest`. /// /// # Panics @@ -294,38 +337,28 @@ pub fn perform_integration_test( let response = test_impl .is_authorized(&request, &policies, &entities) .expect("Authorization failed"); - // check decision - assert_eq!( - response.response.decision(), - json_request.decision, - "test {test_name} failed for request \"{}\": unexpected decision", - &json_request.description - ); - // check reason - let reason: HashSet = response.response.diagnostics().reason().cloned().collect(); - assert_eq!( - reason, - json_request.reason.into_iter().collect(), - "test {test_name} failed for request \"{}\": unexpected reason", - &json_request.description - ); - // check errors, if applicable - // for now, the integration tests only support the `PolicyIds` comparison mode - if matches!( + check_matches_json( + response, + &json_request, test_impl.error_comparison_mode(), - ErrorComparisonMode::PolicyIds - ) { - let errors: HashSet = response - .response - .diagnostics() - .errors() - .map(|err| err.policy_id.clone()) - .collect(); - assert_eq!( - errors, - json_request.errors.into_iter().collect(), - "test {test_name} failed for request \"{}\": unexpected errors", - &json_request.description + test_name, + ); + + // now check that entity slicing arrives at the same decision + #[cfg(feature = "entity-manifest")] + if should_validate { + let entity_manifest = compute_entity_manifest(&schema, &policies).expect("test failed"); + let entity_slice = entity_manifest + .slice_entities(&entities, &request) + .expect("test failed"); + let slice_response = test_impl + .is_authorized(&request, &policies, &entity_slice) + .expect("Authorization failed"); + check_matches_json( + slice_response, + &json_request, + test_impl.error_comparison_mode(), + test_name, ); } }