diff --git a/Sources/PSQLKit/Expressions/CountExpression.swift b/Sources/PSQLKit/Expressions/CountExpression.swift index 909fc40..c9ae293 100644 --- a/Sources/PSQLKit/Expressions/CountExpression.swift +++ b/Sources/PSQLKit/Expressions/CountExpression.swift @@ -6,9 +6,16 @@ import SQLKit public struct CountExpression: AggregateExpression { let content: Content + let isDistinct: Bool + + init(_ content: Content, distinct: Bool) { + self.content = content + self.isDistinct = distinct + } public init(_ content: Content) { self.content = content + self.isDistinct = false } } @@ -16,15 +23,20 @@ extension CountExpression: SelectSQLExpression where Content: SelectSQLExpression { public var selectSqlExpression: SQLExpression { - _Select(content: self.content) + _Select(content: self.content, distinct: self.isDistinct) } private struct _Select: SQLExpression { let content: Content + let distinct: Bool func serialize(to serializer: inout SQLSerializer) { serializer.write("COUNT") serializer.write("(") + if self.distinct { + serializer.write("DISTINCT") + serializer.writeSpace() + } self.content.selectSqlExpression.serialize(to: &serializer) serializer.write(")") } @@ -58,4 +70,8 @@ extension CountExpression { public func `as`(_ alias: String) -> ExpressionAlias> { ExpressionAlias(expression: self, alias: alias) } + + public func distinct(_ isDistinct: Bool = true) -> Self { + .init(self.content, distinct: isDistinct) + } } diff --git a/Sources/PSQLKit/PSQLExpression.swift b/Sources/PSQLKit/PSQLExpression.swift index 44b87a0..980e14b 100644 --- a/Sources/PSQLKit/PSQLExpression.swift +++ b/Sources/PSQLKit/PSQLExpression.swift @@ -19,3 +19,9 @@ extension PSQLExpression where Self: Encodable { .init(self) } } + +extension PSQLExpression where Self: RawRepresentable, RawValue: PSQLExpression { + public static var postgresColumnType: PostgresColumnType { + RawValue.postgresColumnType + } +} diff --git a/Sources/PSQLKit/TypeEquatable.swift b/Sources/PSQLKit/TypeEquatable.swift index aa8eead..ca76989 100644 --- a/Sources/PSQLKit/TypeEquatable.swift +++ b/Sources/PSQLKit/TypeEquatable.swift @@ -6,3 +6,7 @@ import Foundation public protocol TypeEquatable { associatedtype CompareType } + +extension TypeEquatable where Self: RawRepresentable, RawValue: TypeEquatable { + public typealias CompareType = RawValue +} diff --git a/Tests/PSQLKitTests/ExpressionTests.swift b/Tests/PSQLKitTests/ExpressionTests.swift index 7425e0f..12ece34 100644 --- a/Tests/PSQLKitTests/ExpressionTests.swift +++ b/Tests/PSQLKitTests/ExpressionTests.swift @@ -63,6 +63,30 @@ final class ExpressionTests: PSQLTestCase { XCTAssertEqual(psqlkitSerializer.sql, compare) } + func testCountDistinct() { + SELECT { + COUNT(f.$name) + .distinct() + COUNT(f.$age) + .distinct() + .as("age") + } + .serialize(to: &fluentSerializer) + + SELECT { + COUNT(p.$name) + .distinct() + COUNT(p.$age) + .distinct() + .as("age") + } + .serialize(to: &psqlkitSerializer) + + let compare = #"SELECT COUNT(DISTINCT "x"."name"::TEXT), COUNT(DISTINCT "x"."age"::INTEGER) AS "age""# + XCTAssertEqual(fluentSerializer.sql, compare) + XCTAssertEqual(psqlkitSerializer.sql, compare) + } + func testSum() { SELECT { SUM(f.$name) diff --git a/Tests/PSQLKitTests/PSQLTests.swift b/Tests/PSQLKitTests/PSQLTests.swift index 48ed259..f47e361 100644 --- a/Tests/PSQLKitTests/PSQLTests.swift +++ b/Tests/PSQLKitTests/PSQLTests.swift @@ -22,6 +22,8 @@ final class FluentModel: Model, Table { var money: Double @Field(key: "birthday") var birthday: Date + @Field(key: "category") + var category: Category @Group(key: "pet") var pet: Pet @@ -44,6 +46,11 @@ final class FluentModel: Model, Table { init() {} } } + + enum Category: String, Codable, Equatable, TypeEquatable, PSQLExpression { + case yes + case no + } } struct PSQLModel: Table { @@ -61,6 +68,8 @@ struct PSQLModel: Table { var money: Double @Column(key: "birthday") var birthday: Date + @Column(key: "category") + var category: Category @NestedColumn(key: "pet") var pet: Pet @@ -83,6 +92,11 @@ struct PSQLModel: Table { init() {} } } + + enum Category: String, Codable, Equatable, TypeEquatable, PSQLExpression { + case yes + case no + } } class PSQLTestCase: XCTestCase { diff --git a/Tests/PSQLKitTests/WhereTests.swift b/Tests/PSQLKitTests/WhereTests.swift index 275d88c..e473d26 100644 --- a/Tests/PSQLKitTests/WhereTests.swift +++ b/Tests/PSQLKitTests/WhereTests.swift @@ -29,6 +29,24 @@ final class WhereTests: PSQLTestCase { XCTAssertEqual(fluentSerializer.sql, compare) } + func testEnum() { + WHERE { + FluentModel.$category != FluentModel.$category + FluentModel.$category == FluentModel.Category.yes.rawValue + } + .serialize(to: &fluentSerializer) + + WHERE { + PSQLModel.$category != PSQLModel.$category + PSQLModel.$category == PSQLModel.Category.yes.rawValue + } + .serialize(to: &psqlkitSerializer) + + let compare = #"WHERE ("my_model"."category" != "my_model"."category") AND ("my_model"."category" = 'yes')"# + XCTAssertEqual(fluentSerializer.sql, compare) + XCTAssertEqual(fluentSerializer.sql, compare) + } + func testMultiple() { WHERE { FluentModel.$name == f.$title