From 5a9b583d2a28d12bc8523b204e75a0da6c8b04ff Mon Sep 17 00:00:00 2001 From: Jan Berkel Date: Sun, 24 Sep 2017 22:53:16 +0200 Subject: [PATCH] Introduce FailableIterator --- Sources/SQLite/Core/Statement.swift | 40 +++++++++++++++++++++-------- Sources/SQLite/Typed/Query.swift | 38 ++++++++++----------------- Tests/SQLiteTests/QueryTests.swift | 10 ++++---- 3 files changed, 48 insertions(+), 40 deletions(-) diff --git a/Sources/SQLite/Core/Statement.swift b/Sources/SQLite/Core/Statement.swift index c817e887..5df00791 100644 --- a/Sources/SQLite/Core/Statement.swift +++ b/Sources/SQLite/Core/Statement.swift @@ -191,14 +191,6 @@ public final class Statement { } -extension Statement { - - func rowCursorNext() throws -> [Binding?]? { - return try step() ? Array(row) : nil - } - -} - extension Statement : Sequence { public func makeIterator() -> Statement { @@ -208,12 +200,38 @@ extension Statement : Sequence { } -extension Statement : IteratorProtocol { +public protocol FailableIterator : IteratorProtocol { + func failableNext() throws -> Self.Element? +} - public func next() -> [Binding?]? { - return try! step() ? Array(row) : nil +extension FailableIterator { + public func next() -> Element? { + return try! failableNext() } + public func map(_ transform: (Element) throws -> T) throws -> [T] { + var elements = [T]() + while let row = try failableNext() { + elements.append(try transform(row)) + } + return elements + } +} + +extension Array { + public init(_ failableIterator: I) throws where I.Element == Element { + self.init() + while let row = try failableIterator.failableNext() { + append(row) + } + } +} + +extension Statement : FailableIterator { + public typealias Element = [Binding?] + public func failableNext() throws -> [Binding?]? { + return try step() ? Array(row) : nil + } } extension Statement : CustomStringConvertible { diff --git a/Sources/SQLite/Typed/Query.swift b/Sources/SQLite/Typed/Query.swift index e9fa941c..cc54649e 100644 --- a/Sources/SQLite/Typed/Query.swift +++ b/Sources/SQLite/Typed/Query.swift @@ -894,35 +894,19 @@ public struct Delete : ExpressionType { } -public struct RowCursor { + +public struct RowCursor: FailableIterator { + public typealias Element = Row let statement: Statement let columnNames: [String: Int] - - public func next() throws -> Row? { - return try statement.rowCursorNext().flatMap { Row(columnNames, $0) } - } - - public func map(_ transform: (Row) throws -> T) throws -> [T] { - var elements = [T]() - while true { - if let row = try next() { - elements.append(try transform(row)) - } else { - break - } - } - - return elements + + public func failableNext() throws -> Row? { + return try statement.failableNext().flatMap { Row(columnNames, $0) } } } + extension Connection { - - public func prepareCursor(_ query: QueryType) throws -> RowCursor { - let expression = query.expression - let statement = try prepare(expression.template, expression.bindings) - return RowCursor(statement: statement, columnNames: try columnNamesForQuery(query)) - } public func prepare(_ query: QueryType) throws -> AnySequence { let expression = query.expression @@ -935,6 +919,12 @@ extension Connection { } } + public func prepareRowCursor(_ query: QueryType) throws -> RowCursor { + let expression = query.expression + let statement = try prepare(expression.template, expression.bindings) + return RowCursor(statement: statement, columnNames: try columnNamesForQuery(query)) + } + private func columnNamesForQuery(_ query: QueryType) throws -> [String: Int] { var (columnNames, idx) = ([String: Int](), 0) column: for each in query.clauses.select.columns { @@ -1002,7 +992,7 @@ extension Connection { } public func pluck(_ query: QueryType) throws -> Row? { - return try prepareCursor(query.limit(1, query.clauses.limit?.offset)).next() + return try prepareRowCursor(query.limit(1, query.clauses.limit?.offset)).failableNext() } /// Runs an `Insert` query. diff --git a/Tests/SQLiteTests/QueryTests.swift b/Tests/SQLiteTests/QueryTests.swift index feb0596d..2c026b9b 100644 --- a/Tests/SQLiteTests/QueryTests.swift +++ b/Tests/SQLiteTests/QueryTests.swift @@ -343,14 +343,14 @@ class QueryIntegrationTests : SQLiteTestCase { _ = user[users[managerId]] } } - - func test_prepareCursor() { + + func test_prepareRowCursor() { let names = ["a", "b", "c"] try! InsertUsers(names) - + let emailColumn = Expression("email") - let emails = try! db.prepareCursor(users).map { $0[emailColumn] } - + let emails = try! db.prepareRowCursor(users).map { $0[emailColumn] } + XCTAssertEqual(names.map({ "\($0)@example.com" }), emails.sorted()) }