From a46ca07add069710cd37fc1451dc3af82baf4da8 Mon Sep 17 00:00:00 2001 From: cthomas-figma Date: Wed, 24 Apr 2024 17:21:45 -0400 Subject: [PATCH 1/2] Create proxy TableDefinition class. Tests pass --- lib/strong_migrations.rb | 46 ++++++++++++----------- lib/strong_migrations/migration.rb | 17 +++++++-- lib/strong_migrations/table_definition.rb | 13 +++++++ 3 files changed, 50 insertions(+), 26 deletions(-) create mode 100644 lib/strong_migrations/table_definition.rb diff --git a/lib/strong_migrations.rb b/lib/strong_migrations.rb index 97635935..f0957e4b 100644 --- a/lib/strong_migrations.rb +++ b/lib/strong_migrations.rb @@ -1,23 +1,24 @@ # dependencies -require "active_support" +require 'active_support' # adapters -require_relative "strong_migrations/adapters/abstract_adapter" -require_relative "strong_migrations/adapters/mysql_adapter" -require_relative "strong_migrations/adapters/mariadb_adapter" -require_relative "strong_migrations/adapters/postgresql_adapter" +require_relative 'strong_migrations/adapters/abstract_adapter' +require_relative 'strong_migrations/adapters/mysql_adapter' +require_relative 'strong_migrations/adapters/mariadb_adapter' +require_relative 'strong_migrations/adapters/postgresql_adapter' # modules -require_relative "strong_migrations/checks" -require_relative "strong_migrations/safe_methods" -require_relative "strong_migrations/checker" -require_relative "strong_migrations/database_tasks" -require_relative "strong_migrations/migration" -require_relative "strong_migrations/migrator" -require_relative "strong_migrations/version" +require_relative 'strong_migrations/checks' +require_relative 'strong_migrations/safe_methods' +require_relative 'strong_migrations/checker' +require_relative 'strong_migrations/database_tasks' +require_relative 'strong_migrations/migration' +require_relative 'strong_migrations/migrator' +require_relative 'strong_migrations/version' +require_relative 'strong_migrations/table_definition' # integrations -require_relative "strong_migrations/railtie" if defined?(Rails) +require_relative 'strong_migrations/railtie' if defined?(Rails) module StrongMigrations class Error < StandardError; end @@ -26,10 +27,10 @@ class UnsupportedVersion < Error; end class << self attr_accessor :auto_analyze, :start_after, :checks, :error_messages, - :target_postgresql_version, :target_mysql_version, :target_mariadb_version, - :enabled_checks, :lock_timeout, :statement_timeout, :check_down, :target_version, - :safe_by_default, :target_sql_mode, :lock_timeout_retries, :lock_timeout_retry_delay, - :alphabetize_schema + :target_postgresql_version, :target_mysql_version, :target_mariadb_version, + :enabled_checks, :lock_timeout, :statement_timeout, :check_down, :target_version, + :safe_by_default, :target_sql_mode, :lock_timeout_retries, :lock_timeout_retry_delay, + :alphabetize_schema attr_writer :lock_timeout_limit end self.auto_analyze = false @@ -43,7 +44,7 @@ class << self # private def self.developer_env? - env == "development" || env == "test" + env == 'development' || env == 'test' end # private @@ -52,7 +53,7 @@ def self.env Rails.env else # default to production for safety - ENV["RACK_ENV"] || "production" + ENV['RACK_ENV'] || 'production' end end @@ -68,7 +69,7 @@ def self.add_check(&block) end def self.enable_check(check, start_after: nil) - enabled_checks[check] = {start_after: start_after} + enabled_checks[check] = { start_after: start_after } end def self.disable_check(check) @@ -86,16 +87,17 @@ def self.check_enabled?(check, version: nil) end # load error messages -require_relative "strong_migrations/error_messages" +require_relative 'strong_migrations/error_messages' ActiveSupport.on_load(:active_record) do ActiveRecord::Migration.prepend(StrongMigrations::Migration) ActiveRecord::Migrator.prepend(StrongMigrations::Migrator) + # ActiveRecord::ConnectionAdapters::TableDefinition.prepend(StrongMigrations::TableDefinition) if defined?(ActiveRecord::Tasks::DatabaseTasks) ActiveRecord::Tasks::DatabaseTasks.singleton_class.prepend(StrongMigrations::DatabaseTasks) end - require_relative "strong_migrations/schema_dumper" + require_relative 'strong_migrations/schema_dumper' ActiveRecord::SchemaDumper.prepend(StrongMigrations::SchemaDumper) end diff --git a/lib/strong_migrations/migration.rb b/lib/strong_migrations/migration.rb index 09d039a9..e3507665 100644 --- a/lib/strong_migrations/migration.rb +++ b/lib/strong_migrations/migration.rb @@ -28,13 +28,22 @@ def revert(*) end end - def safety_assured - strong_migrations_checker.class.safety_assured do - yield + def create_table(table_name, **options, &block) + if block_given? + super do |t| + table_definition = StrongMigrations::TableDefinition.new(compatible_table_definition(t)) + yield table_definition + end + else + super end end - def stop!(message, header: "Custom check") + def safety_assured(&block) + strong_migrations_checker.class.safety_assured(&block) + end + + def stop!(message, header: 'Custom check') raise StrongMigrations::UnsafeMigration, "\n=== #{header} #strong_migrations ===\n\n#{message}\n" end diff --git a/lib/strong_migrations/table_definition.rb b/lib/strong_migrations/table_definition.rb new file mode 100644 index 00000000..fab6375d --- /dev/null +++ b/lib/strong_migrations/table_definition.rb @@ -0,0 +1,13 @@ +module StrongMigrations + class TableDefinition + def initialize(ar_table_definition) + @ar_table_definition = ar_table_definition + end + + def method_missing(method, *args, **kwargs) + p "StrongMigrations::TableDefinition method_missing called with method: #{method} and args: #{args} and kwargs: #{kwargs}" + + @ar_table_definition.send(method, *args, **kwargs) + end + end +end From d3f3f8dd5243b8045814321cf612a7db2f925b6d Mon Sep 17 00:00:00 2001 From: cthomas-figma Date: Fri, 26 Apr 2024 12:10:54 -0400 Subject: [PATCH 2/2] feature works. probably need to refactor and figure out which linter project uses --- lib/strong_migrations.rb | 8 +- lib/strong_migrations/checker.rb | 41 +-- lib/strong_migrations/checks.rb | 388 +++++++++++----------- lib/strong_migrations/migration.rb | 3 +- lib/strong_migrations/table_definition.rb | 18 +- test/create_table_test.rb | 15 + test/migrations/create_table.rb | 35 ++ test/test_helper.rb | 79 ++--- 8 files changed, 329 insertions(+), 258 deletions(-) create mode 100644 test/create_table_test.rb create mode 100644 test/migrations/create_table.rb diff --git a/lib/strong_migrations.rb b/lib/strong_migrations.rb index f0957e4b..7ea62f4c 100644 --- a/lib/strong_migrations.rb +++ b/lib/strong_migrations.rb @@ -26,7 +26,7 @@ class UnsafeMigration < Error; end class UnsupportedVersion < Error; end class << self - attr_accessor :auto_analyze, :start_after, :checks, :error_messages, + attr_accessor :auto_analyze, :start_after, :checks, :table_checks, :error_messages, :target_postgresql_version, :target_mysql_version, :target_mariadb_version, :enabled_checks, :lock_timeout, :statement_timeout, :check_down, :target_version, :safe_by_default, :target_sql_mode, :lock_timeout_retries, :lock_timeout_retry_delay, @@ -38,6 +38,7 @@ class << self self.lock_timeout_retries = 0 self.lock_timeout_retry_delay = 10 # seconds self.checks = [] + self.table_checks = [] self.safe_by_default = false self.check_down = false self.alphabetize_schema = false @@ -64,6 +65,10 @@ def self.lock_timeout_limit @lock_timeout_limit end + def self.add_table_check(&block) + table_checks << block + end + def self.add_check(&block) checks << block end @@ -92,7 +97,6 @@ def self.check_enabled?(check, version: nil) ActiveSupport.on_load(:active_record) do ActiveRecord::Migration.prepend(StrongMigrations::Migration) ActiveRecord::Migrator.prepend(StrongMigrations::Migrator) - # ActiveRecord::ConnectionAdapters::TableDefinition.prepend(StrongMigrations::TableDefinition) if defined?(ActiveRecord::Tasks::DatabaseTasks) ActiveRecord::Tasks::DatabaseTasks.singleton_class.prepend(StrongMigrations::DatabaseTasks) diff --git a/lib/strong_migrations/checker.rb b/lib/strong_migrations/checker.rb index ac9ccea7..9ac06f19 100644 --- a/lib/strong_migrations/checker.rb +++ b/lib/strong_migrations/checker.rb @@ -27,13 +27,13 @@ def self.safety_assured end end - def perform(method, *args) + def perform(method, *args, &block) check_version_supported set_timeouts check_lock_timeout if !safe? || safe_by_default_method?(method) - # TODO better pattern + # TODO: better pattern # see checks.rb for methods case method when :add_check_constraint @@ -83,7 +83,7 @@ def perform(method, *args) @committed = true end - if !safe? + unless safe? # custom checks StrongMigrations.checks.each do |check| @migration.instance_exec(method, args, &check) @@ -93,18 +93,16 @@ def perform(method, *args) result = if retry_lock_timeouts?(method) - # TODO figure out how to handle methods that generate multiple statements + # TODO: figure out how to handle methods that generate multiple statements # like add_reference(table, ref, index: {algorithm: :concurrently}) # lock timeout after first statement will cause retry to fail - retry_lock_timeouts { yield } + retry_lock_timeouts(&block) else yield end # outdated statistics + a new index can hurt performance of existing queries - if StrongMigrations.auto_analyze && direction == :up && method == :add_index - adapter.analyze_table(args[0]) - end + adapter.analyze_table(args[0]) if StrongMigrations.auto_analyze && direction == :up && method == :add_index result end @@ -129,6 +127,10 @@ def version_safe? version && version <= StrongMigrations.start_after end + def safe? + self.class.safe || ENV['SAFETY_ASSURED'] || (direction == :down && !StrongMigrations.check_down) || version_safe? || @migration.reverting? + end + private def check_version_supported @@ -138,7 +140,8 @@ def check_version_supported if min_version version = adapter.server_version if version < Gem::Version.new(min_version) - raise UnsupportedVersion, "#{adapter.name} version (#{version}) not supported in this version of Strong Migrations (#{StrongMigrations::VERSION})" + raise UnsupportedVersion, + "#{adapter.name} version (#{version}) not supported in this version of Strong Migrations (#{StrongMigrations::VERSION})" end end @@ -148,12 +151,8 @@ def check_version_supported def set_timeouts return if @timeouts_set - if StrongMigrations.statement_timeout - adapter.set_statement_timeout(StrongMigrations.statement_timeout) - end - if StrongMigrations.lock_timeout - adapter.set_lock_timeout(StrongMigrations.lock_timeout) - end + adapter.set_statement_timeout(StrongMigrations.statement_timeout) if StrongMigrations.statement_timeout + adapter.set_lock_timeout(StrongMigrations.lock_timeout) if StrongMigrations.lock_timeout @timeouts_set = true end @@ -161,17 +160,11 @@ def set_timeouts def check_lock_timeout return if defined?(@lock_timeout_checked) - if StrongMigrations.lock_timeout_limit - adapter.check_lock_timeout(StrongMigrations.lock_timeout_limit) - end + adapter.check_lock_timeout(StrongMigrations.lock_timeout_limit) if StrongMigrations.lock_timeout_limit @lock_timeout_checked = true end - def safe? - self.class.safe || ENV["SAFETY_ASSURED"] || (direction == :down && !StrongMigrations.check_down) || version_safe? || @migration.reverting? - end - def version @migration.version end @@ -201,11 +194,9 @@ def connection end def retry_lock_timeouts?(method) - ( - StrongMigrations.lock_timeout_retries > 0 && + StrongMigrations.lock_timeout_retries > 0 && !in_transaction? && method != :transaction - ) end end end diff --git a/lib/strong_migrations/checks.rb b/lib/strong_migrations/checks.rb index b0d3ee3d..0e631c40 100644 --- a/lib/strong_migrations/checks.rb +++ b/lib/strong_migrations/checks.rb @@ -1,4 +1,4 @@ -# TODO better pattern +# TODO: better pattern module StrongMigrations module Checks private @@ -7,23 +7,24 @@ def check_add_check_constraint(*args) options = args.extract_options! table, expression = args - if !new_table?(table) - if postgresql? && options[:validate] != false - add_options = options.merge(validate: false) - name = options[:name] || @migration.check_constraint_options(table, expression, options)[:name] - validate_options = {name: name} + return if new_table?(table) - if StrongMigrations.safe_by_default - safe_add_check_constraint(*args, add_options, validate_options) - throw :safe - end + if postgresql? && options[:validate] != false + add_options = options.merge(validate: false) + name = options[:name] || @migration.check_constraint_options(table, expression, options)[:name] + validate_options = { name: name } - raise_error :add_check_constraint, - add_check_constraint_code: command_str("add_check_constraint", [table, expression, add_options]), - validate_check_constraint_code: command_str("validate_check_constraint", [table, validate_options]) - elsif mysql? || mariadb? - raise_error :add_check_constraint_mysql + if StrongMigrations.safe_by_default + safe_add_check_constraint(*args, add_options, validate_options) + throw :safe end + + raise_error :add_check_constraint, + add_check_constraint_code: command_str('add_check_constraint', [table, expression, add_options]), + validate_check_constraint_code: command_str('validate_check_constraint', + [table, validate_options]) + elsif mysql? || mariadb? + raise_error :add_check_constraint_mysql end end @@ -39,7 +40,7 @@ def check_add_column(*args) # # Also, Active Record has special case for uuid columns that allows function default values # https://github.com/rails/rails/blob/v7.0.3.1/activerecord/lib/active_record/connection_adapters/postgresql/quoting.rb#L92-L93 - if options.key?(:default) && (!adapter.add_column_default_safe? || (volatile = (postgresql? && type.to_s == "uuid" && default.to_s.include?("()") && adapter.default_volatile?(default)))) + if options.key?(:default) && (!adapter.add_column_default_safe? || (volatile = postgresql? && type.to_s == 'uuid' && default.to_s.include?('()') && adapter.default_volatile?(default))) if options[:null] == false options = options.except(:null) append = "\n\nThen add the NOT NULL constraint in separate migrations." @@ -47,18 +48,18 @@ def check_add_column(*args) if default.nil? raise_error :add_column_default_null, - command: command_str("add_column", [table, column, type, options.except(:default)]), - append: append, - rewrite_blocks: adapter.rewrite_blocks + command: command_str('add_column', [table, column, type, options.except(:default)]), + append: append, + rewrite_blocks: adapter.rewrite_blocks else raise_error :add_column_default, - add_command: command_str("add_column", [table, column, type, options.except(:default)]), - change_command: command_str("change_column_default", [table, column, default]), - remove_command: command_str("remove_column", [table, column]), - code: backfill_code(table, column, default, volatile), - append: append, - rewrite_blocks: adapter.rewrite_blocks, - default_type: (volatile ? "volatile" : "non-null") + add_command: command_str('add_column', [table, column, type, options.except(:default)]), + change_command: command_str('change_column_default', [table, column, default]), + remove_command: command_str('remove_column', [table, column]), + code: backfill_code(table, column, default, volatile), + append: append, + rewrite_blocks: adapter.rewrite_blocks, + default_type: (volatile ? 'volatile' : 'non-null') end elsif default.is_a?(Proc) && postgresql? # adding a column with a VOLATILE default is not safe @@ -68,29 +69,29 @@ def check_add_column(*args) raise_error :add_column_default_callable end - if type.to_s == "json" && postgresql? + if type.to_s == 'json' && postgresql? raise_error :add_column_json, - command: command_str("add_column", [table, column, :jsonb, options]) + command: command_str('add_column', [table, column, :jsonb, options]) end - if type.to_s == "virtual" && options[:stored] + if type.to_s == 'virtual' && options[:stored] raise_error :add_column_generated_stored, rewrite_blocks: adapter.rewrite_blocks end - if adapter.auto_incrementing_types.include?(type.to_s) - append = (mysql? || mariadb?) ? "\n\nIf using statement-based replication, this can also generate different values on replicas." : "" - raise_error :add_column_auto_incrementing, - rewrite_blocks: adapter.rewrite_blocks, - append: append - end + return unless adapter.auto_incrementing_types.include?(type.to_s) + + append = mysql? || mariadb? ? "\n\nIf using statement-based replication, this can also generate different values on replicas." : '' + raise_error :add_column_auto_incrementing, + rewrite_blocks: adapter.rewrite_blocks, + append: append end def check_add_exclusion_constraint(*args) table = args[0] - unless new_table?(table) - raise_error :add_exclusion_constraint - end + return if new_table?(table) + + raise_error :add_exclusion_constraint end # unlike add_index, we don't make an exception here for new tables @@ -110,16 +111,17 @@ def check_add_foreign_key(*args) from_table, to_table = args validate = options.fetch(:validate, true) - if postgresql? && validate - if StrongMigrations.safe_by_default - safe_add_foreign_key(*args, **options) - throw :safe - end + return unless postgresql? && validate - raise_error :add_foreign_key, - add_foreign_key_code: command_str("add_foreign_key", [from_table, to_table, options.merge(validate: false)]), - validate_foreign_key_code: command_str("validate_foreign_key", [from_table, to_table]) + if StrongMigrations.safe_by_default + safe_add_foreign_key(*args, **options) + throw :safe end + + raise_error :add_foreign_key, + add_foreign_key_code: command_str('add_foreign_key', + [from_table, to_table, options.merge(validate: false)]), + validate_foreign_key_code: command_str('validate_foreign_key', [from_table, to_table]) end def check_add_index(*args) @@ -127,7 +129,7 @@ def check_add_index(*args) table, columns = args if columns.is_a?(Array) && columns.size > 3 && !options[:unique] - raise_error :add_index_columns, header: "Best practice" + raise_error :add_index_columns, header: 'Best practice' end # safe_by_default goes through this path as well @@ -137,50 +139,51 @@ def check_add_index(*args) # safe to add non-concurrently to new tables (even after inserting data) # since the table won't be in use by the application - if postgresql? && options[:algorithm] != :concurrently && !new_table?(table) - if StrongMigrations.safe_by_default - safe_add_index(*args, **options) - throw :safe - end + return unless postgresql? && options[:algorithm] != :concurrently && !new_table?(table) - raise_error :add_index, command: command_str("add_index", [table, columns, options.merge(algorithm: :concurrently)]) + if StrongMigrations.safe_by_default + safe_add_index(*args, **options) + throw :safe end + + raise_error :add_index, + command: command_str('add_index', [table, columns, options.merge(algorithm: :concurrently)]) end def check_add_reference(method, *args) options = args.extract_options! table, reference = args - if postgresql? - index_value = options.fetch(:index, true) - concurrently_set = index_value.is_a?(Hash) && index_value[:algorithm] == :concurrently - bad_index = index_value && !concurrently_set - - if bad_index || options[:foreign_key] - if index_value.is_a?(Hash) - options[:index] = options[:index].merge(algorithm: :concurrently) - elsif index_value - options = options.merge(index: {algorithm: :concurrently}) - end + return unless postgresql? - if StrongMigrations.safe_by_default - safe_add_reference(*args, **options) - throw :safe - end + index_value = options.fetch(:index, true) + concurrently_set = index_value.is_a?(Hash) && index_value[:algorithm] == :concurrently + bad_index = index_value && !concurrently_set - if options.delete(:foreign_key) - headline = "Adding a foreign key blocks writes on both tables." - append = "\n\nThen add the foreign key in separate migrations." - else - headline = "Adding an index non-concurrently locks the table." - end + return unless bad_index || options[:foreign_key] - raise_error :add_reference, - headline: headline, - command: command_str(method, [table, reference, options]), - append: append - end + if index_value.is_a?(Hash) + options[:index] = options[:index].merge(algorithm: :concurrently) + elsif index_value + options = options.merge(index: { algorithm: :concurrently }) + end + + if StrongMigrations.safe_by_default + safe_add_reference(*args, **options) + throw :safe end + + if options.delete(:foreign_key) + headline = 'Adding a foreign key blocks writes on both tables.' + append = "\n\nThen add the foreign key in separate migrations." + else + headline = 'Adding an index non-concurrently locks the table.' + end + + raise_error :add_reference, + headline: headline, + command: command_str(method, [table, reference, options]), + append: append end def check_add_unique_constraint(*args) @@ -189,13 +192,13 @@ def check_add_unique_constraint(*args) # column and using_index cannot be used together # check for column to ensure error message can be generated - if column && !new_table?(table) - index_name = connection.index_name(table, {column: column}) - raise_error :add_unique_constraint, - index_command: command_str(:add_index, [table, column, {unique: true, algorithm: :concurrently}]), - constraint_command: command_str(:add_unique_constraint, [table, {using_index: index_name}]), - remove_command: command_str(:remove_unique_constraint, [table, column]) - end + return unless column && !new_table?(table) + + index_name = connection.index_name(table, { column: column }) + raise_error :add_unique_constraint, + index_command: command_str(:add_index, [table, column, { unique: true, algorithm: :concurrently }]), + constraint_command: command_str(:add_unique_constraint, [table, { using_index: index_name }]), + remove_command: command_str(:remove_unique_constraint, [table, column]) end def check_change_column(*args) @@ -203,18 +206,20 @@ def check_change_column(*args) table, column, type = args safe = false - table_columns = connection.columns(table) rescue [] + table_columns = begin + connection.columns(table) + rescue StandardError + [] + end existing_column = table_columns.find { |c| c.name.to_s == column.to_s } if existing_column - existing_type = existing_column.sql_type.sub(/\(\d+(,\d+)?\)/, "") + existing_type = existing_column.sql_type.sub(/\(\d+(,\d+)?\)/, '') safe = adapter.change_type_safe?(table, column, type, options, existing_column, existing_type) end # unsafe to set NOT NULL for safe types with Postgres # TODO check if safe for MySQL and MariaDB - if safe && existing_column.null && options[:null] == false - raise_error :change_column_with_not_null - end + raise_error :change_column_with_not_null if safe && existing_column.null && options[:null] == false raise_error :change_column, rewrite_blocks: adapter.rewrite_blocks unless safe end @@ -225,90 +230,94 @@ def check_change_column_default(*args) # just check ActiveRecord::Base, even though can override on model partial_inserts = ar_version >= 7 ? ActiveRecord::Base.partial_inserts : ActiveRecord::Base.partial_writes - if partial_inserts && !new_column?(table, column) - raise_error :change_column_default, - config: ar_version >= 7 ? "partial_inserts" : "partial_writes" - end + return unless partial_inserts && !new_column?(table, column) + + raise_error :change_column_default, + config: ar_version >= 7 ? 'partial_inserts' : 'partial_writes' end def check_change_column_null(*args) table, column, null, default = args - if !null - if postgresql? - safe = false - safe_with_check_constraint = adapter.server_version >= Gem::Version.new("12") - if safe_with_check_constraint - safe = adapter.constraints(table).any? { |c| c["def"] == "CHECK ((#{column} IS NOT NULL))" || c["def"] == "CHECK ((#{connection.quote_column_name(column)} IS NOT NULL))" } + return if null + + if postgresql? + safe = false + safe_with_check_constraint = adapter.server_version >= Gem::Version.new('12') + if safe_with_check_constraint + safe = adapter.constraints(table).any? do |c| + c['def'] == "CHECK ((#{column} IS NOT NULL))" || c['def'] == "CHECK ((#{connection.quote_column_name(column)} IS NOT NULL))" end + end - unless safe - # match https://github.com/nullobject/rein - constraint_name = "#{table}_#{column}_null" + unless safe + # match https://github.com/nullobject/rein + constraint_name = "#{table}_#{column}_null" - add_code = constraint_str("ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s IS NOT NULL) NOT VALID", [table, constraint_name, column]) - validate_code = constraint_str("ALTER TABLE %s VALIDATE CONSTRAINT %s", [table, constraint_name]) - remove_code = constraint_str("ALTER TABLE %s DROP CONSTRAINT %s", [table, constraint_name]) + add_code = constraint_str('ALTER TABLE %s ADD CONSTRAINT %s CHECK (%s IS NOT NULL) NOT VALID', + [table, constraint_name, column]) + validate_code = constraint_str('ALTER TABLE %s VALIDATE CONSTRAINT %s', [table, constraint_name]) + remove_code = constraint_str('ALTER TABLE %s DROP CONSTRAINT %s', [table, constraint_name]) - constraint_methods = ar_version >= 6.1 + constraint_methods = ar_version >= 6.1 - validate_constraint_code = - if constraint_methods - String.new(command_str(:validate_check_constraint, [table, {name: constraint_name}])) - else - String.new(safety_assured_str(validate_code)) - end + validate_constraint_code = + if constraint_methods + String.new(command_str(:validate_check_constraint, [table, { name: constraint_name }])) + else + String.new(safety_assured_str(validate_code)) + end - if safe_with_check_constraint - change_args = [table, column, null] + if safe_with_check_constraint + change_args = [table, column, null] - validate_constraint_code << "\n #{command_str(:change_column_null, change_args)}" + validate_constraint_code << "\n #{command_str(:change_column_null, change_args)}" - if constraint_methods - validate_constraint_code << "\n #{command_str(:remove_check_constraint, [table, {name: constraint_name}])}" - else - validate_constraint_code << "\n #{safety_assured_str(remove_code)}" - end + if constraint_methods + validate_constraint_code << "\n #{command_str(:remove_check_constraint, + [table, { name: constraint_name }])}" + else + validate_constraint_code << "\n #{safety_assured_str(remove_code)}" end + end - if StrongMigrations.safe_by_default - safe_change_column_null(add_code, validate_code, change_args, remove_code, default) - throw :safe + if StrongMigrations.safe_by_default + safe_change_column_null(add_code, validate_code, change_args, remove_code, default) + throw :safe + end + + add_constraint_code = + if constraint_methods + command_str(:add_check_constraint, + [table, "#{quote_column_if_needed(column)} IS NOT NULL", + { name: constraint_name, validate: false }]) + else + safety_assured_str(add_code) end - add_constraint_code = - if constraint_methods - command_str(:add_check_constraint, [table, "#{quote_column_if_needed(column)} IS NOT NULL", {name: constraint_name, validate: false}]) - else - safety_assured_str(add_code) - end - - validate_constraint_code = - if safe_with_check_constraint - down_code = "#{add_constraint_code}\n #{command_str(:change_column_null, [table, column, true])}" - "def up\n #{validate_constraint_code}\n end\n\n def down\n #{down_code}\n end" - else - "def change\n #{validate_constraint_code}\n end" - end - - raise_error :change_column_null_postgresql, - add_constraint_code: add_constraint_code, - validate_constraint_code: validate_constraint_code - end - elsif mysql? || mariadb? - unless adapter.strict_mode? - raise_error :change_column_null_mysql - end - end + validate_constraint_code = + if safe_with_check_constraint + down_code = "#{add_constraint_code}\n #{command_str(:change_column_null, [table, column, true])}" + "def up\n #{validate_constraint_code}\n end\n\n def down\n #{down_code}\n end" + else + "def change\n #{validate_constraint_code}\n end" + end - if !default.nil? - raise_error :change_column_null, - code: backfill_code(table, column, default) + raise_error :change_column_null_postgresql, + add_constraint_code: add_constraint_code, + validate_constraint_code: validate_constraint_code end + elsif mysql? || mariadb? + raise_error :change_column_null_mysql unless adapter.strict_mode? end + + return if default.nil? + + raise_error :change_column_null, + code: backfill_code(table, column, default) end def check_change_table - raise_error :change_table, header: "Possibly dangerous operation" + raise_error :change_table, header: 'Possibly dangerous operation' end def check_create_join_table(*args) @@ -316,12 +325,12 @@ def check_create_join_table(*args) raise_error :create_table if options[:force] - # TODO keep track of new table of add_index check + # TODO: keep track of new table of add_index check end def check_create_table(*args) options = args.extract_options! - table, _ = args + table, = args raise_error :create_table if options[:force] @@ -330,14 +339,14 @@ def check_create_table(*args) end def check_execute - raise_error :execute, header: "Possibly dangerous operation" + raise_error :execute, header: 'Possibly dangerous operation' end def check_remove_column(method, *args) columns = case method when :remove_timestamps - ["created_at", "updated_at"] + %w[created_at updated_at] when :remove_column [args[1].to_s] when :remove_columns @@ -359,33 +368,34 @@ def check_remove_column(method, *args) code = "self.ignored_columns += #{columns.inspect}" raise_error :remove_column, - model: args[0].to_s.classify, - code: code, - command: command_str(method, args), - column_suffix: columns.size > 1 ? "s" : "" + model: args[0].to_s.classify, + code: code, + command: command_str(method, args), + column_suffix: columns.size > 1 ? 's' : '' end def check_remove_index(*args) options = args.extract_options! - table, _ = args + table, = args - if postgresql? && options[:algorithm] != :concurrently && !new_table?(table) - # avoid suggesting extra (invalid) args - args = args[0..1] unless StrongMigrations.safe_by_default + return unless postgresql? && options[:algorithm] != :concurrently && !new_table?(table) - # Active Record < 6.1 only supports two arguments (including options) - if args.size == 2 && ar_version < 6.1 - # arg takes precedence over option - options[:column] = args.pop - end + # avoid suggesting extra (invalid) args + args = args[0..1] unless StrongMigrations.safe_by_default - if StrongMigrations.safe_by_default - safe_remove_index(*args, **options) - throw :safe - end + # Active Record < 6.1 only supports two arguments (including options) + if args.size == 2 && ar_version < 6.1 + # arg takes precedence over option + options[:column] = args.pop + end - raise_error :remove_index, command: command_str("remove_index", args + [options.merge(algorithm: :concurrently)]) + if StrongMigrations.safe_by_default + safe_remove_index(*args, **options) + throw :safe end + + raise_error :remove_index, + command: command_str('remove_index', args + [options.merge(algorithm: :concurrently)]) end def check_rename_column @@ -397,15 +407,15 @@ def check_rename_table end def check_validate_check_constraint - if postgresql? && adapter.writes_blocked? - raise_error :validate_check_constraint - end + return unless postgresql? && adapter.writes_blocked? + + raise_error :validate_check_constraint end def check_validate_foreign_key - if postgresql? && adapter.writes_blocked? - raise_error :validate_foreign_key - end + return unless postgresql? && adapter.writes_blocked? + + raise_error :validate_foreign_key end # helpers @@ -429,16 +439,16 @@ def ar_version def raise_error(message_key, header: nil, append: nil, **vars) return unless StrongMigrations.check_enabled?(message_key, version: version) - message = StrongMigrations.error_messages[message_key] || "Missing message" - message = message + append if append + message = StrongMigrations.error_messages[message_key] || 'Missing message' + message += append if append vars[:migration_name] = @migration.class.name vars[:migration_suffix] = "[#{ActiveRecord::VERSION::MAJOR}.#{ActiveRecord::VERSION::MINOR}]" - vars[:base_model] = "ApplicationRecord" + vars[:base_model] = 'ApplicationRecord' # escape % not followed by { - message = message.gsub(/%(?!{)/, "%%") % vars if message.include?("%") - @migration.stop!(message, header: header || "Dangerous operation detected") + message = message.gsub(/%(?!{)/, '%%') % vars if message.include?('%') + @migration.stop!(message, header: header || 'Dangerous operation detected') end def constraint_str(statement, identifiers) @@ -460,17 +470,17 @@ def command_str(command, args) str_args << last_arg.map do |k, v| if v.is_a?(Hash) # pretty index: {algorithm: :concurrently} - "#{k}: {#{v.map { |k2, v2| "#{k2}: #{v2.inspect}" }.join(", ")}}" + "#{k}: {#{v.map { |k2, v2| "#{k2}: #{v2.inspect}" }.join(', ')}}" else "#{k}: #{v.inspect}" end - end.join(", ") + end.join(', ') end else str_args << last_arg.inspect end - "#{command} #{str_args.join(", ")}" + "#{command} #{str_args.join(', ')}" end def backfill_code(table, column, default, function = false) diff --git a/lib/strong_migrations/migration.rb b/lib/strong_migrations/migration.rb index e3507665..87657318 100644 --- a/lib/strong_migrations/migration.rb +++ b/lib/strong_migrations/migration.rb @@ -31,7 +31,8 @@ def revert(*) def create_table(table_name, **options, &block) if block_given? super do |t| - table_definition = StrongMigrations::TableDefinition.new(compatible_table_definition(t)) + table_definition = StrongMigrations::TableDefinition.new(compatible_table_definition(t), self, + strong_migrations_checker) yield table_definition end else diff --git a/lib/strong_migrations/table_definition.rb b/lib/strong_migrations/table_definition.rb index fab6375d..f87f444f 100644 --- a/lib/strong_migrations/table_definition.rb +++ b/lib/strong_migrations/table_definition.rb @@ -1,13 +1,27 @@ module StrongMigrations class TableDefinition - def initialize(ar_table_definition) + def initialize(ar_table_definition, migration, checker) @ar_table_definition = ar_table_definition + @migration = migration + @checker = checker end def method_missing(method, *args, **kwargs) - p "StrongMigrations::TableDefinition method_missing called with method: #{method} and args: #{args} and kwargs: #{kwargs}" + return super if is_a?(ActiveRecord::Schema) + # Active Record 7.0.2+ versioned schema + return super if defined?(ActiveRecord::Schema::Definition) && is_a?(ActiveRecord::Schema::Definition) + + unless @checker.safe? + StrongMigrations.table_checks.each do |check| + @migration.instance_exec(method, args, kwargs, &check) + end + end @ar_table_definition.send(method, *args, **kwargs) end + + def safety_assured(&block) + @checker.class.safety_assured(&block) + end end end diff --git a/test/create_table_test.rb b/test/create_table_test.rb new file mode 100644 index 00000000..c8cf19e1 --- /dev/null +++ b/test/create_table_test.rb @@ -0,0 +1,15 @@ +require_relative 'test_helper' + +class CreateTableTest < Minitest::Test + def test_create_table_with_integer + assert_unsafe CreateTableWithInteger + end + + def test_col_definition_in_safe_block + assert_safe CreateTableWithSafetyAssured + end + + def test_create_table_with_integer_column_call + assert_unsafe CreateTableWithIntegerColumnCall + end +end diff --git a/test/migrations/create_table.rb b/test/migrations/create_table.rb new file mode 100644 index 00000000..91594c96 --- /dev/null +++ b/test/migrations/create_table.rb @@ -0,0 +1,35 @@ +class CreateTableWithInteger < TestMigration + def up + create_table :test_table do |t| + t.integer :test_column + end + end + + def down + drop_table :test_table + end +end + +class CreateTableWithIntegerColumnCall < TestMigration + def up + create_table :test_table do |t| + t.column :test_column, :integer + end + end + + def down + drop_table :test_table + end +end + +class CreateTableWithSafetyAssured < TestMigration + def up + create_table :test_table do |t| + safety_assured { t.integer :test_column } + end + end + + def down + drop_table :test_table + end +end diff --git a/test/test_helper.rb b/test/test_helper.rb index 43d2bc70..77f8f3b8 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -1,36 +1,34 @@ -require "bundler/setup" +require 'bundler/setup' Bundler.require(:default) -require "minitest/autorun" -require "minitest/pride" -require "active_record" +require 'minitest/autorun' +require 'minitest/pride' +require 'active_record' # needed for target_version module Rails def self.env - ActiveSupport::StringInquirer.new("test") + ActiveSupport::StringInquirer.new('test') end end -$adapter = ENV["ADAPTER"] || "postgresql" +$adapter = ENV['ADAPTER'] || 'postgresql' connection_options = { adapter: $adapter, - database: "strong_migrations_test" + database: 'strong_migrations_test' } -if $adapter == "mysql2" - connection_options[:encoding] = "utf8mb4" - if ActiveRecord::VERSION::STRING.to_f >= 7.1 - connection_options[:prepared_statements] = true - end -elsif $adapter == "trilogy" +if $adapter == 'mysql2' + connection_options[:encoding] = 'utf8mb4' + connection_options[:prepared_statements] = true if ActiveRecord::VERSION::STRING.to_f >= 7.1 +elsif $adapter == 'trilogy' if ActiveRecord::VERSION::STRING.to_f < 7.1 - require "trilogy_adapter/connection" - ActiveRecord::Base.public_send :extend, TrilogyAdapter::Connection + require 'trilogy_adapter/connection' + ActiveRecord::Base.extend TrilogyAdapter::Connection end - connection_options[:host] = "127.0.0.1" + connection_options[:host] = '127.0.0.1' end ActiveRecord::Base.establish_connection(**connection_options) -if ENV["VERBOSE"] +if ENV['VERBOSE'] ActiveRecord::Base.logger = ActiveSupport::Logger.new($stdout) else ActiveRecord::Migration.verbose = false @@ -54,18 +52,18 @@ def schema_migration schema_migration.create_table ActiveRecord::Schema.define do - if $adapter == "postgresql" + if $adapter == 'postgresql' # for change column - enable_extension "citext" + enable_extension 'citext' # for exclusion constraints - enable_extension "btree_gist" + enable_extension 'btree_gist' # for gen_random_uuid() in Postgres < 13 - enable_extension "pgcrypto" + enable_extension 'pgcrypto' end - [:users, :new_users, :orders, :devices, :cities_users].each do |table| + %i[users new_users orders devices cities_users test_table].each do |table| drop_table(table) if table_exists?(table) end @@ -77,7 +75,7 @@ def schema_migration t.string :country, limit: 20 t.string :interval t.text :description - t.citext :code if $adapter == "postgresql" + t.citext :code if $adapter == 'postgresql' t.references :order end @@ -93,15 +91,15 @@ class User < ActiveRecord::Base module Helpers def postgresql? - $adapter == "postgresql" + $adapter == 'postgresql' end def mysql? - ($adapter == "mysql2" || $adapter == "trilogy") && !ActiveRecord::Base.connection.mariadb? + ($adapter == 'mysql2' || $adapter == 'trilogy') && !ActiveRecord::Base.connection.mariadb? end def mariadb? - ($adapter == "mysql2" || $adapter == "trilogy") && ActiveRecord::Base.connection.mariadb? + ($adapter == 'mysql2' || $adapter == 'trilogy') && ActiveRecord::Base.connection.mariadb? end end @@ -133,8 +131,9 @@ def migrate(migration, direction: :up, version: 123) end ActiveRecord::Migrator.new(direction, [migration], *args).migrate true - rescue => e + rescue StandardError => e raise e.cause if e.cause + raise e end @@ -142,7 +141,7 @@ def assert_unsafe(migration, message = nil, **options) error = assert_raises(StrongMigrations::UnsafeMigration) do migrate(migration, **options) end - puts error.message if ENV["VERBOSE"] + puts error.message if ENV['VERBOSE'] assert_match message, error.message if message end @@ -174,16 +173,12 @@ def with_target_version(version) StrongMigrations.target_version = nil end - def with_safety_assured - StrongMigrations::Checker.stub(:safe, true) do - yield - end + def with_safety_assured(&block) + StrongMigrations::Checker.stub(:safe, true, &block) end - def outside_developer_env - StrongMigrations.stub(:developer_env?, false) do - yield - end + def outside_developer_env(&block) + StrongMigrations.stub(:developer_env?, false, &block) end def check_constraints? @@ -192,11 +187,17 @@ def check_constraints? end StrongMigrations.add_check do |method, args| - if method == :add_column && args[1].to_s == "forbidden" - stop! "Cannot add forbidden column" + stop! 'Cannot add forbidden column' if method == :add_column && args[1].to_s == 'forbidden' +end + +StrongMigrations.add_table_check do |method, args, _kwargs| + stop!('Use bigint instead to avoid overflow') if method == :integer + + if method == :column && (args[1].to_sym == :integer || args[1].to_sym == :int) + stop!('Use bigint instead to avoid overflow') end end -Dir.glob("migrations/*.rb", base: __dir__).sort.each do |file| +Dir.glob('migrations/*.rb', base: __dir__).sort.each do |file| require_relative file end