Skip to content

Commit

Permalink
feat(api): add support for GraphQL filter attributeExists (#3838)
Browse files Browse the repository at this point in the history
* add support for GraphQL filter attributeExists

* fix(datastore): deduplicate SQL statement of attributeExists and eq/ne

* fix(datastore): flacky unit test
  • Loading branch information
5d authored Sep 23, 2024
1 parent 57a05c3 commit 815bf2a
Show file tree
Hide file tree
Showing 21 changed files with 401 additions and 72 deletions.
36 changes: 36 additions & 0 deletions Amplify/Categories/DataStore/Model/Internal/Persistable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ struct PersistableHelper {
return lhs == rhs
case let (lhs, rhs) as (String, String):
return lhs == rhs
case let (lhs, rhs) as (any EnumPersistable, String):
return lhs.rawValue == rhs
case let (lhs, rhs) as (String, any EnumPersistable):
return lhs == rhs.rawValue
case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable):
return lhs.rawValue == rhs.rawValue
default:
return false
}
Expand Down Expand Up @@ -94,6 +100,12 @@ struct PersistableHelper {
return lhs == Double(rhs)
case let (lhs, rhs) as (String, String):
return lhs == rhs
case let (lhs, rhs) as (any EnumPersistable, String):
return lhs.rawValue == rhs
case let (lhs, rhs) as (String, any EnumPersistable):
return lhs == rhs.rawValue
case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable):
return lhs.rawValue == rhs.rawValue
default:
return false
}
Expand Down Expand Up @@ -122,6 +134,12 @@ struct PersistableHelper {
return lhs <= Double(rhs)
case let (lhs, rhs) as (String, String):
return lhs <= rhs
case let (lhs, rhs) as (any EnumPersistable, String):
return lhs.rawValue <= rhs
case let (lhs, rhs) as (String, any EnumPersistable):
return lhs <= rhs.rawValue
case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable):
return lhs.rawValue <= rhs.rawValue
default:
return false
}
Expand Down Expand Up @@ -150,6 +168,12 @@ struct PersistableHelper {
return lhs < Double(rhs)
case let (lhs, rhs) as (String, String):
return lhs < rhs
case let (lhs, rhs) as (any EnumPersistable, String):
return lhs.rawValue < rhs
case let (lhs, rhs) as (String, any EnumPersistable):
return lhs < rhs.rawValue
case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable):
return lhs.rawValue < rhs.rawValue
default:
return false
}
Expand Down Expand Up @@ -178,6 +202,12 @@ struct PersistableHelper {
return lhs >= Double(rhs)
case let (lhs, rhs) as (String, String):
return lhs >= rhs
case let (lhs, rhs) as (any EnumPersistable, String):
return lhs.rawValue >= rhs
case let (lhs, rhs) as (String, any EnumPersistable):
return lhs >= rhs.rawValue
case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable):
return lhs.rawValue >= rhs.rawValue
default:
return false
}
Expand Down Expand Up @@ -206,6 +236,12 @@ struct PersistableHelper {
return Double(lhs) > rhs
case let (lhs, rhs) as (String, String):
return lhs > rhs
case let (lhs, rhs) as (any EnumPersistable, String):
return lhs.rawValue > rhs
case let (lhs, rhs) as (String, any EnumPersistable):
return lhs > rhs.rawValue
case let (lhs, rhs) as (any EnumPersistable, any EnumPersistable):
return lhs.rawValue > rhs.rawValue
default:
return false
}
Expand Down
5 changes: 5 additions & 0 deletions Amplify/Categories/DataStore/Query/ModelKey.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ public protocol ModelKey: CodingKey, CaseIterable, QueryFieldOperation {}

extension CodingKey where Self: ModelKey {

// MARK: - attributeExists
public func attributeExists(_ value: Bool) -> QueryPredicateOperation {
return field(stringValue).attributeExists(value)
}

// MARK: - beginsWith
public func beginsWith(_ value: String) -> QueryPredicateOperation {
return field(stringValue).beginsWith(value)
Expand Down
7 changes: 6 additions & 1 deletion Amplify/Categories/DataStore/Query/QueryField.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ public func field(_ name: String) -> QueryField {
/// - seealso: `ModelKey`
public protocol QueryFieldOperation {
// MARK: - Functions

func attributeExists(_ value: Bool) -> QueryPredicateOperation
func beginsWith(_ value: String) -> QueryPredicateOperation
func between(start: Persistable, end: Persistable) -> QueryPredicateOperation
func contains(_ value: String) -> QueryPredicateOperation
Expand Down Expand Up @@ -61,6 +61,11 @@ public struct QueryField: QueryFieldOperation {

public let name: String

// MARK: - attributeExists
public func attributeExists(_ value: Bool) -> QueryPredicateOperation {
return QueryPredicateOperation(field: name, operator: .attributeExists(value))
}

// MARK: - beginsWith
public func beginsWith(_ value: String) -> QueryPredicateOperation {
return QueryPredicateOperation(field: name, operator: .beginsWith(value))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ extension QueryOperator: Equatable {
case let (.between(oneStart, oneEnd), .between(otherStart, otherEnd)):
return PersistableHelper.isEqual(oneStart, otherStart)
&& PersistableHelper.isEqual(oneEnd, otherEnd)
case let (.attributeExists(one), .attributeExists(other)):
return one == other
default:
return false
}
Expand Down
19 changes: 15 additions & 4 deletions Amplify/Categories/DataStore/Query/QueryOperator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ public enum QueryOperator: Encodable {
case notContains(_ value: String)
case between(start: Persistable, end: Persistable)
case beginsWith(_ value: String)
case attributeExists(_ value: Bool)

public func evaluate(target: Any) -> Bool {
public func evaluate(target: Any?) -> Bool {
switch self {
case .notEqual(let predicateValue):
return !PersistableHelper.isEqual(target, predicateValue)
Expand All @@ -34,20 +35,26 @@ public enum QueryOperator: Encodable {
case .greaterThan(let predicateValue):
return PersistableHelper.isGreaterThan(target, predicateValue)
case .contains(let predicateString):
if let targetString = target as? String {
if let targetString = target.flatMap({ $0 as? String }) {
return targetString.contains(predicateString)
}
return false
case .notContains(let predicateString):
if let targetString = target as? String {
if let targetString = target.flatMap({ $0 as? String }) {
return !targetString.contains(predicateString)
}
case .between(let start, let end):
return PersistableHelper.isBetween(start, end, target)
case .beginsWith(let predicateValue):
if let targetString = target as? String {
if let targetString = target.flatMap({ $0 as? String }) {
return targetString.starts(with: predicateValue)
}
case .attributeExists(let predicateValue):
if case .some = target {
return predicateValue == true
} else {
return predicateValue == false
}
}
return false
}
Expand Down Expand Up @@ -105,6 +112,10 @@ public enum QueryOperator: Encodable {
case .beginsWith(let value):
try container.encode("beginsWith", forKey: .type)
try container.encode(value, forKey: .value)

case .attributeExists(let value):
try container.encode("attributeExists", forKey: .type)
try container.encode(value, forKey: .value)
}
}
}
30 changes: 1 addition & 29 deletions Amplify/Categories/DataStore/Query/QueryPredicate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -155,34 +155,6 @@ public class QueryPredicateOperation: QueryPredicate, Encodable {
}

public func evaluate(target: Model) -> Bool {
guard let fieldValue = target[field] else {
return false
}

guard let value = fieldValue else {
return false
}

if let booleanValue = value as? Bool {
return self.operator.evaluate(target: booleanValue)
}

if let doubleValue = value as? Double {
return self.operator.evaluate(target: doubleValue)
}

if let intValue = value as? Int {
return self.operator.evaluate(target: intValue)
}

if let timeValue = value as? Temporal.Time {
return self.operator.evaluate(target: timeValue)
}

if let enumValue = value as? EnumPersistable {
return self.operator.evaluate(target: enumValue.rawValue)
}

return self.operator.evaluate(target: value)
return self.operator.evaluate(target: target[field]?.flatMap { $0 })
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,56 @@ extension GraphQLModelBasedTests {
XCTAssertNotNil(error)
}
}

/**
- Given: API with Post schema and optional field 'draft'
- When:
- create a new post with optional field 'draft' value .none
- Then:
- query Posts with filter {eq : null} shouldn't include the post
*/
func test_listModelsWithNilOptionalField_failedWithEqFilter() async throws {
let post = Post(title: UUID().uuidString, content: UUID().uuidString, createdAt: .now())
_ = try await Amplify.API.mutate(request: .create(post))
let posts = try await list(.list(
Post.self,
where: Post.keys.draft == nil && Post.keys.createdAt >= post.createdAt
))

XCTAssertFalse(posts.map(\.id).contains(post.id))
}

/**
- Given: DataStore with Post schema and optional field 'draft'
- When:
- create a new post with optional field 'draft' value .none
- Then:
- query Posts with filter {attributeExists : false} should include the post
*/
func test_listModelsWithNilOptionalField_successWithAttributeExistsFilter() async throws {
let post = Post(title: UUID().uuidString, content: UUID().uuidString, createdAt: .now())
_ = try await Amplify.API.mutate(request: .create(post))
let listPosts = try await list(
.list(
Post.self,
where: Post.keys.draft.attributeExists(false)
&& Post.keys.createdAt >= post.createdAt
)
)

XCTAssertTrue(listPosts.map(\.id).contains(post.id))
}

func list<M: Model>(_ request: GraphQLRequest<List<M>>) async throws -> [M] {
func getAllPages(_ list: List<M>) async throws -> [M] {
if list.hasNextPage() {
return list.elements + (try await getAllPages(list.getNextPage()))
} else {
return list.elements
}
}

return try await getAllPages(try await Amplify.API.query(request: request).get())
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ extension QueryOperator {
return "beginsWith"
case .notContains:
return "notContains"
case .attributeExists:
return "attributeExists"
}
}

Expand All @@ -212,6 +214,8 @@ extension QueryOperator {
return value
case .notContains(let value):
return value
case .attributeExists(let value):
return value
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,88 @@ class GraphQLListQueryTests: XCTestCase {
XCTAssertEqual(variables["limit"] as? Int, 1_000)
XCTAssertNotNil(variables["filter"])
}

/**
- Given:
- A Post schema with optional field 'draft'
- When:
- Using list query to filter records that either don't have 'draft' field or have 'null' value
- Then:
- the query document as expected
- the filter is encoded correctly
*/
func test_listQuery_withAttributeExistsFilter_correctlyBuildGraphQLQueryStatement() {
let post = Post.keys
let predicate = post.id.eq("id")
&& (post.draft.attributeExists(false) || post.draft.eq(nil))

var documentBuilder = ModelBasedGraphQLDocumentBuilder(modelSchema: Post.schema, operationType: .query)
documentBuilder.add(decorator: DirectiveNameDecorator(type: .list))
documentBuilder.add(decorator: PaginationDecorator())
documentBuilder.add(decorator: FilterDecorator(filter: predicate.graphQLFilter(for: Post.schema)))
let document = documentBuilder.build()
let expectedQueryDocument = """
query ListPosts($filter: ModelPostFilterInput, $limit: Int) {
listPosts(filter: $filter, limit: $limit) {
items {
id
content
createdAt
draft
rating
status
title
updatedAt
__typename
}
nextToken
}
}
"""
XCTAssertEqual(document.name, "listPosts")
XCTAssertEqual(document.stringValue, expectedQueryDocument)
guard let variables = document.variables else {
XCTFail("The document doesn't contain variables")
return
}
XCTAssertNotNil(variables["limit"])
XCTAssertEqual(variables["limit"] as? Int, 1_000)

guard let filter = variables["filter"] as? GraphQLFilter else {
XCTFail("variables should contain a valid filter")
return
}

// Test filter for a valid JSON format
let filterJSON = try? JSONSerialization.data(withJSONObject: filter,
options: .prettyPrinted)
XCTAssertNotNil(filterJSON)

let expectedFilterJSON = """
{
"and" : [
{
"id" : {
"eq" : "id"
}
},
{
"or" : [
{
"draft" : {
"attributeExists" : false
}
},
{
"draft" : {
"eq" : null
}
}
]
}
]
}
"""
XCTAssertEqual(String(data: filterJSON!, encoding: .utf8), expectedFilterJSON)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class QueryPredicateEvaluateGeneratedBoolTests: XCTestCase {

let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance)

XCTAssertFalse(evaluation)
XCTAssertTrue(evaluation)
}

func testBoolfalsenotEqualBooltrue() throws {
Expand Down Expand Up @@ -70,7 +70,7 @@ class QueryPredicateEvaluateGeneratedBoolTests: XCTestCase {

let evaluation = try predicate.evaluate(target: instance.eraseToAnyModel().instance)

XCTAssertFalse(evaluation)
XCTAssertTrue(evaluation)
}

func testBooltrueequalsBooltrue() throws {
Expand Down
Loading

0 comments on commit 815bf2a

Please sign in to comment.