Skip to content

Commit

Permalink
[MySQL] Add unsigned floating point types + Fix unsigned integer type…
Browse files Browse the repository at this point in the history
… bugs in Kit (drizzle-team#3284)

* Fix bugs with MySQL introspection tests

* Update float data type in MySQL

* Better support for float types in MySQL

* Handle existing unsigned numerical types in MySQL

* Add unsigned to floating point types in MySQL

* Handle unsigned floating point types in MySQL

* Update decimal data type

---------

Co-authored-by: Andrii Sherman <andreysherman11@gmail.com>
  • Loading branch information
L-Mario564 and AndriiSherman authored Nov 3, 2024
1 parent d43eee3 commit 19f042a
Show file tree
Hide file tree
Showing 9 changed files with 206 additions and 56 deletions.
71 changes: 55 additions & 16 deletions drizzle-kit/src/introspect-mysql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ export const schemaToTypeScript = (
patched = patched.startsWith('varbinary(') ? 'varbinary' : patched;
patched = patched.startsWith('int(') ? 'int' : patched;
patched = patched.startsWith('double(') ? 'double' : patched;
patched = patched.startsWith('float(') ? 'float' : patched;
patched = patched.startsWith('int unsigned') ? 'int' : patched;
patched = patched.startsWith('tinyint unsigned') ? 'tinyint' : patched;
patched = patched.startsWith('smallint unsigned') ? 'smallint' : patched;
patched = patched.startsWith('mediumint unsigned') ? 'mediumint' : patched;
patched = patched.startsWith('bigint unsigned') ? 'bigint' : patched;
return patched;
})
.filter((type) => {
Expand Down Expand Up @@ -207,6 +213,12 @@ export const schemaToTypeScript = (
patched = patched.startsWith('varbinary(') ? 'varbinary' : patched;
patched = patched.startsWith('int(') ? 'int' : patched;
patched = patched.startsWith('double(') ? 'double' : patched;
patched = patched.startsWith('float(') ? 'float' : patched;
patched = patched.startsWith('int unsigned') ? 'int' : patched;
patched = patched.startsWith('tinyint unsigned') ? 'tinyint' : patched;
patched = patched.startsWith('smallint unsigned') ? 'smallint' : patched;
patched = patched.startsWith('mediumint unsigned') ? 'mediumint' : patched;
patched = patched.startsWith('bigint unsigned') ? 'bigint' : patched;
return patched;
})
.filter((type) => {
Expand Down Expand Up @@ -397,8 +409,9 @@ const column = (

if (lowered.startsWith('int')) {
const isUnsigned = lowered.startsWith('int unsigned');
let out = `${casing(name)}: int(${dbColumnName({ name, casing: rawCasing, withMode: isUnsigned })}${
isUnsigned ? '{ unsigned: true }' : ''
const columnName = dbColumnName({ name, casing: rawCasing, withMode: isUnsigned });
let out = `${casing(name)}: int(${columnName}${
isUnsigned ? `${columnName.length > 0 ? ', ' : ''}{ unsigned: true }` : ''
})`;
out += autoincrement ? `.autoincrement()` : '';
out += typeof defaultValue !== 'undefined'
Expand All @@ -409,9 +422,10 @@ const column = (

if (lowered.startsWith('tinyint')) {
const isUnsigned = lowered.startsWith('tinyint unsigned');
const columnName = dbColumnName({ name, casing: rawCasing, withMode: isUnsigned });
// let out = `${name.camelCase()}: tinyint("${name}")`;
let out: string = `${casing(name)}: tinyint(${dbColumnName({ name, casing: rawCasing, withMode: isUnsigned })}${
isUnsigned ? ', { unsigned: true }' : ''
let out: string = `${casing(name)}: tinyint(${columnName}${
isUnsigned ? `${columnName.length > 0 ? ', ' : ''}{ unsigned: true }` : ''
})`;
out += autoincrement ? `.autoincrement()` : '';
out += typeof defaultValue !== 'undefined'
Expand All @@ -422,8 +436,9 @@ const column = (

if (lowered.startsWith('smallint')) {
const isUnsigned = lowered.startsWith('smallint unsigned');
let out = `${casing(name)}: smallint(${dbColumnName({ name, casing: rawCasing, withMode: isUnsigned })}${
isUnsigned ? ', { unsigned: true }' : ''
const columnName = dbColumnName({ name, casing: rawCasing, withMode: isUnsigned });
let out = `${casing(name)}: smallint(${columnName}${
isUnsigned ? `${columnName.length > 0 ? ', ' : ''}{ unsigned: true }` : ''
})`;
out += autoincrement ? `.autoincrement()` : '';
out += defaultValue
Expand All @@ -434,8 +449,9 @@ const column = (

if (lowered.startsWith('mediumint')) {
const isUnsigned = lowered.startsWith('mediumint unsigned');
let out = `${casing(name)}: mediumint(${dbColumnName({ name, casing: rawCasing, withMode: isUnsigned })}${
isUnsigned ? ', { unsigned: true }' : ''
const columnName = dbColumnName({ name, casing: rawCasing, withMode: isUnsigned });
let out = `${casing(name)}: mediumint(${columnName}${
isUnsigned ? `${columnName.length > 0 ? ', ' : ''}{ unsigned: true }` : ''
})`;
out += autoincrement ? `.autoincrement()` : '';
out += defaultValue
Expand Down Expand Up @@ -466,16 +482,20 @@ const column = (

if (lowered.startsWith('double')) {
let params:
| { precision: string | undefined; scale: string | undefined }
| { precision?: string; scale?: string; unsigned?: boolean }
| undefined;

if (lowered.length > 6) {
if (lowered.length > (lowered.includes('unsigned') ? 15 : 6)) {
const [precision, scale] = lowered
.slice(7, lowered.length - 1)
.slice(7, lowered.length - (1 + (lowered.includes('unsigned') ? 9 : 0)))
.split(',');
params = { precision, scale };
}

if (lowered.includes('unsigned')) {
params = { ...(params ?? {}), unsigned: true };
}

const timeConfigParams = params ? timeConfig(params) : undefined;

let out = params
Expand All @@ -491,8 +511,23 @@ const column = (
return out;
}

if (lowered === 'float') {
let out = `${casing(name)}: float(${dbColumnName({ name, casing: rawCasing })})`;
if (lowered.startsWith('float')) {
let params:
| { precision?: string; scale?: string; unsigned?: boolean }
| undefined;

if (lowered.length > (lowered.includes('unsigned') ? 14 : 5)) {
const [precision, scale] = lowered
.slice(6, lowered.length - (1 + (lowered.includes('unsigned') ? 9 : 0)))
.split(',');
params = { precision, scale };
}

if (lowered.includes('unsigned')) {
params = { ...(params ?? {}), unsigned: true };
}

let out = `${casing(name)}: float(${dbColumnName({ name, casing: rawCasing })}${params ? timeConfig(params) : ''})`;
out += defaultValue
? `.default(${mapColumnDefault(defaultValue, isExpression)})`
: '';
Expand Down Expand Up @@ -700,16 +735,20 @@ const column = (

if (lowered.startsWith('decimal')) {
let params:
| { precision: string | undefined; scale: string | undefined }
| { precision?: string; scale?: string; unsigned?: boolean }
| undefined;

if (lowered.length > 7) {
if (lowered.length > (lowered.includes('unsigned') ? 16 : 7)) {
const [precision, scale] = lowered
.slice(8, lowered.length - 1)
.slice(8, lowered.length - (1 + (lowered.includes('unsigned') ? 9 : 0)))
.split(',');
params = { precision, scale };
}

if (lowered.includes('unsigned')) {
params = { ...(params ?? {}), unsigned: true };
}

const timeConfigParams = params ? timeConfig(params) : undefined;

let out = params
Expand Down
4 changes: 2 additions & 2 deletions drizzle-kit/src/serializer/mysqlSerializer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -651,8 +651,8 @@ export const fromDatabase = async (
}
}

if (columnType.startsWith('tinyint')) {
changedType = 'tinyint';
if (columnType.includes('decimal(10,0)')) {
changedType = columnType.replace('decimal(10,0)', 'decimal');
}

let onUpdate: boolean | undefined = undefined;
Expand Down
90 changes: 72 additions & 18 deletions drizzle-kit/tests/introspect/mysql.test.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@
import 'dotenv/config';
import Docker from 'dockerode';
import { SQL, sql } from 'drizzle-orm';
import { char, check, int, mysqlTable, mysqlView, serial, text, varchar } from 'drizzle-orm/mysql-core';
import {
bigint,
char,
check,
decimal,
double,
float,
int,
mediumint,
mysqlTable,
mysqlView,
serial,
smallint,
text,
tinyint,
varchar,
} from 'drizzle-orm/mysql-core';
import * as fs from 'fs';
import getPort from 'get-port';
import { Connection, createConnection } from 'mysql2/promise';
import { introspectMySQLToFile } from 'tests/schemaDiffer';
import { v4 as uuid } from 'uuid';
import { afterAll, beforeAll, expect, test } from 'vitest';
import { afterAll, beforeAll, beforeEach, expect, test } from 'vitest';

let client: Connection;
let mysqlContainer: Docker.Container;
Expand Down Expand Up @@ -71,6 +88,12 @@ afterAll(async () => {
await mysqlContainer?.stop().catch(console.error);
});

beforeEach(async () => {
await client.query(`drop database if exists \`drizzle\`;`);
await client.query(`create database \`drizzle\`;`);
await client.query(`use \`drizzle\`;`);
});

if (!fs.existsSync('tests/introspect/mysql')) {
fs.mkdirSync('tests/introspect/mysql');
}
Expand All @@ -95,8 +118,6 @@ test('generated always column: link to another column', async () => {

expect(statements.length).toBe(0);
expect(sqlStatements.length).toBe(0);

await client.query(`drop table users;`);
});

test('generated always column virtual: link to another column', async () => {
Expand All @@ -120,8 +141,6 @@ test('generated always column virtual: link to another column', async () => {

expect(statements.length).toBe(0);
expect(sqlStatements.length).toBe(0);

await client.query(`drop table users;`);
});

test('Default value of character type column: char', async () => {
Expand All @@ -141,8 +160,6 @@ test('Default value of character type column: char', async () => {

expect(statements.length).toBe(0);
expect(sqlStatements.length).toBe(0);

await client.query(`drop table users;`);
});

test('Default value of character type column: varchar', async () => {
Expand All @@ -162,8 +179,6 @@ test('Default value of character type column: varchar', async () => {

expect(statements.length).toBe(0);
expect(sqlStatements.length).toBe(0);

await client.query(`drop table users;`);
});

test('introspect checks', async () => {
Expand All @@ -186,8 +201,6 @@ test('introspect checks', async () => {

expect(statements.length).toBe(0);
expect(sqlStatements.length).toBe(0);

await client.query(`drop table users;`);
});

test('view #1', async () => {
Expand All @@ -210,14 +223,9 @@ test('view #1', async () => {

expect(statements.length).toBe(0);
expect(sqlStatements.length).toBe(0);

await client.query(`drop view some_view;`);
await client.query(`drop table users;`);
});

test('view #2', async () => {
// await client.query(`drop view some_view;`);

const users = mysqlTable('some_users', { id: int('id') });
const testView = mysqlView('some_view', { id: int('id') }).algorithm('temptable').sqlSecurity('definer').as(
sql`SELECT * FROM ${users}`,
Expand All @@ -237,6 +245,52 @@ test('view #2', async () => {

expect(statements.length).toBe(0);
expect(sqlStatements.length).toBe(0);
});

test('handle float type', async () => {
const schema = {
table: mysqlTable('table', {
col1: float(),
col2: float({ precision: 2 }),
col3: float({ precision: 2, scale: 1 }),
}),
};

const { statements, sqlStatements } = await introspectMySQLToFile(
client,
schema,
'handle-float-type',
'drizzle',
);

expect(statements.length).toBe(0);
expect(sqlStatements.length).toBe(0);
});

test('handle unsigned numerical types', async () => {
const schema = {
table: mysqlTable('table', {
col1: int({ unsigned: true }),
col2: tinyint({ unsigned: true }),
col3: smallint({ unsigned: true }),
col4: mediumint({ unsigned: true }),
col5: bigint({ mode: 'number', unsigned: true }),
col6: float({ unsigned: true }),
col7: float({ precision: 2, scale: 1, unsigned: true }),
col8: double({ unsigned: true }),
col9: double({ precision: 2, scale: 1, unsigned: true }),
col10: decimal({ unsigned: true }),
col11: decimal({ precision: 2, scale: 1, unsigned: true }),
}),
};

await client.query(`drop table some_users;`);
const { statements, sqlStatements } = await introspectMySQLToFile(
client,
schema,
'handle-unsigned-numerical-types',
'drizzle',
);

expect(statements.length).toBe(0);
expect(sqlStatements.length).toBe(0);
});
5 changes: 4 additions & 1 deletion drizzle-kit/tests/push/common.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { afterAll, beforeAll, test } from 'vitest';
import { afterAll, beforeAll, beforeEach, test } from 'vitest';

export interface DialectSuite {
allTypes(context?: any): Promise<void>;
Expand All @@ -22,10 +22,13 @@ export const run = (
suite: DialectSuite,
beforeAllFn?: (context: any) => Promise<void>,
afterAllFn?: (context: any) => Promise<void>,
beforeEachFn?: (context: any) => Promise<void>,
) => {
let context: any = {};
beforeAll(beforeAllFn ? () => beforeAllFn(context) : () => {});

beforeEach(beforeEachFn ? () => beforeEachFn(context) : () => {});

test('No diffs for all database types', () => suite.allTypes(context));
test('Adding basic indexes', () => suite.addBasicIndexes(context));
test('Dropping basic index', () => suite.dropIndex(context));
Expand Down
6 changes: 6 additions & 0 deletions drizzle-kit/tests/push/mysql.test.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import 'dotenv/config';
import Docker from 'dockerode';
import { SQL, sql } from 'drizzle-orm';
import {
Expand Down Expand Up @@ -696,4 +697,9 @@ run(
await context.client?.end().catch(console.error);
await context.mysqlContainer?.stop().catch(console.error);
},
async (context: any) => {
await context.client?.query(`drop database if exists \`drizzle\`;`);
await context.client?.query(`create database \`drizzle\`;`);
await context.client?.query(`use \`drizzle\`;`);
},
);
Loading

0 comments on commit 19f042a

Please sign in to comment.