From b1512e88e06fb327c922d891a9006032f916bca5 Mon Sep 17 00:00:00 2001 From: Joakim Reinert Date: Tue, 28 Aug 2018 01:31:11 +0200 Subject: [PATCH] add all read methods to LiveTransaction and have them accept tx in Repo --- spec/transactions_spec.cr | 48 +++++++++++ src/crecto/live_transaction.cr | 92 ++++++++++++++++++++- src/crecto/repo.cr | 145 ++++++++++++++++----------------- 3 files changed, 211 insertions(+), 74 deletions(-) diff --git a/spec/transactions_spec.cr b/spec/transactions_spec.cr index c20f7d0..015b2dc 100644 --- a/spec/transactions_spec.cr +++ b/spec/transactions_spec.cr @@ -336,6 +336,54 @@ describe Crecto do Repo.all(User, Query.where(name: "perform_all")).size.should eq 2 Repo.all(User, Query.where(name: "perform_all_io2oj999")).size.should eq 0 end + + it "allows reading records inserted inside the transaction" do + insert_user = User.new + insert_user.name = "insert_user" + + Repo.transaction! do |tx| + id = tx.insert!(insert_user).instance.id + tx.get(User, id).should_not eq(nil) + tx.get!(User, id).should_not eq(nil) + tx.get(User, id, Query.new).should_not eq(nil) + tx.get!(User, id, Query.new).should_not eq(nil) + tx.get_by(User, id: id).should_not eq(nil) + tx.get_by!(User, id: id).should_not eq(nil) + tx.get_by(User, id: id).should_not eq(nil) + tx.get_by!(User, id: id).should_not eq(nil) + tx.get_by(User, Query.where(id: id)).should_not eq(nil) + tx.get_by!(User, Query.where(id: id)).should_not eq(nil) + tx.all(User, Query.where(id: id)).first.should_not eq(nil) + tx.all(User, Query.where(id: id), preloads: [] of Symbol).first.should_not eq(nil) + end + end + + it "allows nesting transactions" do + Repo.delete_all(Post) + Repo.delete_all(User) + + insert_user = User.new + insert_user.name = "nested_transactions_insert_user" + invalid_user = User.new + delete_user = quick_create_user("nested_transactions_delete_user") + + Repo.transaction! do |tx| + tx.insert!(insert_user) + + expect_raises Crecto::InvalidChangeset do + Repo.transaction! do |inner_tx| + inner_tx.delete!(delete_user) + inner_tx.insert!(invalid_user) + end + end + end + + # check insert happened + Repo.all(User, Query.where(name: "nested_transactions_insert_user")).size.should eq 1 + + # check delete didn't happen + Repo.all(User, Query.where(name: "nested_transactions_delete_user")).size.should eq 1 + end end end end diff --git a/src/crecto/live_transaction.cr b/src/crecto/live_transaction.cr index f40aff3..b3ef944 100644 --- a/src/crecto/live_transaction.cr +++ b/src/crecto/live_transaction.cr @@ -1,10 +1,86 @@ -require "./multi" +require "./repo/query" module Crecto class LiveTransaction(T) + alias Query = Repo::Query + def initialize(@tx : DB::Transaction, @repo : T) end + def raw_exec(args : Array) + @repo.raw_exec(args, @tx) + end + + def raw_exec(*args) + @repo.raw_exec(*args, @tx) + end + + def raw_query(query, *args) + @repo.raw_query(query, *args, @tx) do |rs| + yield rs + end + end + + def raw_query(query, args : Array) + @repo.raw_query(query, args, @tx) + end + + def raw_query(query, *args) + @repo.raw_query(query, *args) + end + + def raw_scalar(*args) + @repo.raw_scalar(*args, @tx) + end + + def all(queryable, query : Query? = Query.new, **opts) + @repo.all(queryable, query, @tx, **opts) + end + + def all(queryable, query = Query.new) + @repo.all(queryable, query, @tx) + end + + def get(queryable, id) + @repo.get(queryable, id, @tx) + end + + def get!(queryable, id) + @repo.get!(queryable, id, @tx) + end + + def get(queryable, id, query : Query) + @repo.get(queryable, id, query, @tx) + end + + def get!(queryable, id, query : Query) + @repo.get!(queryable, id, query, @tx) + end + + def get_by(queryable, **opts) + @repo.get_by(queryable, @tx, **opts) + end + + def get_by(queryable, query) + @repo.get_by(queryable, query, @tx) + end + + def get_by!(queryable, **opts) + @repo.get_by!(queryable, @tx, **opts) + end + + def get_by!(queryable, query) + @repo.get_by!(queryable, query, @tx) + end + + def get_association(queryable_instance, association_name : Symbol) + @repo.get_association(queryable_instance, association_name, @tx) + end + + def get_association!(queryable_instance, association_name : Symbol) + @repo.get_association!(queryable_instance, association_name, @tx) + end + {% for type in %w[insert insert! delete delete! update update!] %} def {{type.id}}(queryable : Crecto::Model) @repo.{{type.id}}(queryable, @tx) @@ -26,5 +102,19 @@ module Crecto def update_all(queryable, query, update_tuple : NamedTuple) update_all(queryable, query, update_tuple.to_h) end + + def aggregate(queryable, aggregate_function : Symbol, field : Symbol) + @repo.aggregate(queryable, aggregate_function, field, @tx) + end + + def aggregate(queryable, aggregate_function : Symbol, field : Symbol, query : Crecto::Repo::Query) + @repo.aggregate(queryable, aggregate_function, field, query, @tx) + end + + def transaction! + @repo.transaction!(@tx) do |tx| + yield tx + end + end end end diff --git a/src/crecto/repo.cr b/src/crecto/repo.cr index ea3c6ba..cedefd1 100644 --- a/src/crecto/repo.cr +++ b/src/crecto/repo.cr @@ -28,35 +28,35 @@ module Crecto end # Run a raw `exec` query directly on the adapter connection - def raw_exec(args : Array) - config.get_connection.exec(args) + def raw_exec(args : Array, tx : DB::Transaction? = nil) + (tx || config.get_connection).exec(args) end # Run a raw `exec` query directly on the adapter connection - def raw_exec(*args) - config.get_connection.exec(*args) + def raw_exec(*args, tx : DB::Transaction? = nil) + (tx || config.get_connection).exec(*args) end # Run a raw `query` query directly on the adapter connection - def raw_query(query, *args) - config.get_connection.query(query, *args) do |rs| + def raw_query(query, *args, tx : DB::Transaction? = nil) + (tx || config.get_connection).query(query, *args) do |rs| yield rs end end # Run a raw `query` query directly on the adapter connection - def raw_query(query, args : Array) - config.get_connection.query(args) + def raw_query(query, args : Array, tx : DB::Transaction? = nil) + (tx || config.get_connection).query(args) end # Run a raw `query` query directly on the adapter connection - def raw_query(query, *args) - config.get_connection.query(*args) + def raw_query(query, *args, tx : DB::Transaction? = nil) + (tx || config.get_connection).query(*args) end # Run a raw `scalar` query directly on the adapter connection - def raw_scalar(*args) - config.get_connection.scalar(*args) + def raw_scalar(*args, tx : DB::Transaction? = nil) + (tx || config.get_connection).scalar(*args) end # Return a list of *queryable* instances using *query* @@ -65,15 +65,15 @@ module Crecto # query = Query.where(name: "fred") # users = Repo.all(User, query) # ``` - def all(queryable, query : Query? = Query.new, **opts) - q = config.adapter.run(config.get_connection, :all, queryable, query).as(DB::ResultSet) + def all(queryable, query : Query? = Query.new, tx : DB::Transaction? = nil, **opts) : Array + q = config.adapter.run(tx || config.get_connection, :all, queryable, query).as(DB::ResultSet) results = queryable.from_rs(q.as(DB::ResultSet)) opt_preloads = opts.fetch(:preload, [] of Symbol) preloads = query.preloads + opt_preloads.map { |a| {symbol: a, query: nil} } if preloads.any? - add_preloads(results, queryable, preloads) + add_preloads(results, queryable, preloads, tx) end results @@ -84,8 +84,8 @@ module Crecto # ``` # users = Crecto::Repo.all(User) # ``` - def all(queryable, query = Query.new) - q = config.adapter.run(config.get_connection, :all, queryable, query).as(DB::ResultSet) + def all(queryable, query = Query.new, tx : DB::Transaction? = nil) : Array + q = config.adapter.run(tx || config.get_connection, :all, queryable, query).as(DB::ResultSet) results = queryable.from_rs(q) results end @@ -95,8 +95,8 @@ module Crecto # ``` # user = Repo.get(User, 1) # ``` - def get(queryable, id) - q = config.adapter.run(config.get_connection, :get, queryable, id).as(DB::ResultSet) + def get(queryable, id, tx : DB::Transaction? = nil) + q = config.adapter.run(tx || config.get_connection, :get, queryable, id).as(DB::ResultSet) results = queryable.from_rs(q) results.first if results.any? end @@ -107,8 +107,8 @@ module Crecto # ``` # user = Repo.get(User, 1) # ``` - def get!(queryable, id) - if result = get(queryable, id) + def get!(queryable, id, tx : DB::Transaction? = nil) + if result = get(queryable, id, tx) result else raise NoResults.new("No Results") @@ -122,13 +122,13 @@ module Crecto # query = Query.preload(:posts) # user = Repo.get(User, 1, query) # ``` - def get(queryable, id, query : Query) - q = config.adapter.run(config.get_connection, :get, queryable, id).as(DB::ResultSet) + def get(queryable, id, query : Query, tx : DB::Transaction? = nil) + q = config.adapter.run(tx || config.get_connection, :get, queryable, id).as(DB::ResultSet) results = queryable.from_rs(q) if results.any? if query.preloads.any? - add_preloads(results, queryable, query.preloads) + add_preloads(results, queryable, query.preloads, tx) end results.first @@ -143,8 +143,8 @@ module Crecto # query = Query.preload(:posts) # user = Repo.get(User, 1, query) # ``` - def get!(queryable, id, query : Query) - if result = get(queryable, id, query) + def get!(queryable, id, query : Query, tx : DB::Transaction? = nil) + if result = get(queryable, id, query, tx) result else raise NoResults.new("No Results") @@ -156,8 +156,8 @@ module Crecto # ``` # user = Repo.get_by(User, name: "fred", age: 21) # ``` - def get_by(queryable, **opts) - get_by(queryable, Query.where(**opts)) + def get_by(queryable, tx : DB::Transaction? = nil, **opts) + get_by(queryable, Query.where(**opts), tx) end # Return a single nilable instance of *queryable* using the *query* param @@ -166,8 +166,8 @@ module Crecto # ``` # user = Repo.get_by(User, Query.where(name: "fred", age: 21)) # ``` - def get_by(queryable, query) - results = all(queryable, query.limit(1)) + def get_by(queryable, query, tx : DB::Transaction? = nil) + results = all(queryable, query.limit(1), tx) results.first if results.any? end @@ -177,8 +177,8 @@ module Crecto # ``` # user = Repo.get_by(User, name: "fred", age: 21) # ``` - def get_by!(queryable, **opts) - get_by!(queryable, Query.where(**opts)) + def get_by!(queryable, tx : DB::Transaction? = nil, **opts) + get_by!(queryable, Query.where(**opts), tx) end # Return a single instance of *queryable* using the *query* param @@ -187,8 +187,8 @@ module Crecto # ``` # user = Repo.get_by(User, Query.where(name: "fred", age: 21)) # ``` - def get_by!(queryable, query) - if result = get_by(queryable, query) + def get_by!(queryable, query, tx : DB::Transaction? = nil) + if result = get_by(queryable, query, tx) result else raise NoResults.new("No Results") @@ -200,15 +200,14 @@ module Crecto # ``` # user = Crecto::Repo.get(User, 1) # post = Repo.get_association(user, :post) - # ``` - def get_association(queryable_instance, association_name : Symbol, query : Query = Query.new) + def get_association(queryable_instance, association_name : Symbol, query : Query = Query.new, tx : DB::Transaction? = nil) case queryable_instance.class.association_type_for_association(association_name) when :has_many - get_has_many_association(queryable_instance, association_name, query) + get_has_many_association(queryable_instance, association_name, query, tx) when :has_one - get_has_one_association(queryable_instance, association_name, query) + get_has_one_association(queryable_instance, association_name, query, tx) when :belongs_to - get_belongs_to_association(queryable_instance, association_name, query) + get_belongs_to_association(queryable_instance, association_name, query, tx) end end @@ -220,8 +219,8 @@ module Crecto # user = Crecto::Repo.get(User, 1) # post = Repo.get_association!(user, :post) # ``` - def get_association!(queryable_instance, association_name : Symbol, query : Query = Query.new) - if result = get_association(queryable_instance, association_name, query) + def get_association!(queryable_instance, association_name : Symbol, query : Query = Query.new, tx : DB::Transaction? = nil) + if result = get_association(queryable_instance, association_name, query, tx) result else raise NoResults.new("No Results") @@ -447,8 +446,8 @@ module Crecto # ``` # Repo.query(User, "select * from users where id > ?", [30]) # ``` - def query(queryable, sql : String, params = [] of DbValue) : Array - q = config.adapter.run(config.get_connection, :sql, sql, params).as(DB::ResultSet) + def query(queryable, sql : String, params = [] of DbValue, tx : DB::Transaction? = nil) : Array + q = config.adapter.run(tx || config.get_connection, :sql, sql, params).as(DB::ResultSet) results = queryable.from_rs(q) results end @@ -463,8 +462,8 @@ module Crecto # ``` # query = Crecto::Repo.query("select * from users where id = ?", [30]) # ``` - def query(sql : String, params = [] of DbValue) : DB::ResultSet - config.adapter.run(config.get_connection, :sql, sql, params).as(DB::ResultSet) + def query(sql : String, params = [] of DbValue, tx : DB::Transaction? = nil) : DB::ResultSet + config.adapter.run(tx || config.get_connection, :sql, sql, params).as(DB::ResultSet) end def transaction(multi : Crecto::Multi) @@ -496,8 +495,8 @@ module Crecto # tx.insert!(post) # end # ``` - def transaction! - config.get_connection.transaction do |tx| + def transaction!(tx : DB::Transaction? = nil) + (tx || config.get_connection).transaction do |tx| begin yield LiveTransaction.new(tx, self) rescue error : Exception @@ -529,16 +528,16 @@ module Crecto # Calculate the given aggregate `aggregate_function` over the given `field` # Aggregate `aggregate_function` must be one of (:avg, :count, :max, :min:, :sum) - def aggregate(queryable, aggregate_function : Symbol, field : Symbol) + def aggregate(queryable, aggregate_function : Symbol, field : Symbol, tx : DB::Transaction? = nil) raise InvalidOption.new("Aggregate must be one of :avg, :count, :max, :min:, :sum") unless [:avg, :count, :max, :min, :sum].includes?(aggregate_function) - config.adapter.aggregate(config.get_connection, queryable, aggregate_function, field) + config.adapter.aggregate(tx || config.get_connection, queryable, aggregate_function, field) end - def aggregate(queryable, aggregate_function : Symbol, field : Symbol, query : Crecto::Repo::Query) + def aggregate(queryable, aggregate_function : Symbol, field : Symbol, query : Crecto::Repo::Query, tx : DB::Transaction? = nil) raise InvalidOption.new("Aggregate must be one of :avg, :count, :max, :min:, :sum") unless [:avg, :count, :max, :min, :sum].includes?(aggregate_function) - config.adapter.aggregate(config.get_connection, queryable, aggregate_function, field, query) + config.adapter.aggregate(tx || config.get_connection, queryable, aggregate_function, field, query) end private def check_dependents(changeset, tx : DB::Transaction?) : Nil @@ -609,32 +608,32 @@ module Crecto end end - private def add_preloads(results, queryable, preloads) + private def add_preloads(results, queryable, preloads, tx) preloads.each do |preload| case queryable.association_type_for_association(preload[:symbol]) when :has_many - has_many_preload(results, queryable, preload) + has_many_preload(results, queryable, preload, tx) when :has_one - has_one_preload(results, queryable, preload) + has_one_preload(results, queryable, preload, tx) when :belongs_to - belongs_to_preload(results, queryable, preload) + belongs_to_preload(results, queryable, preload, tx) end end end - private def has_one_preload(results, queryable, preload) - join_direct(results, queryable, preload, singular: true) + private def has_one_preload(results, queryable, preload, tx : DB::Transaction?) + join_direct(results, queryable, preload, tx, singular: true) end - private def has_many_preload(results, queryable, preload) + private def has_many_preload(results, queryable, preload, tx) if queryable.through_key_for_association(preload[:symbol]) - join_through(results, queryable, preload) + join_through(results, queryable, preload, tx) else - join_direct(results, queryable, preload) + join_direct(results, queryable, preload, tx) end end - private def join_direct(results, queryable, preload, singular = false) + private def join_direct(results, queryable, preload, tx, singular = false) ids = results.map(&.pkey_value.as(PkeyValue)) foreign_key = queryable.foreign_key_for_association(preload[:symbol]) return if foreign_key.nil? @@ -644,7 +643,7 @@ module Crecto end association_klass = queryable.klass_for_association(preload[:symbol]) return if association_klass.nil? - relation_items = all(association_klass, query) + relation_items = all(association_klass, query, tx) relation_items = relation_items.group_by { |t| queryable.foreign_key_value_for_association(preload[:symbol], t) } results.each do |result| @@ -655,7 +654,7 @@ module Crecto end end - private def join_through(results, queryable, preload) + private def join_through(results, queryable, preload, tx) ids = results.map(&.pkey_value.as(PkeyValue)) foreign_key = queryable.foreign_key_for_association(preload[:symbol]) return if foreign_key.nil? @@ -663,7 +662,7 @@ module Crecto # UserProjects association_klass = queryable.klass_for_association(queryable.through_key_for_association(preload[:symbol]).as(Symbol)) return if association_klass.nil? - join_table_items = all(association_klass, join_query) + join_table_items = all(association_klass, join_query, tx) # array of Project id's if join_table_items.empty? @@ -682,7 +681,7 @@ module Crecto association_query = association_query.combine(preload_query) end # Projects - relation_items = all(association_klass, association_query) + relation_items = all(association_klass, association_query, tx) # UserProject grouped by user_id join_table_items = join_table_items.group_by { |t| queryable.foreign_key_value_for_association(queryable.through_key_for_association(preload[:symbol]).as(Symbol), t) } @@ -697,7 +696,7 @@ module Crecto end end - private def belongs_to_preload(results, queryable, preload) + private def belongs_to_preload(results, queryable, preload, tx) ids = results.map { |r| queryable.foreign_key_value_for_association(preload[:symbol], r).as(PkeyValue) } ids.compact! return if ids.empty? @@ -712,7 +711,7 @@ module Crecto end association_klass = queryable.klass_for_association(preload[:symbol]) return if association_klass.nil? - relation_items = all(association_klass, query) + relation_items = all(association_klass, query, tx) unless relation_items.nil? relation_items = relation_items.group_by { |t| t.pkey_value.as(PkeyValue) } @@ -727,28 +726,28 @@ module Crecto end end - private def get_has_many_association(instance, association : Symbol, query : Query) + private def get_has_many_association(instance, association : Symbol, query : Query, tx) queryable = instance.class foreign_key = queryable.foreign_key_for_association(association) return if foreign_key.nil? query = query.where(foreign_key, instance.pkey_value) association_klass = queryable.klass_for_association(association) return if association_klass.nil? - all(association_klass, query) + all(association_klass, query, tx) end - private def get_has_one_association(instance, association : Symbol, query : Query) - many = get_has_many_association(instance, association, query) + private def get_has_one_association(instance, association : Symbol, query : Query, tx) + many = get_has_many_association(instance, association, query, tx) return if many.nil? many.first? end - private def get_belongs_to_association(instance, association : Symbol, query : Query) + private def get_belongs_to_association(instance, association : Symbol, query : Query, tx) queryable = instance.class klass_for_association = queryable.klass_for_association(association) return if klass_for_association.nil? key_for_association = queryable.foreign_key_value_for_association(association, instance) - get(klass_for_association, key_for_association, query) + get(klass_for_association, key_for_association, query, tx) end end end