diff --git a/crud/compare/conditions.lua b/crud/compare/conditions.lua index 25faa6d8..3c6e705e 100644 --- a/crud/compare/conditions.lua +++ b/crud/compare/conditions.lua @@ -13,6 +13,11 @@ conditions.operators = { LE = '<=', GT = '>', GE = '>=', + NO = '@<>', + BTW = '@>><<', + OUT = '@<<>>', + OUTN = '@<>', + INC = '@>)', } local tarantool_iter_by_cond_operators = { @@ -54,6 +59,11 @@ local cond_operators_by_func_names = { le = conditions.operators.LE, gt = conditions.operators.GT, ge = conditions.operators.GE, + no = conditions.operators.NO, + between = conditions.operators.BTW, + outer = conditions.operators.OUT, + outer_or_null = conditions.operators.OUTN, + include = conditions.operators.INC, } for func_name, operator in pairs(cond_operators_by_func_names) do @@ -75,6 +85,11 @@ local funcs_by_symbols = { ['<='] = conditions.funcs.le, ['>'] = conditions.funcs.gt, ['>='] = conditions.funcs.ge, + ['<>'] = conditions.funcs.no, + ['>><<'] = conditions.funcs.between, + ['<<>>'] = conditions.funcs.outer, + ['<>'] = conditions.funcs.outer_or_null, + ['>)'] = conditions.funcs.include } function conditions.parse(user_conditions) diff --git a/crud/compare/filters.lua b/crud/compare/filters.lua index 145e0190..4eff0b58 100644 --- a/crud/compare/filters.lua +++ b/crud/compare/filters.lua @@ -204,16 +204,22 @@ end local function gen_tuple_fields_def_code(filter_conditions) -- get field names - local fields_added = {} local fields = {} for _, cond in ipairs(filter_conditions) do - for i = 1, #cond.values do - local field = cond.fields[i] + --no matter the number of fields coincide number of values, the fields array contains only uniq field names + if #cond.fields == 1 then + local field = cond.fields[1] + table.insert(fields, field) + else + local fields_added = {} + for i = 1, #cond.fields do + local field = cond.fields[i] - if not fields_added[field] then - table.insert(fields, field) - fields_added[field] = true + if not fields_added[field] then + table.insert(fields, field) + fields_added[field] = true + end end end end @@ -264,25 +270,62 @@ local function add_collation_postfix(func_name, value_opts) error('Unsupported collation: ' .. tostring(value_opts.collation)) end +local function cond_string(cond, values_opts, field, j) + local value = cond.values[j] + local value_type = cond.types[j] + local value_opts = values_opts[j] or {} + + local func_name = 'eq' + + if value_type == 'string' then + func_name = add_collation_postfix('eq', value_opts) + if collations.is_unicode(value_opts.collation) then + func_name = add_strict_postfix(func_name, value_opts) + end + elseif value_type == 'uuid' then + func_name = 'eq_uuid' + end + + return format_comp_with_value(field, func_name, value) +end + local function format_eq(cond) local cond_strings = {} local values_opts = cond.values_opts or {} + if #cond.values == 1 then + table.insert(cond_strings, cond_string(cond, values_opts, cond.fields[1], 1)) + else + for j = 1, #cond.values do + if #cond.fields == 1 then + table.insert(cond_strings, cond_string(cond, values_opts, cond.fields[1], j)) + else + table.insert(cond_strings, cond_string(cond, values_opts, cond.fields[j], j)) + end + end + end + return cond_strings +end + +local function format_not_eq(cond) + local cond_strings = {} + local values_opts = cond.values_opts or {} + for j = 1, #cond.values do local field = cond.fields[j] local value = cond.values[j] local value_type = cond.types[j] local value_opts = values_opts[j] or {} - local func_name = 'eq' + local func_name = 'not eq' if value_type == 'string' then - func_name = add_collation_postfix('eq', value_opts) + func_name = add_collation_postfix('not eq', value_opts) if collations.is_unicode(value_opts.collation) then func_name = add_strict_postfix(func_name, value_opts) end elseif value_type == 'uuid' then - func_name = 'eq_uuid' + func_name = 'not eq_uuid' end table.insert(cond_strings, format_comp_with_value(field, func_name, value)) @@ -327,6 +370,24 @@ local function gen_eq_func_code(func_name, cond, func_args_code) local header = LIB_FUNC_HEADER_TEMPLATE:format(func_name, func_args_code) table.insert(func_code_lines, header) + local return_line = string.format( + ' return %s', concat_conditions(eq_conds, 'or') + ) + table.insert(func_code_lines, return_line) + + table.insert(func_code_lines, 'end') + + return table.concat(func_code_lines, '\n') +end + +local function gen_not_eq_func_code(func_name, cond, func_args_code) + local func_code_lines = {} + + local eq_conds = format_not_eq(cond) + + local header = LIB_FUNC_HEADER_TEMPLATE:format(func_name, func_args_code) + table.insert(func_code_lines, header) + local return_line = string.format( ' return %s', concat_conditions(eq_conds, 'and') ) @@ -369,30 +430,53 @@ local results_by_operators = { local function gen_cmp_array_func_code(operator, func_name, cond, func_args_code) local func_code_lines = {} - local eq_conds = format_eq(cond) - local lt_conds = format_lt(cond) - local header = LIB_FUNC_HEADER_TEMPLATE:format(func_name, func_args_code) table.insert(func_code_lines, header) - assert(#lt_conds == #eq_conds) + if string.sub(operator, 1, 1) == '@' then + if operator == '@<>' then + return gen_not_eq_func_code(func_name, cond, func_args_code) + end - local results = results_by_operators[operator] - assert(results ~= nil) + local return_line = '' + local field = get_field_variable_name(cond.fields[1]) + if operator == '@>><<' then + return_line = string.format('if lt_strict(%s, %s) and lt(%s, %s) then return true end', cond.values[1], field, field, cond.values[2]) + elseif operator == '@<<>>' then + return_line = string.format('if lt(%s, %s) or lt_strict(%s, %s) then return true end', field, cond.values[1], cond.values[2], field) + elseif operator == '@<>' then + return_line = string.format('if lt(%s, %s) or lt_strict(%s, %s) or eq(%s, NULL) then return true end', field, cond.values[1], cond.values[2], field, field) + elseif operator == '@>)' then + return_line = string.format('if lt(%s, %s) or lt_strict(%s, %s) or eq(%s, NULL) then return true end', field, cond.values[1], cond.values[2], field, field) + end - for i = 1, #eq_conds do - local comp_value_code = table.concat({ - string.format(' if %s then return %s end', lt_conds[i], results.le), - string.format(' if not %s then return %s end', eq_conds[i], results.not_eq), - '', - }, '\n') + table.insert(func_code_lines, return_line) + table.insert(func_code_lines, ' return false ') + else - table.insert(func_code_lines, comp_value_code) - end + local eq_conds = format_eq(cond) + local lt_conds = format_lt(cond) + + --log.error(func_name) - local return_code = (' return %s'):format(results.default) - table.insert(func_code_lines, return_code) + assert(#lt_conds == #eq_conds) + local results = results_by_operators[operator] + assert(results ~= nil) + + for i = 1, #eq_conds do + local comp_value_code = table.concat({ + string.format(' if %s then return %s end', lt_conds[i], results.le), + string.format(' if not %s then return %s end', eq_conds[i], results.not_eq), + '', + }, '\n') + + table.insert(func_code_lines, comp_value_code) + end + + local return_code = (' return %s'):format(results.default) + table.insert(func_code_lines, return_code) + end table.insert(func_code_lines, 'end') return table.concat(func_code_lines, '\n') @@ -469,6 +553,15 @@ local function gen_filter_code(filter_conditions) } end +local function gt_nullable(lhs, rhs) + if lhs == nil and rhs ~= nil then + return true + elseif rhs == nil then + return false + end + return lhs < rhs +end + local function lt_nullable(lhs, rhs) if lhs == nil and rhs ~= nil then return true