diff --git a/modules/module-postgres/src/api/PostgresRouteAPIAdapter.ts b/modules/module-postgres/src/api/PostgresRouteAPIAdapter.ts index e35f2e4f2..c40eddc7a 100644 --- a/modules/module-postgres/src/api/PostgresRouteAPIAdapter.ts +++ b/modules/module-postgres/src/api/PostgresRouteAPIAdapter.ts @@ -9,8 +9,11 @@ import { getDebugTableInfo } from '../replication/replication-utils.js'; import { KEEPALIVE_STATEMENT, PUBLICATION_NAME } from '../replication/WalStream.js'; import * as types from '../types/types.js'; import { getApplicationName } from '../utils/application-name.js'; +import { CustomTypeRegistry } from '../types/registry.js'; +import { PostgresTypeResolver } from '../types/resolver.js'; export class PostgresRouteAPIAdapter implements api.RouteAPI { + private typeCache: PostgresTypeResolver; connectionTag: string; // TODO this should probably be configurable one day publicationName = PUBLICATION_NAME; @@ -31,6 +34,7 @@ export class PostgresRouteAPIAdapter implements api.RouteAPI { connectionTag?: string, private config?: types.ResolvedConnectionConfig ) { + this.typeCache = new PostgresTypeResolver(config?.typeRegistry ?? new CustomTypeRegistry(), pool); this.connectionTag = connectionTag ?? sync_rules.DEFAULT_TAG; } @@ -297,6 +301,7 @@ LEFT JOIN ( SELECT attrelid, attname, + atttypid, format_type(atttypid, atttypmod) as data_type, (SELECT typname FROM pg_catalog.pg_type WHERE oid = atttypid) as pg_type, attnum, @@ -311,6 +316,7 @@ LEFT JOIN ( ) GROUP BY schemaname, tablename, quoted_name` ); + await this.typeCache.fetchTypesForSchema(); const rows = pgwire.pgwireRows(results); let schemas: Record = {}; @@ -332,9 +338,11 @@ GROUP BY schemaname, tablename, quoted_name` if (pg_type.startsWith('_')) { pg_type = `${pg_type.substring(1)}[]`; } + + const knownType = this.typeCache.registry.lookupType(Number(column.atttypid)); table.columns.push({ name: column.attname, - sqlite_type: sync_rules.expressionTypeFromPostgresType(pg_type).typeFlags, + sqlite_type: sync_rules.ExpressionType.fromTypeText(knownType.sqliteType()).typeFlags, type: column.data_type, internal_type: column.data_type, pg_type: pg_type diff --git a/modules/module-postgres/src/index.ts b/modules/module-postgres/src/index.ts index ec110750f..3b0d87195 100644 --- a/modules/module-postgres/src/index.ts +++ b/modules/module-postgres/src/index.ts @@ -1,3 +1 @@ export * from './module/PostgresModule.js'; - -export * as pg_utils from './utils/pgwire_utils.js'; diff --git a/modules/module-postgres/src/module/PostgresModule.ts b/modules/module-postgres/src/module/PostgresModule.ts index d762110c0..48aefd9f5 100644 --- a/modules/module-postgres/src/module/PostgresModule.ts +++ b/modules/module-postgres/src/module/PostgresModule.ts @@ -19,8 +19,11 @@ import { WalStreamReplicator } from '../replication/WalStreamReplicator.js'; import * as types from '../types/types.js'; import { PostgresConnectionConfig } from '../types/types.js'; import { getApplicationName } from '../utils/application-name.js'; +import { CustomTypeRegistry } from '../types/registry.js'; export class PostgresModule extends replication.ReplicationModule { + private customTypes: CustomTypeRegistry = new CustomTypeRegistry(); + constructor() { super({ name: 'Postgres', @@ -48,7 +51,7 @@ export class PostgresModule extends replication.ReplicationModule[] = []; constructor( public options: NormalizedPostgresConnectionConfig, - public poolOptions: pgwire.PgPoolOptions + public poolOptions: PgManagerOptions ) { // The pool is lazy - no connections are opened until a query is performed. this.pool = pgwire.connectPgWirePool(this.options, poolOptions); + this.types = new PostgresTypeResolver(poolOptions.registry, this.pool); } public get connectionTag() { @@ -41,9 +51,7 @@ export class PgManager { * @returns The Postgres server version in a parsed Semver instance */ async getServerVersion(): Promise { - const result = await this.pool.query(`SHOW server_version;`); - // The result is usually of the form "16.2 (Debian 16.2-1.pgdg120+2)" - return semver.coerce(result.rows[0][0].split(' ')[0]); + return await getServerVersion(this.pool); } /** diff --git a/modules/module-postgres/src/replication/PgRelation.ts b/modules/module-postgres/src/replication/PgRelation.ts index cc3d9a840..3e665c5f2 100644 --- a/modules/module-postgres/src/replication/PgRelation.ts +++ b/modules/module-postgres/src/replication/PgRelation.ts @@ -30,3 +30,12 @@ export function getPgOutputRelation(source: PgoutputRelation): storage.SourceEnt replicaIdColumns: getReplicaIdColumns(source) } satisfies storage.SourceEntityDescriptor; } + +export function referencedColumnTypeIds(source: PgoutputRelation): number[] { + const oids = new Set(); + for (const column of source.columns) { + oids.add(column.typeOid); + } + + return [...oids]; +} diff --git a/modules/module-postgres/src/replication/WalStream.ts b/modules/module-postgres/src/replication/WalStream.ts index b5e26db4f..ee34b6d0f 100644 --- a/modules/module-postgres/src/replication/WalStream.ts +++ b/modules/module-postgres/src/replication/WalStream.ts @@ -29,10 +29,9 @@ import { TablePattern, toSyncRulesRow } from '@powersync/service-sync-rules'; -import * as pg_utils from '../utils/pgwire_utils.js'; import { PgManager } from './PgManager.js'; -import { getPgOutputRelation, getRelId } from './PgRelation.js'; +import { getPgOutputRelation, getRelId, referencedColumnTypeIds } from './PgRelation.js'; import { checkSourceConfiguration, checkTableRls, getReplicationIdentityColumns } from './replication-utils.js'; import { ReplicationMetric } from '@powersync/service-types'; import { @@ -189,28 +188,30 @@ export class WalStream { let tableRows: any[]; const prefix = tablePattern.isWildcard ? tablePattern.tablePrefix : undefined; - if (tablePattern.isWildcard) { - const result = await db.query({ - statement: `SELECT c.oid AS relid, c.relname AS table_name + + { + let query = ` + SELECT + c.oid AS relid, + c.relname AS table_name, + (SELECT + json_agg(DISTINCT a.atttypid) + FROM pg_attribute a + WHERE a.attnum > 0 AND NOT a.attisdropped AND a.attrelid = c.oid) + AS column_types FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE n.nspname = $1 - AND c.relkind = 'r' - AND c.relname LIKE $2`, - params: [ - { type: 'varchar', value: schema }, - { type: 'varchar', value: tablePattern.tablePattern } - ] - }); - tableRows = pgwire.pgwireRows(result); - } else { + AND c.relkind = 'r'`; + + if (tablePattern.isWildcard) { + query += ' AND c.relname LIKE $2'; + } else { + query += ' AND c.relname = $2'; + } + const result = await db.query({ - statement: `SELECT c.oid AS relid, c.relname AS table_name - FROM pg_class c - JOIN pg_namespace n ON n.oid = c.relnamespace - WHERE n.nspname = $1 - AND c.relkind = 'r' - AND c.relname = $2`, + statement: query, params: [ { type: 'varchar', value: schema }, { type: 'varchar', value: tablePattern.tablePattern } @@ -219,6 +220,7 @@ export class WalStream { tableRows = pgwire.pgwireRows(result); } + let result: storage.SourceTable[] = []; for (let row of tableRows) { @@ -258,16 +260,18 @@ export class WalStream { const cresult = await getReplicationIdentityColumns(db, relid); - const table = await this.handleRelation( + const columnTypes = (JSON.parse(row.column_types) as string[]).map((e) => Number(e)); + const table = await this.handleRelation({ batch, - { + descriptor: { name, schema, objectId: relid, replicaIdColumns: cresult.replicationColumns } as SourceEntityDescriptor, - false - ); + snapshot: false, + referencedTypeIds: columnTypes + }); result.push(table); } @@ -683,7 +687,14 @@ WHERE oid = $1::regclass`, } } - async handleRelation(batch: storage.BucketStorageBatch, descriptor: SourceEntityDescriptor, snapshot: boolean) { + async handleRelation(options: { + batch: storage.BucketStorageBatch; + descriptor: SourceEntityDescriptor; + snapshot: boolean; + referencedTypeIds: number[]; + }) { + const { batch, descriptor, snapshot, referencedTypeIds } = options; + if (!descriptor.objectId && typeof descriptor.objectId != 'number') { throw new ReplicationAssertionError(`objectId expected, got ${typeof descriptor.objectId}`); } @@ -699,6 +710,9 @@ WHERE oid = $1::regclass`, // Drop conflicting tables. This includes for example renamed tables. await batch.drop(result.dropTables); + // Ensure we have a description for custom types referenced in the table. + await this.connections.types.fetchTypes(referencedTypeIds); + // Snapshot if: // 1. Snapshot is requested (false for initial snapshot, since that process handles it elsewhere) // 2. Snapshot is not already done, AND: @@ -789,7 +803,7 @@ WHERE oid = $1::regclass`, if (msg.tag == 'insert') { this.metrics.getCounter(ReplicationMetric.ROWS_REPLICATED).add(1); - const baseRecord = pg_utils.constructAfterRecord(msg); + const baseRecord = this.connections.types.constructAfterRecord(msg); return await batch.save({ tag: storage.SaveOperationTag.INSERT, sourceTable: table, @@ -802,8 +816,8 @@ WHERE oid = $1::regclass`, this.metrics.getCounter(ReplicationMetric.ROWS_REPLICATED).add(1); // "before" may be null if the replica id columns are unchanged // It's fine to treat that the same as an insert. - const before = pg_utils.constructBeforeRecord(msg); - const after = pg_utils.constructAfterRecord(msg); + const before = this.connections.types.constructBeforeRecord(msg); + const after = this.connections.types.constructAfterRecord(msg); return await batch.save({ tag: storage.SaveOperationTag.UPDATE, sourceTable: table, @@ -814,7 +828,7 @@ WHERE oid = $1::regclass`, }); } else if (msg.tag == 'delete') { this.metrics.getCounter(ReplicationMetric.ROWS_REPLICATED).add(1); - const before = pg_utils.constructBeforeRecord(msg)!; + const before = this.connections.types.constructBeforeRecord(msg)!; return await batch.save({ tag: storage.SaveOperationTag.DELETE, @@ -955,7 +969,12 @@ WHERE oid = $1::regclass`, for (const msg of messages) { if (msg.tag == 'relation') { - await this.handleRelation(batch, getPgOutputRelation(msg), true); + await this.handleRelation({ + batch, + descriptor: getPgOutputRelation(msg), + snapshot: true, + referencedTypeIds: referencedColumnTypeIds(msg) + }); } else if (msg.tag == 'begin') { // This may span multiple transactions in the same chunk, or even across chunks. skipKeepalive = true; diff --git a/modules/module-postgres/src/types/registry.ts b/modules/module-postgres/src/types/registry.ts new file mode 100644 index 000000000..5ab6b9231 --- /dev/null +++ b/modules/module-postgres/src/types/registry.ts @@ -0,0 +1,278 @@ +import { + applyValueContext, + CompatibilityContext, + CompatibilityOption, + CustomSqliteValue, + DatabaseInputValue, + SqliteValue, + SqliteValueType, + toSyncRulesValue +} from '@powersync/service-sync-rules'; +import * as pgwire from '@powersync/service-jpgwire'; + +interface BaseType { + sqliteType: () => SqliteValueType; +} + +/** A type natively supported by {@link pgwire.PgType.decode}. */ +interface BuiltinType extends BaseType { + type: 'builtin'; + oid: number; +} + +/** + * An array type. + */ +interface ArrayType extends BaseType { + type: 'array'; + innerId: number; + separatorCharCode: number; +} + +/** + * A domain type, like `CREATE DOMAIN api.rating_value AS FLOAT CHECK (VALUE BETWEEN 0 AND 5);` + * + * This type gets decoded and synced as the inner type (`FLOAT` in the example above). + */ +interface DomainType extends BaseType { + type: 'domain'; + innerId: number; +} + +/** + * A composite type as created by `CREATE TYPE AS`. + * + * These types are encoded as a tuple of values, so we recover attribute names to restore them as a JSON object. + */ +interface CompositeType extends BaseType { + type: 'composite'; + members: { name: string; typeId: number }[]; +} + +/** + * A type created with `CREATE TYPE AS RANGE`. + * + * Ranges are represented as {@link pgwire.Range}. Multiranges are represented as arrays thereof. + */ +interface RangeType extends BaseType { + type: 'range' | 'multirange'; + innerId: number; +} + +type KnownType = BuiltinType | ArrayType | DomainType | DomainType | CompositeType | RangeType; + +interface UnknownType extends BaseType { + type: 'unknown'; +} + +type MaybeKnownType = KnownType | UnknownType; + +const UNKNOWN_TYPE: UnknownType = { + type: 'unknown', + sqliteType: () => 'text' +}; + +class CustomTypeValue extends CustomSqliteValue { + constructor( + readonly oid: number, + readonly cache: CustomTypeRegistry, + readonly rawValue: string + ) { + super(); + } + + private lookup(): KnownType | UnknownType { + return this.cache.lookupType(this.oid); + } + + private decodeToDatabaseInputValue(context: CompatibilityContext): DatabaseInputValue { + if (context.isEnabled(CompatibilityOption.customTypes)) { + try { + return this.cache.decodeWithCustomTypes(this.rawValue, this.oid); + } catch (_e) { + return this.rawValue; + } + } else { + return pgwire.PgType.decode(this.rawValue, this.oid); + } + } + + toSqliteValue(context: CompatibilityContext): SqliteValue { + const value = toSyncRulesValue(this.decodeToDatabaseInputValue(context)); + return applyValueContext(value, context); + } + + get sqliteType(): SqliteValueType { + return this.lookup().sqliteType(); + } +} + +/** + * A registry of custom types. + * + * These extend the builtin decoding behavior in {@link pgwire.PgType.decode} for user-defined types like `DOMAIN`s or + * composite types. + */ +export class CustomTypeRegistry { + private readonly byOid: Map; + + constructor() { + this.byOid = new Map(); + + for (const builtin of Object.values(pgwire.PgTypeOid)) { + if (typeof builtin == 'number') { + // We need to know the SQLite type of builtins to implement CustomSqliteValue.sqliteType for DOMAIN types. + let sqliteType: SqliteValueType; + switch (builtin) { + case pgwire.PgTypeOid.TEXT: + case pgwire.PgTypeOid.UUID: + case pgwire.PgTypeOid.VARCHAR: + case pgwire.PgTypeOid.DATE: + case pgwire.PgTypeOid.TIMESTAMP: + case pgwire.PgTypeOid.TIMESTAMPTZ: + case pgwire.PgTypeOid.TIME: + case pgwire.PgTypeOid.JSON: + case pgwire.PgTypeOid.JSONB: + case pgwire.PgTypeOid.PG_LSN: + sqliteType = 'text'; + break; + case pgwire.PgTypeOid.BYTEA: + sqliteType = 'blob'; + break; + case pgwire.PgTypeOid.BOOL: + case pgwire.PgTypeOid.INT2: + case pgwire.PgTypeOid.INT4: + case pgwire.PgTypeOid.OID: + case pgwire.PgTypeOid.INT8: + sqliteType = 'integer'; + break; + case pgwire.PgTypeOid.FLOAT4: + case pgwire.PgTypeOid.FLOAT8: + sqliteType = 'real'; + break; + default: + sqliteType = 'text'; + } + + this.byOid.set(builtin, { + type: 'builtin', + oid: builtin, + sqliteType: () => sqliteType + }); + } + } + + for (const [arrayId, innerId] of pgwire.ARRAY_TO_ELEM_OID.entries()) { + // We can just use the default decoder, except for box[] because those use a different delimiter. We don't fix + // this in PgType._decodeArray for backwards-compatibility. + if (innerId == 603) { + this.byOid.set(arrayId, { + type: 'array', + innerId, + sqliteType: () => 'text', // these get encoded as JSON arrays + separatorCharCode: 0x3b // ";" + }); + } else { + this.byOid.set(arrayId, { + type: 'builtin', + oid: arrayId, + sqliteType: () => 'text' // these get encoded as JSON arrays + }); + } + } + } + + knows(oid: number): boolean { + return this.byOid.has(oid); + } + + set(oid: number, value: KnownType) { + this.byOid.set(oid, value); + } + + setDomainType(oid: number, inner: number) { + this.set(oid, { + type: 'domain', + innerId: inner, + sqliteType: () => this.lookupType(inner).sqliteType() + }); + } + + decodeWithCustomTypes(raw: string, oid: number): DatabaseInputValue { + const resolved = this.lookupType(oid); + switch (resolved.type) { + case 'builtin': + case 'unknown': + return pgwire.PgType.decode(raw, oid); + case 'domain': + return this.decodeWithCustomTypes(raw, resolved.innerId); + case 'composite': { + const parsed: [string, any][] = []; + + new pgwire.StructureParser(raw).parseComposite((raw) => { + const nextMember = resolved.members[parsed.length]; + if (nextMember) { + const value = raw == null ? null : this.decodeWithCustomTypes(raw, nextMember.typeId); + parsed.push([nextMember.name, value]); + } + }); + return Object.fromEntries(parsed); + } + case 'array': { + // Nornalize "array of array of T" types into just "array of T", because Postgres arrays are natively multi- + // dimensional. This may be required when we have a DOMAIN wrapper around an array followed by another array + // around that domain. + let innerId = resolved.innerId; + while (true) { + const resolvedInner = this.lookupType(innerId); + if (resolvedInner.type == 'domain') { + innerId = resolvedInner.innerId; + } else if (resolvedInner.type == 'array') { + innerId = resolvedInner.innerId; + } else { + break; + } + } + + return new pgwire.StructureParser(raw).parseArray( + (source) => this.decodeWithCustomTypes(source, innerId), + resolved.separatorCharCode + ); + } + case 'range': + return new pgwire.StructureParser(raw).parseRange((s) => this.decodeWithCustomTypes(s, resolved.innerId)); + case 'multirange': + return new pgwire.StructureParser(raw).parseMultiRange((s) => this.decodeWithCustomTypes(s, resolved.innerId)); + } + } + + lookupType(type: number): KnownType | UnknownType { + return this.byOid.get(type) ?? UNKNOWN_TYPE; + } + + private isParsedWithoutCustomTypesSupport(type: MaybeKnownType): boolean { + switch (type.type) { + case 'builtin': + case 'unknown': + return true; + case 'array': + return ( + type.separatorCharCode == pgwire.CHAR_CODE_COMMA && + this.isParsedWithoutCustomTypesSupport(this.lookupType(type.innerId)) + ); + default: + return false; + } + } + + decodeDatabaseValue(value: string, oid: number): DatabaseInputValue { + const resolved = this.lookupType(oid); + // For backwards-compatibility, some types are only properly parsed with a compatibility option. Others are synced + // in the raw text representation by default, and are only parsed as JSON values when necessary. + if (this.isParsedWithoutCustomTypesSupport(resolved)) { + return pgwire.PgType.decode(value, oid); + } else { + return new CustomTypeValue(oid, this, value); + } + } +} diff --git a/modules/module-postgres/src/types/resolver.ts b/modules/module-postgres/src/types/resolver.ts new file mode 100644 index 000000000..694b53537 --- /dev/null +++ b/modules/module-postgres/src/types/resolver.ts @@ -0,0 +1,210 @@ +import { DatabaseInputRow, SqliteInputRow, toSyncRulesRow } from '@powersync/service-sync-rules'; +import * as pgwire from '@powersync/service-jpgwire'; +import { CustomTypeRegistry } from './registry.js'; +import semver from 'semver'; +import { getServerVersion } from '../utils/postgres_version.js'; + +/** + * Resolves descriptions used to decode values for custom postgres types. + * + * Custom types are resolved from the source database, which also involves crawling inner types (e.g. for composites). + */ +export class PostgresTypeResolver { + private cachedVersion: semver.SemVer | null = null; + + constructor( + readonly registry: CustomTypeRegistry, + private readonly pool: pgwire.PgClient + ) { + this.registry = new CustomTypeRegistry(); + } + + private async fetchVersion(): Promise { + if (this.cachedVersion == null) { + this.cachedVersion = (await getServerVersion(this.pool)) ?? semver.parse('0.0.1'); + } + + return this.cachedVersion!; + } + + /** + * @returns Whether the Postgres instance this type cache is connected to has support for the multirange type (which + * is the case for Postgres 14 and later). + */ + async supportsMultiRanges() { + const version = await this.fetchVersion(); + return version.compare(PostgresTypeResolver.minVersionForMultirange) >= 0; + } + + /** + * Fetches information about indicated types. + * + * If a type references another custom type (e.g. because it's a composite type with a custom field), these are + * automatically crawled as well. + */ + public async fetchTypes(oids: number[]) { + const multiRangeSupport = await this.supportsMultiRanges(); + + let pending = oids.filter((id) => !this.registry.knows(id)); + // For details on columns, see https://www.postgresql.org/docs/current/catalog-pg-type.html + const multiRangeDesc = `WHEN 'm' THEN json_build_object('inner', (SELECT rngsubtype FROM pg_range WHERE rngmultitypid = t.oid))`; + const statement = ` +SELECT oid, t.typtype, + CASE t.typtype + WHEN 'b' THEN json_build_object('element_type', t.typelem, 'delim', (SELECT typdelim FROM pg_type i WHERE i.oid = t.typelem)) + WHEN 'd' THEN json_build_object('type', t.typbasetype) + WHEN 'c' THEN json_build_object( + 'elements', + (SELECT json_agg(json_build_object('name', a.attname, 'type', a.atttypid)) + FROM pg_attribute a + WHERE a.attrelid = t.typrelid) + ) + WHEN 'r' THEN json_build_object('inner', (SELECT rngsubtype FROM pg_range WHERE rngtypid = t.oid)) + ${multiRangeSupport ? multiRangeDesc : ''} + ELSE NULL + END AS desc +FROM pg_type t +WHERE t.oid = ANY($1) +`; + + while (pending.length != 0) { + // 1016: int8 array + const query = await this.pool.query({ statement, params: [{ type: 1016, value: pending }] }); + const stillPending: number[] = []; + + const requireType = (oid: number) => { + if (!this.registry.knows(oid) && !pending.includes(oid) && !stillPending.includes(oid)) { + stillPending.push(oid); + } + }; + + for (const row of pgwire.pgwireRows(query)) { + const oid = Number(row.oid); + const desc = JSON.parse(row.desc); + + switch (row.typtype) { + case 'b': + const { element_type, delim } = desc; + + if (!this.registry.knows(oid)) { + // This type is an array of another custom type. + const inner = Number(element_type); + if (inner != 0) { + // Some array types like macaddr[] don't seem to have their inner type set properly - skip! + requireType(inner); + this.registry.set(oid, { + type: 'array', + innerId: inner, + separatorCharCode: (delim as string).charCodeAt(0), + sqliteType: () => 'text' // Since it's JSON + }); + } + } + break; + case 'c': + // For composite types, we sync the JSON representation. + const elements: { name: string; typeId: number }[] = []; + for (const { name, type } of desc.elements) { + const typeId = Number(type); + elements.push({ name, typeId }); + requireType(typeId); + } + + this.registry.set(oid, { + type: 'composite', + members: elements, + sqliteType: () => 'text' // Since it's JSON + }); + break; + case 'd': + // For domain values like CREATE DOMAIN api.rating_value AS FLOAT CHECK (VALUE BETWEEN 0 AND 5), we sync + // the inner type (pg_type.typbasetype). + const inner = Number(desc.type); + this.registry.setDomainType(oid, inner); + requireType(inner); + break; + case 'r': + case 'm': { + const inner = Number(desc.inner); + this.registry.set(oid, { + type: row.typtype == 'r' ? 'range' : 'multirange', + innerId: inner, + sqliteType: () => 'text' // Since it's JSON + }); + } + } + } + + pending = stillPending; + } + } + + /** + * Crawls all custom types referenced by table columns in the current database. + */ + public async fetchTypesForSchema() { + const sql = ` +SELECT DISTINCT a.atttypid AS type_oid +FROM pg_attribute a +JOIN pg_class c ON c.oid = a.attrelid +JOIN pg_namespace cn ON cn.oid = c.relnamespace +JOIN pg_type t ON t.oid = a.atttypid +JOIN pg_namespace tn ON tn.oid = t.typnamespace +WHERE a.attnum > 0 + AND NOT a.attisdropped + AND cn.nspname not in ('information_schema', 'pg_catalog', 'pg_toast') + `; + + const query = await this.pool.query({ statement: sql }); + let ids: number[] = []; + for (const row of pgwire.pgwireRows(query)) { + ids.push(Number(row.type_oid)); + } + + await this.fetchTypes(ids); + } + + /** + * pgwire message -> SQLite row. + * @param message + */ + constructAfterRecord(message: pgwire.PgoutputInsert | pgwire.PgoutputUpdate): SqliteInputRow { + const rawData = (message as any).afterRaw; + + const record = this.decodeTuple(message.relation, rawData); + return toSyncRulesRow(record); + } + + /** + * pgwire message -> SQLite row. + * @param message + */ + constructBeforeRecord(message: pgwire.PgoutputDelete | pgwire.PgoutputUpdate): SqliteInputRow | undefined { + const rawData = (message as any).beforeRaw; + if (rawData == null) { + return undefined; + } + const record = this.decodeTuple(message.relation, rawData); + return toSyncRulesRow(record); + } + + /** + * We need a high level of control over how values are decoded, to make sure there is no loss + * of precision in the process. + */ + decodeTuple(relation: pgwire.PgoutputRelation, tupleRaw: Record): DatabaseInputRow { + let result: Record = {}; + for (let columnName in tupleRaw) { + const rawval = tupleRaw[columnName]; + const typeOid = (relation as any)._tupleDecoder._typeOids.get(columnName); + if (typeof rawval == 'string' && typeOid) { + result[columnName] = this.registry.decodeDatabaseValue(rawval, typeOid); + } else { + result[columnName] = rawval; + } + } + return result; + } + + private static minVersionForMultirange: semver.SemVer = semver.parse('14.0.0')!; +} diff --git a/modules/module-postgres/src/types/types.ts b/modules/module-postgres/src/types/types.ts index 4de2ac0ed..3fd3c9fbb 100644 --- a/modules/module-postgres/src/types/types.ts +++ b/modules/module-postgres/src/types/types.ts @@ -1,6 +1,7 @@ import * as lib_postgres from '@powersync/lib-service-postgres'; import * as service_types from '@powersync/service-types'; import * as t from 'ts-codec'; +import { CustomTypeRegistry } from './registry.js'; // Maintain backwards compatibility by exporting these export const validatePort = lib_postgres.validatePort; @@ -24,7 +25,10 @@ export type PostgresConnectionConfig = t.Decoded SQLite row. - * @param message - */ -export function constructAfterRecord(message: pgwire.PgoutputInsert | pgwire.PgoutputUpdate): SqliteInputRow { - const rawData = (message as any).afterRaw; - - const record = decodeTuple(message.relation, rawData); - return toSyncRulesRow(record); -} - -/** - * pgwire message -> SQLite row. - * @param message - */ -export function constructBeforeRecord( - message: pgwire.PgoutputDelete | pgwire.PgoutputUpdate -): SqliteInputRow | undefined { - const rawData = (message as any).beforeRaw; - if (rawData == null) { - return undefined; - } - const record = decodeTuple(message.relation, rawData); - return toSyncRulesRow(record); -} - -/** - * We need a high level of control over how values are decoded, to make sure there is no loss - * of precision in the process. - */ -export function decodeTuple(relation: pgwire.PgoutputRelation, tupleRaw: Record): DatabaseInputRow { - let result: Record = {}; - for (let columnName in tupleRaw) { - const rawval = tupleRaw[columnName]; - const typeOid = (relation as any)._tupleDecoder._typeOids.get(columnName); - if (typeof rawval == 'string' && typeOid) { - result[columnName] = pgwire.PgType.decode(rawval, typeOid); - } else { - result[columnName] = rawval; - } - } - return result; -} diff --git a/modules/module-postgres/src/utils/postgres_version.ts b/modules/module-postgres/src/utils/postgres_version.ts new file mode 100644 index 000000000..7e2a7e9ce --- /dev/null +++ b/modules/module-postgres/src/utils/postgres_version.ts @@ -0,0 +1,8 @@ +import * as pgwire from '@powersync/service-jpgwire'; +import semver, { type SemVer } from 'semver'; + +export async function getServerVersion(db: pgwire.PgClient): Promise { + const result = await db.query(`SHOW server_version;`); + // The result is usually of the form "16.2 (Debian 16.2-1.pgdg120+2)" + return semver.coerce(result.rows[0][0].split(' ')[0]); +} diff --git a/modules/module-postgres/test/src/pg_test.test.ts b/modules/module-postgres/test/src/pg_test.test.ts index 116d95cbc..e41c61f63 100644 --- a/modules/module-postgres/test/src/pg_test.test.ts +++ b/modules/module-postgres/test/src/pg_test.test.ts @@ -1,4 +1,3 @@ -import { constructAfterRecord } from '@module/utils/pgwire_utils.js'; import * as pgwire from '@powersync/service-jpgwire'; import { applyRowContext, @@ -11,6 +10,8 @@ import { import { describe, expect, test } from 'vitest'; import { clearTestDb, connectPgPool, connectPgWire, TEST_URI } from './util.js'; import { WalStream } from '@module/replication/WalStream.js'; +import { PostgresTypeResolver } from '@module/types/resolver.js'; +import { CustomTypeRegistry } from '@module/types/registry.js'; describe('pg data types', () => { async function setupTable(db: pgwire.PgClient) { @@ -382,7 +383,7 @@ VALUES(10, ARRAY['null']::TEXT[]); } }); - const transformed = await getReplicationTx(replicationStream); + const transformed = await getReplicationTx(db, replicationStream); await pg.end(); checkResults(transformed); @@ -419,7 +420,7 @@ VALUES(10, ARRAY['null']::TEXT[]); } }); - const transformed = await getReplicationTx(replicationStream); + const transformed = await getReplicationTx(db, replicationStream); await pg.end(); checkResultArrays(transformed.map((e) => applyRowContext(e, CompatibilityContext.FULL_BACKWARDS_COMPATIBILITY))); @@ -470,17 +471,163 @@ INSERT INTO test_data(id, time, timestamp, timestamptz) VALUES (1, '17:42:01.12' await db.end(); } }); + + test('test replication - custom types', async () => { + const db = await connectPgPool(); + try { + await clearTestDb(db); + await db.query(`CREATE DOMAIN rating_value AS FLOAT CHECK (VALUE BETWEEN 0 AND 5);`); + await db.query(`CREATE TYPE composite AS (foo rating_value[], bar TEXT);`); + await db.query(`CREATE TYPE nested_composite AS (a BOOLEAN, b composite);`); + await db.query(`CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')`); + + await db.query(`CREATE TABLE test_custom( + id serial primary key, + rating rating_value, + composite composite, + nested_composite nested_composite, + boxes box[], + mood mood + );`); + + const slotName = 'test_slot'; + + await db.query({ + statement: 'SELECT pg_drop_replication_slot(slot_name) FROM pg_replication_slots WHERE slot_name = $1', + params: [{ type: 'varchar', value: slotName }] + }); + + await db.query({ + statement: `SELECT slot_name, lsn FROM pg_catalog.pg_create_logical_replication_slot($1, 'pgoutput')`, + params: [{ type: 'varchar', value: slotName }] + }); + + await db.query(` + INSERT INTO test_custom + (rating, composite, nested_composite, boxes, mood) + VALUES ( + 1, + (ARRAY[2,3], 'bar'), + (TRUE, (ARRAY[2,3], 'bar')), + ARRAY[box(point '(1,2)', point '(3,4)'), box(point '(5, 6)', point '(7,8)')], + 'happy' + ); + `); + + const pg: pgwire.PgConnection = await pgwire.pgconnect({ replication: 'database' }, TEST_URI); + const replicationStream = await pg.logicalReplication({ + slot: slotName, + options: { + proto_version: '1', + publication_names: 'powersync' + } + }); + + const [transformed] = await getReplicationTx(db, replicationStream); + await pg.end(); + + const oldFormat = applyRowContext(transformed, CompatibilityContext.FULL_BACKWARDS_COMPATIBILITY); + expect(oldFormat).toMatchObject({ + rating: '1', + composite: '("{2,3}",bar)', + nested_composite: '(t,"(""{2,3}"",bar)")', + boxes: '["(3","4)","(1","2);(7","8)","(5","6)"]', + mood: 'happy' + }); + + const newFormat = applyRowContext(transformed, new CompatibilityContext(CompatibilityEdition.SYNC_STREAMS)); + expect(newFormat).toMatchObject({ + rating: 1, + composite: '{"foo":[2.0,3.0],"bar":"bar"}', + nested_composite: '{"a":1,"b":{"foo":[2.0,3.0],"bar":"bar"}}', + boxes: JSON.stringify(['(3,4),(1,2)', '(7,8),(5,6)']), + mood: 'happy' + }); + } finally { + await db.end(); + } + }); + + test('test replication - multiranges', async () => { + const db = await connectPgPool(); + + if (!(await new PostgresTypeResolver(new CustomTypeRegistry(), db).supportsMultiRanges())) { + // This test requires Postgres 14 or later. + return; + } + + try { + await clearTestDb(db); + + await db.query(`CREATE TABLE test_custom( + id serial primary key, + ranges int4multirange[] + );`); + + const slotName = 'test_slot'; + + await db.query({ + statement: 'SELECT pg_drop_replication_slot(slot_name) FROM pg_replication_slots WHERE slot_name = $1', + params: [{ type: 'varchar', value: slotName }] + }); + + await db.query({ + statement: `SELECT slot_name, lsn FROM pg_catalog.pg_create_logical_replication_slot($1, 'pgoutput')`, + params: [{ type: 'varchar', value: slotName }] + }); + + await db.query(` + INSERT INTO test_custom + (ranges) + VALUES ( + ARRAY[int4multirange(int4range(2, 4), int4range(5, 7, '(]'))]::int4multirange[] + ); + `); + + const pg: pgwire.PgConnection = await pgwire.pgconnect({ replication: 'database' }, TEST_URI); + const replicationStream = await pg.logicalReplication({ + slot: slotName, + options: { + proto_version: '1', + publication_names: 'powersync' + } + }); + + const [transformed] = await getReplicationTx(db, replicationStream); + await pg.end(); + + const oldFormat = applyRowContext(transformed, CompatibilityContext.FULL_BACKWARDS_COMPATIBILITY); + expect(oldFormat).toMatchObject({ + ranges: '{"{[2,4),[6,8)}"}' + }); + + const newFormat = applyRowContext(transformed, new CompatibilityContext(CompatibilityEdition.SYNC_STREAMS)); + expect(newFormat).toMatchObject({ + ranges: JSON.stringify([ + [ + { lower: 2, upper: 4, lower_exclusive: 0, upper_exclusive: 1 }, + { lower: 6, upper: 8, lower_exclusive: 0, upper_exclusive: 1 } + ] + ]) + }); + } finally { + await db.end(); + } + }); }); /** * Return all the inserts from the first transaction in the replication stream. */ -async function getReplicationTx(replicationStream: pgwire.ReplicationStream) { +async function getReplicationTx(db: pgwire.PgClient, replicationStream: pgwire.ReplicationStream) { + const typeCache = new PostgresTypeResolver(new CustomTypeRegistry(), db); + await typeCache.fetchTypesForSchema(); + let transformed: SqliteInputRow[] = []; for await (const batch of replicationStream.pgoutputDecode()) { for (const msg of batch.messages) { if (msg.tag == 'insert') { - transformed.push(constructAfterRecord(msg)); + transformed.push(typeCache.constructAfterRecord(msg)); } else if (msg.tag == 'commit') { return transformed; } diff --git a/modules/module-postgres/test/src/route_api_adapter.test.ts b/modules/module-postgres/test/src/route_api_adapter.test.ts new file mode 100644 index 000000000..98f16930c --- /dev/null +++ b/modules/module-postgres/test/src/route_api_adapter.test.ts @@ -0,0 +1,60 @@ +import { describe, expect, test } from 'vitest'; +import { clearTestDb, connectPgPool } from './util.js'; +import { PostgresRouteAPIAdapter } from '@module/api/PostgresRouteAPIAdapter.js'; +import { TYPE_INTEGER, TYPE_REAL, TYPE_TEXT } from '@powersync/service-sync-rules'; + +describe('PostgresRouteAPIAdapter tests', () => { + test('infers connection schema', async () => { + const db = await connectPgPool(); + try { + await clearTestDb(db); + const api = new PostgresRouteAPIAdapter(db); + + await db.query(`CREATE DOMAIN rating_value AS FLOAT CHECK (VALUE BETWEEN 0 AND 5)`); + await db.query(` + CREATE TABLE test_users ( + id TEXT NOT NULL PRIMARY KEY, + is_admin BOOLEAN, + rating RATING_VALUE + ); + `); + + const schema = await api.getConnectionSchema(); + expect(schema).toStrictEqual([ + { + name: 'public', + tables: [ + { + name: 'test_users', + columns: [ + { + internal_type: 'text', + name: 'id', + pg_type: 'text', + sqlite_type: TYPE_TEXT, + type: 'text' + }, + { + internal_type: 'boolean', + name: 'is_admin', + pg_type: 'bool', + sqlite_type: TYPE_INTEGER, + type: 'boolean' + }, + { + internal_type: 'rating_value', + name: 'rating', + pg_type: 'rating_value', + sqlite_type: TYPE_REAL, + type: 'rating_value' + } + ] + } + ] + } + ]); + } finally { + await db.end(); + } + }); +}); diff --git a/modules/module-postgres/test/src/schema_changes.test.ts b/modules/module-postgres/test/src/schema_changes.test.ts index 26a4fa3c8..c1994e7a8 100644 --- a/modules/module-postgres/test/src/schema_changes.test.ts +++ b/modules/module-postgres/test/src/schema_changes.test.ts @@ -590,4 +590,36 @@ function defineTests(factory: storage.TestStorageFactory) { expect(failures).toEqual([]); }); + + test('custom types', async () => { + await using context = await WalStreamTestContext.open(factory); + + await context.updateSyncRules(` +streams: + stream: + query: SELECT * FROM "test_data" + +config: + edition: 2 +`); + + const { pool } = context; + await pool.query(`CREATE TABLE test_data(id text primary key);`); + await pool.query(`INSERT INTO test_data(id) VALUES ('t1')`); + + await context.replicateSnapshot(); + context.startStreaming(); + + await pool.query( + { statement: `CREATE TYPE composite AS (foo bool, bar int4);` }, + { statement: `ALTER TABLE test_data ADD COLUMN other composite;` }, + { statement: `UPDATE test_data SET other = ROW(TRUE, 2)::composite;` } + ); + + const data = await context.getBucketData('1#stream|0[]'); + expect(data).toMatchObject([ + putOp('test_data', { id: 't1' }), + putOp('test_data', { id: 't1', other: '{"foo":1,"bar":2}' }) + ]); + }); } diff --git a/modules/module-postgres/test/src/slow_tests.test.ts b/modules/module-postgres/test/src/slow_tests.test.ts index ae5294887..49dd23952 100644 --- a/modules/module-postgres/test/src/slow_tests.test.ts +++ b/modules/module-postgres/test/src/slow_tests.test.ts @@ -19,6 +19,7 @@ import { METRICS_HELPER, test_utils } from '@powersync/service-core-tests'; import * as mongo_storage from '@powersync/service-module-mongodb-storage'; import * as postgres_storage from '@powersync/service-module-postgres-storage'; import * as timers from 'node:timers/promises'; +import { CustomTypeRegistry } from '@module/types/registry.js'; describe.skipIf(!(env.CI || env.SLOW_TESTS))('slow tests', function () { describeWithStorage({ timeout: 120_000 }, function (factory) { @@ -68,7 +69,7 @@ function defineSlowTests(factory: storage.TestStorageFactory) { }); async function testRepeatedReplication(testOptions: { compact: boolean; maxBatchSize: number; numBatches: number }) { - const connections = new PgManager(TEST_CONNECTION_OPTIONS, {}); + const connections = new PgManager(TEST_CONNECTION_OPTIONS, { registry: new CustomTypeRegistry() }); const replicationConnection = await connections.replicationConnection(); const pool = connections.pool; await clearTestDb(pool); @@ -329,7 +330,7 @@ bucket_definitions: await pool.query(`SELECT pg_drop_replication_slot(slot_name) FROM pg_replication_slots WHERE active = FALSE`); i += 1; - const connections = new PgManager(TEST_CONNECTION_OPTIONS, {}); + const connections = new PgManager(TEST_CONNECTION_OPTIONS, { registry: new CustomTypeRegistry() }); const replicationConnection = await connections.replicationConnection(); abortController = new AbortController(); diff --git a/modules/module-postgres/test/src/types/registry.test.ts b/modules/module-postgres/test/src/types/registry.test.ts new file mode 100644 index 000000000..d76f78f7e --- /dev/null +++ b/modules/module-postgres/test/src/types/registry.test.ts @@ -0,0 +1,149 @@ +import { describe, expect, test, beforeEach } from 'vitest'; +import { CustomTypeRegistry } from '@module/types/registry.js'; +import { CHAR_CODE_COMMA, PgTypeOid } from '@powersync/service-jpgwire'; +import { + applyValueContext, + CompatibilityContext, + CompatibilityEdition, + toSyncRulesValue +} from '@powersync/service-sync-rules'; + +describe('custom type registry', () => { + let registry: CustomTypeRegistry; + + beforeEach(() => { + registry = new CustomTypeRegistry(); + }); + + function checkResult(raw: string, type: number, old: any, fixed: any) { + const input = registry.decodeDatabaseValue(raw, type); + const syncRulesValue = toSyncRulesValue(input); + + expect(applyValueContext(syncRulesValue, CompatibilityContext.FULL_BACKWARDS_COMPATIBILITY)).toStrictEqual(old); + expect( + applyValueContext(syncRulesValue, new CompatibilityContext(CompatibilityEdition.SYNC_STREAMS)) + ).toStrictEqual(fixed); + } + + test('domain types', () => { + registry.setDomainType(1337, PgTypeOid.INT4); // create domain wrapping integer + checkResult('12', 1337, '12', 12n); // Should be raw text value without fix, parsed as inner type if enabled + }); + + test('array of domain types', () => { + registry.setDomainType(1337, PgTypeOid.INT4); + registry.set(1338, { type: 'array', separatorCharCode: CHAR_CODE_COMMA, innerId: 1337, sqliteType: () => 'text' }); + + checkResult('{1,2,3}', 1338, '{1,2,3}', '[1,2,3]'); + }); + + test('nested array through domain type', () => { + registry.setDomainType(1337, PgTypeOid.INT4); + registry.set(1338, { type: 'array', separatorCharCode: CHAR_CODE_COMMA, innerId: 1337, sqliteType: () => 'text' }); + registry.setDomainType(1339, 1338); + + checkResult('{1,2,3}', 1339, '{1,2,3}', '[1,2,3]'); + + registry.set(1400, { type: 'array', separatorCharCode: CHAR_CODE_COMMA, innerId: 1339, sqliteType: () => 'text' }); + checkResult('{{1,2,3}}', 1400, '{{1,2,3}}', '[[1,2,3]]'); + }); + + test('structure', () => { + // create type c1 AS (a bool, b integer, c text[]); + registry.set(1337, { + type: 'composite', + sqliteType: () => 'text', + members: [ + { name: 'a', typeId: PgTypeOid.BOOL }, + { name: 'b', typeId: PgTypeOid.INT4 }, + { name: 'c', typeId: 1009 } // text array + ] + }); + + // SELECT (TRUE, 123, ARRAY['foo', 'bar'])::c1; + checkResult('(t,123,"{foo,bar}")', 1337, '(t,123,"{foo,bar}")', '{"a":1,"b":123,"c":["foo","bar"]}'); + }); + + test('array of structure', () => { + // create type c1 AS (a bool, b integer, c text[]); + registry.set(1337, { + type: 'composite', + sqliteType: () => 'text', + members: [ + { name: 'a', typeId: PgTypeOid.BOOL }, + { name: 'b', typeId: PgTypeOid.INT4 }, + { name: 'c', typeId: 1009 } // text array + ] + }); + registry.set(1338, { type: 'array', separatorCharCode: CHAR_CODE_COMMA, innerId: 1337, sqliteType: () => 'text' }); + + // SELECT ARRAY[(TRUE, 123, ARRAY['foo', 'bar']),(FALSE, NULL, ARRAY[]::text[])]::c1[]; + checkResult( + '{"(t,123,\\"{foo,bar}\\")","(f,,{})"}', + 1338, + '{"(t,123,\\"{foo,bar}\\")","(f,,{})"}', + '[{"a":1,"b":123,"c":["foo","bar"]},{"a":0,"b":null,"c":[]}]' + ); + }); + + test('domain type of structure', () => { + registry.set(1337, { + type: 'composite', + sqliteType: () => 'text', + members: [ + { name: 'a', typeId: PgTypeOid.BOOL }, + { name: 'b', typeId: PgTypeOid.INT4 } + ] + }); + registry.setDomainType(1338, 1337); + + checkResult('(t,123)', 1337, '(t,123)', '{"a":1,"b":123}'); + }); + + test('structure of another structure', () => { + // CREATE TYPE c2 AS (a BOOLEAN, b INTEGER); + registry.set(1337, { + type: 'composite', + sqliteType: () => 'text', + members: [ + { name: 'a', typeId: PgTypeOid.BOOL }, + { name: 'b', typeId: PgTypeOid.INT4 } + ] + }); + registry.set(1338, { type: 'array', separatorCharCode: CHAR_CODE_COMMA, innerId: 1337, sqliteType: () => 'text' }); + // CREATE TYPE c3 (c c2[]); + registry.set(1339, { + type: 'composite', + sqliteType: () => 'text', + members: [{ name: 'c', typeId: 1338 }] + }); + + // SELECT ROW(ARRAY[(FALSE,2)]::c2[])::c3; + checkResult('("{""(f,2)""}")', 1339, '("{""(f,2)""}")', '{"c":[{"a":0,"b":2}]}'); + }); + + test('range', () => { + registry.set(1337, { + type: 'range', + sqliteType: () => 'text', + innerId: PgTypeOid.INT2 + }); + + checkResult('[1,2]', 1337, '[1,2]', '{"lower":1,"upper":2,"lower_exclusive":0,"upper_exclusive":0}'); + }); + + test('multirange', () => { + registry.set(1337, { + type: 'multirange', + sqliteType: () => 'text', + innerId: PgTypeOid.INT2 + }); + + checkResult( + '{[1,2),[3,4)}', + 1337, + '{[1,2),[3,4)}', + '[{"lower":1,"upper":2,"lower_exclusive":0,"upper_exclusive":1},{"lower":3,"upper":4,"lower_exclusive":0,"upper_exclusive":1}]' + ); + }); +}); diff --git a/modules/module-postgres/test/src/util.ts b/modules/module-postgres/test/src/util.ts index 130b70fe9..410dd50e2 100644 --- a/modules/module-postgres/test/src/util.ts +++ b/modules/module-postgres/test/src/util.ts @@ -59,6 +59,22 @@ export async function clearTestDb(db: pgwire.PgClient) { await db.query(`DROP TABLE public.${lib_postgres.escapeIdentifier(name)}`); } } + + const domainRows = pgwire.pgwireRows( + await db.query(` + SELECT typname,typtype + FROM pg_type t + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace + WHERE n.nspname = 'public' AND typarray != 0 + `) + ); + for (let row of domainRows) { + if (row.typtype == 'd') { + await db.query(`DROP DOMAIN public.${lib_postgres.escapeIdentifier(row.typname)} CASCADE`); + } else { + await db.query(`DROP TYPE public.${lib_postgres.escapeIdentifier(row.typname)} CASCADE`); + } + } } export async function connectPgWire(type?: 'replication' | 'standard') { diff --git a/modules/module-postgres/test/src/wal_stream.test.ts b/modules/module-postgres/test/src/wal_stream.test.ts index 53f426171..80e545773 100644 --- a/modules/module-postgres/test/src/wal_stream.test.ts +++ b/modules/module-postgres/test/src/wal_stream.test.ts @@ -324,4 +324,28 @@ bucket_definitions: // creating a new replication slot. } }); + + test('custom types', async () => { + await using context = await WalStreamTestContext.open(factory); + + await context.updateSyncRules(` +streams: + stream: + query: SELECT id, * FROM "test_data" + +config: + edition: 2 +`); + + const { pool } = context; + await pool.query(`DROP TABLE IF EXISTS test_data`); + await pool.query(`CREATE TYPE composite AS (foo bool, bar int4);`); + await pool.query(`CREATE TABLE test_data(id text primary key, description composite);`); + + await context.initializeReplication(); + await pool.query(`INSERT INTO test_data(id, description) VALUES ('t1', ROW(TRUE, 2)::composite)`); + + const data = await context.getBucketData('1#stream|0[]'); + expect(data).toMatchObject([putOp('test_data', { id: 't1', description: '{"foo":1,"bar":2}' })]); + }); } diff --git a/modules/module-postgres/test/src/wal_stream_utils.ts b/modules/module-postgres/test/src/wal_stream_utils.ts index b94c73c5a..40ea5aaba 100644 --- a/modules/module-postgres/test/src/wal_stream_utils.ts +++ b/modules/module-postgres/test/src/wal_stream_utils.ts @@ -12,6 +12,7 @@ import { import { METRICS_HELPER, test_utils } from '@powersync/service-core-tests'; import * as pgwire from '@powersync/service-jpgwire'; import { clearTestDb, getClientCheckpoint, TEST_CONNECTION_OPTIONS } from './util.js'; +import { CustomTypeRegistry } from '@module/types/registry.js'; export class WalStreamTestContext implements AsyncDisposable { private _walStream?: WalStream; @@ -32,7 +33,7 @@ export class WalStreamTestContext implements AsyncDisposable { options?: { doNotClear?: boolean; walStreamOptions?: Partial } ) { const f = await factory({ doNotClear: options?.doNotClear }); - const connectionManager = new PgManager(TEST_CONNECTION_OPTIONS, {}); + const connectionManager = new PgManager(TEST_CONNECTION_OPTIONS, { registry: new CustomTypeRegistry() }); if (!options?.doNotClear) { await clearTestDb(connectionManager.pool); diff --git a/packages/jpgwire/package.json b/packages/jpgwire/package.json index 0218ead66..245e11780 100644 --- a/packages/jpgwire/package.json +++ b/packages/jpgwire/package.json @@ -15,12 +15,15 @@ "type": "module", "scripts": { "clean": "rm -r ./dist && tsc -b --clean", - "build": "tsc -b" + "build": "tsc -b", + "build:tests": "tsc -b test/tsconfig.json", + "test": "vitest" }, "dependencies": { "@powersync/service-jsonbig": "workspace:^", "@powersync/service-sync-rules": "workspace:^", "date-fns": "^4.1.0", - "pgwire": "github:kagis/pgwire#f1cb95f9a0f42a612bb5a6b67bb2eb793fc5fc87" + "pgwire": "github:kagis/pgwire#f1cb95f9a0f42a612bb5a6b67bb2eb793fc5fc87", + "vitest": "^3.0.5" } } diff --git a/packages/jpgwire/src/index.ts b/packages/jpgwire/src/index.ts index 0caa6b349..53a02fe44 100644 --- a/packages/jpgwire/src/index.ts +++ b/packages/jpgwire/src/index.ts @@ -3,3 +3,4 @@ export * from './certs.js'; export * from './util.js'; export * from './metrics.js'; export * from './pgwire_types.js'; +export * from './structure_parser.js'; diff --git a/packages/jpgwire/src/pgwire_types.ts b/packages/jpgwire/src/pgwire_types.ts index d71a14d34..555de6e3a 100644 --- a/packages/jpgwire/src/pgwire_types.ts +++ b/packages/jpgwire/src/pgwire_types.ts @@ -1,8 +1,9 @@ // Adapted from https://github.com/kagis/pgwire/blob/0dc927f9f8990a903f238737326e53ba1c8d094f/mod.js#L2218 import { JsonContainer } from '@powersync/service-jsonbig'; -import { TimeValue, type DatabaseInputValue } from '@powersync/service-sync-rules'; +import { CustomSqliteValue, TimeValue, type DatabaseInputValue } from '@powersync/service-sync-rules'; import { dateToSqlite, lsnMakeComparable, timestampToSqlite, timestamptzToSqlite } from './util.js'; +import { StructureParser } from './structure_parser.js'; export enum PgTypeOid { TEXT = 25, @@ -27,7 +28,7 @@ export enum PgTypeOid { // Generate using: // select '[' || typarray || ', ' || oid || '], // ' || typname from pg_catalog.pg_type WHERE typarray != 0; -const ARRAY_TO_ELEM_OID = new Map([ +export const ARRAY_TO_ELEM_OID = new Map([ [1000, 16], // bool [1001, 17], // bytea [1002, 18], // char @@ -141,52 +142,21 @@ export class PgType { case PgTypeOid.PG_LSN: return lsnMakeComparable(text); } - const elemTypeid = this._elemTypeOid(typeOid); + const elemTypeid = this.elemTypeOid(typeOid); if (elemTypeid != null) { return this._decodeArray(text, elemTypeid); } return text; // unknown type } - static _elemTypeOid(arrayTypeOid: number): number | undefined { + static elemTypeOid(arrayTypeOid: number): number | undefined { // select 'case ' || typarray || ': return ' || oid || '; // ' || typname from pg_catalog.pg_type WHERE typarray != 0; return ARRAY_TO_ELEM_OID.get(arrayTypeOid); } - static _decodeArray(text: string, elemTypeOid: number): any { + static _decodeArray(text: string, elemTypeOid: number): DatabaseInputValue[] { text = text.replace(/^\[.+=/, ''); // skip dimensions - let result: any; - for (let i = 0, inQuotes = false, elStart = 0, stack: any[] = []; i < text.length; i++) { - const ch = text.charCodeAt(i); - if (ch == 0x5c /*\*/) { - i++; // escape - } else if (ch == 0x22 /*"*/) { - inQuotes = !inQuotes; - } else if (inQuotes) { - } else if (ch == 0x7b /*{*/) { - // continue - stack.unshift([]), (elStart = i + 1); - } else if (ch == 0x7d /*}*/ || ch == 0x2c /*,*/) { - // TODO configurable delimiter - // TODO ensure .slice is cheap enough to do it unconditionally - const escaped = text.slice(elStart, i); // TODO trim ' \t\n\r\v\f' - if (result) { - stack[0].push(result); - } else if (/^NULL$/i.test(escaped)) { - stack[0].push(null); - } else if (escaped.length) { - const unescaped = escaped.replace(/^"|"$|(? PgType.decode(raw, elemTypeOid)); } static _decodeBytea(text: string): Uint8Array { diff --git a/packages/jpgwire/src/structure_parser.ts b/packages/jpgwire/src/structure_parser.ts new file mode 100644 index 000000000..a9acfeba0 --- /dev/null +++ b/packages/jpgwire/src/structure_parser.ts @@ -0,0 +1,300 @@ +import { delimiter } from 'path'; + +/** + * Utility to parse encoded structural values, such as arrays, composite types, ranges and multiranges. + */ +export class StructureParser { + private offset: number; + + constructor(readonly source: string) { + this.offset = 0; + } + + private currentCharCode(): number { + return this.source.charCodeAt(this.offset); + } + + private get isAtEnd(): boolean { + return this.offset == this.source.length; + } + + private checkNotAtEnd() { + if (this.isAtEnd) { + this.error('Unexpected end of input'); + } + } + + private error(msg: string): never { + throw new Error(`Error decoding Postgres sequence at position ${this.offset}: ${msg}`); + } + + private check(expected: number) { + if (this.currentCharCode() != expected) { + this.error(`Expected ${String.fromCharCode(expected)}, got ${String.fromCharCode(this.currentCharCode())}`); + } + } + + private peek(): number { + this.checkNotAtEnd(); + + return this.source.charCodeAt(this.offset + 1); + } + + private advance() { + this.checkNotAtEnd(); + this.offset++; + } + + private consume(expected: number) { + this.check(expected); + this.advance(); + } + + private maybeConsume(expected: number): boolean { + if (this.currentCharCode() == expected) { + this.advance(); + return true; + } else { + return false; + } + } + + /** + * Assuming that the current position contains a opening double quote for an escaped string, parses the value until + * the closing quote. + * + * The returned value applies escape characters, so `"foo\"bar"` would return the string `foo"bar"`. + */ + private quotedString(allowEscapingWithDoubleDoubleQuote: boolean = false): string { + this.consume(CHAR_CODE_DOUBLE_QUOTE); + + const start = this.offset; + const charCodes: number[] = []; + let previousWasBackslash = false; + + while (true) { + const char = this.currentCharCode(); + + if (previousWasBackslash) { + if (char != CHAR_CODE_DOUBLE_QUOTE && char != CHAR_CODE_BACKSLASH) { + this.error('Expected escaped double quote or escaped backslash'); + } + charCodes.push(char); + previousWasBackslash = false; + } else if (char == CHAR_CODE_DOUBLE_QUOTE) { + if (this.offset != start && allowEscapingWithDoubleDoubleQuote) { + // If the next character is also a double quote, that escapes a single double quote + if (this.offset < this.source.length - 1 && this.peek() == CHAR_CODE_DOUBLE_QUOTE) { + this.offset += 2; + charCodes.push(CHAR_CODE_DOUBLE_QUOTE); + continue; + } + } + + break; // End of string. + } else if (char == CHAR_CODE_BACKSLASH) { + previousWasBackslash = true; + } else { + charCodes.push(char); + } + + this.advance(); + } + + this.consume(CHAR_CODE_DOUBLE_QUOTE); + return String.fromCharCode(...charCodes); + } + + unquotedString(endedBy: number[], illegal: number[]): string { + const start = this.offset; + this.advance(); + + let next = this.currentCharCode(); + while (endedBy.indexOf(next) == -1) { + if (illegal.indexOf(next) != -1) { + this.error('illegal char, should require escaping'); + } + + this.advance(); + next = this.currentCharCode(); + } + + return this.source.substring(start, this.offset); + } + + checkAtEnd() { + if (this.offset < this.source.length) { + this.error('Unexpected trailing text'); + } + } + + // https://www.postgresql.org/docs/current/arrays.html#ARRAYS-IO + parseArray(parseElement: (value: string) => T, delimiter: number = CHAR_CODE_COMMA): ElementOrArray[] { + const array = this.parseArrayInner(delimiter, parseElement); + this.checkAtEnd(); + return array; + } + + // Recursively parses a (potentially multi-dimensional) array. + private parseArrayInner(delimiter: number, parseElement: (value: string) => T): ElementOrArray[] { + this.consume(CHAR_CODE_LEFT_BRACE); + if (this.maybeConsume(CHAR_CODE_RIGHT_BRACE)) { + return []; // Empty array ({}) + } + + const elements: ElementOrArray[] = []; + do { + // Parse a value in the array. This can either be an escaped string, an unescaped string, or a nested array. + const currentChar = this.currentCharCode(); + if (currentChar == CHAR_CODE_LEFT_BRACE) { + // Nested array + elements.push(this.parseArrayInner(delimiter, parseElement)); + } else if (currentChar == CHAR_CODE_DOUBLE_QUOTE) { + elements.push(parseElement(this.quotedString())); + } else { + const value = this.unquotedString( + [delimiter, CHAR_CODE_RIGHT_BRACE], + [CHAR_CODE_DOUBLE_QUOTE, CHAR_CODE_LEFT_BRACE] + ); + elements.push(value == 'NULL' ? null : parseElement(value)); + } + } while (this.maybeConsume(delimiter)); + + this.consume(CHAR_CODE_RIGHT_BRACE); + return elements; + } + + // https://www.postgresql.org/docs/current/rowtypes.html#ROWTYPES-IO-SYNTAX + parseComposite(onElement: (value: string | null) => void) { + this.consume(CHAR_CODE_LEFT_PAREN); + do { + // Parse a composite value. This can either be an escaped string, an unescaped string, or an empty string. + const currentChar = this.currentCharCode(); + if (currentChar == CHAR_CODE_COMMA) { + // Empty value. The comma is consumed by the while() below. + onElement(null); + } else if (currentChar == CHAR_CODE_RIGHT_PAREN) { + // Empty value before end. The right parent is consumed by the line after the loop. + onElement(null); + } else if (currentChar == CHAR_CODE_DOUBLE_QUOTE) { + onElement(this.quotedString(true)); + } else { + const value = this.unquotedString( + [CHAR_CODE_COMMA, CHAR_CODE_RIGHT_PAREN], + [CHAR_CODE_DOUBLE_QUOTE, CHAR_CODE_LEFT_PAREN] + ); + onElement(value); + } + } while (this.maybeConsume(CHAR_CODE_COMMA)); + this.consume(CHAR_CODE_RIGHT_PAREN); + this.checkAtEnd(); + } + + // https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-IO + private parseRangeInner(parseInner: (value: string) => T): Range { + const empty = 'empty'; + + // Parse [ or ( to start the range + let lowerBoundExclusive; + switch (this.currentCharCode()) { + case CHAR_CODE_LEFT_PAREN: + lowerBoundExclusive = true; + this.advance(); + break; + case CHAR_CODE_LEFT_BRACKET: + lowerBoundExclusive = false; + this.advance(); + break; + case empty.charCodeAt(0): + // Consume the string "empty" + for (let i = 0; i < empty.length; i++) { + this.consume(empty.charCodeAt(i)); + } + return empty; + default: + this.error('Expected [, ( or string empty'); + } + + // Parse value until comma (which may be empty) + let lower = null; + if (this.currentCharCode() == CHAR_CODE_DOUBLE_QUOTE) { + lower = parseInner(this.quotedString()); + } else if (this.currentCharCode() != CHAR_CODE_COMMA) { + lower = parseInner(this.unquotedString([CHAR_CODE_COMMA], [])); + } + + this.consume(CHAR_CODE_COMMA); + + let upper = null; + if (this.currentCharCode() == CHAR_CODE_DOUBLE_QUOTE) { + upper = parseInner(this.quotedString()); + } else if (this.currentCharCode() != CHAR_CODE_RIGHT_PAREN && this.currentCharCode() != CHAR_CODE_RIGHT_BRACKET) { + upper = parseInner(this.unquotedString([CHAR_CODE_RIGHT_PAREN, CHAR_CODE_RIGHT_BRACKET], [])); + } + + let upperBoundExclusive; + switch (this.currentCharCode()) { + case CHAR_CODE_RIGHT_PAREN: + upperBoundExclusive = true; + this.advance(); + break; + case CHAR_CODE_RIGHT_BRACKET: + upperBoundExclusive = false; + this.advance(); + break; + default: + this.error('Expected ] or )'); + } + + return { + lower: lower, + upper: upper, + lower_exclusive: lowerBoundExclusive, + upper_exclusive: upperBoundExclusive + }; + } + + parseRange(parseInner: (value: string) => T): Range { + const range = this.parseRangeInner(parseInner); + this.checkAtEnd(); + return range; + } + + parseMultiRange(parseInner: (value: string) => T): Range[] { + this.consume(CHAR_CODE_LEFT_BRACE); + if (this.maybeConsume(CHAR_CODE_RIGHT_BRACE)) { + return []; + } + + const values: Range[] = []; + do { + values.push(this.parseRangeInner(parseInner)); + } while (this.maybeConsume(CHAR_CODE_COMMA)); + + this.consume(CHAR_CODE_RIGHT_BRACE); + this.checkAtEnd(); + return values; + } +} + +export type Range = + | { + lower: T | null; + upper: T | null; + lower_exclusive: boolean; + upper_exclusive: boolean; + } + | 'empty'; + +export type ElementOrArray = null | T | ElementOrArray[]; + +const CHAR_CODE_DOUBLE_QUOTE = 0x22; +const CHAR_CODE_BACKSLASH = 0x5c; +export const CHAR_CODE_COMMA = 0x2c; +export const CHAR_CODE_SEMICOLON = 0x3b; +const CHAR_CODE_LEFT_BRACE = 0x7b; +const CHAR_CODE_RIGHT_BRACE = 0x7d; +const CHAR_CODE_LEFT_PAREN = 0x28; +const CHAR_CODE_RIGHT_PAREN = 0x29; +const CHAR_CODE_LEFT_BRACKET = 0x5b; +const CHAR_CODE_RIGHT_BRACKET = 0x5d; diff --git a/packages/jpgwire/test/structure_parser.test.ts b/packages/jpgwire/test/structure_parser.test.ts new file mode 100644 index 000000000..d8dbbfc41 --- /dev/null +++ b/packages/jpgwire/test/structure_parser.test.ts @@ -0,0 +1,203 @@ +import { describe, expect, test } from 'vitest'; +import { CHAR_CODE_COMMA, CHAR_CODE_SEMICOLON, StructureParser } from '../src/index'; + +describe('StructureParser', () => { + describe('array', () => { + const parseArray = (source: string, delimiter: number = CHAR_CODE_COMMA) => { + return new StructureParser(source).parseArray((source) => source, delimiter); + }; + + test('empty', () => { + expect(parseArray('{}')).toStrictEqual([]); + }); + + test('regular', () => { + expect(parseArray('{foo,bar}')).toStrictEqual(['foo', 'bar']); + }); + + test('custom delimiter', () => { + expect(parseArray('{foo;bar}', CHAR_CODE_SEMICOLON)).toStrictEqual(['foo', 'bar']); + }); + + test('null elements', () => { + expect(parseArray('{null}')).toStrictEqual(['null']); + expect(parseArray('{NULL}')).toStrictEqual([null]); + }); + + test('escaped', () => { + expect(parseArray('{""}')).toStrictEqual(['']); + expect(parseArray('{"foo"}')).toStrictEqual(['foo']); + expect(parseArray('{"fo\\"o"}')).toStrictEqual(['fo"o']); + expect(parseArray('{"fo\\\\o"}')).toStrictEqual(['fo\\o']); + }); + + test('nested', () => { + expect(parseArray('{0,{0,{}}}')).toStrictEqual(['0', ['0', []]]); + }); + + test('trailing data', () => { + expect(() => parseArray('{foo}bar')).toThrow(/Unexpected trailing text/); + }); + + test('unclosed array', () => { + expect(() => parseArray('{')).toThrow(/Unexpected end of input/); + }); + + test('improper escaped string', () => { + expect(() => parseArray('{foo,"bar}')).toThrow(/Unexpected end of input/); + }); + + test('illegal escape sequence', () => { + expect(() => parseArray('{foo,"b\\ar"}')).toThrow(/Expected escaped double quote or escaped backslash/); + }); + + test('illegal delimiter in value', () => { + expect(() => parseArray('{foo{}')).toThrow(/illegal char, should require escaping/); + }); + + test('illegal quote in value', () => { + expect(() => parseArray('{foo"}')).toThrow(/illegal char, should require escaping/); + }); + }); + + describe('composite', () => { + const parseComposite = (source: string) => { + const events: any[] = []; + new StructureParser(source).parseComposite((e) => events.push(e)); + return events; + }; + + test('empty composite', () => { + // Both of the following render as '()': + // create type foo as (); select ROW()::foo; create type foo2 as (foo integer); + // SELECT ROW()::foo, ROW(NULL)::foo2; + // Here, we resolve the ambiguity by parsing () as an one-element composite - callers need to be aware of this. + expect(parseComposite('()')).toStrictEqual([null]); + }); + + test('only null entries', () => { + expect(parseComposite('(,)')).toStrictEqual([null, null]); + expect(parseComposite('(,,)')).toStrictEqual([null, null, null]); + }); + + test('null before element', () => { + expect(parseComposite('(,foo)')).toStrictEqual([null, 'foo']); + }); + + test('null after element', () => { + expect(parseComposite('(foo,)')).toStrictEqual(['foo', null]); + }); + + test('nested', () => { + expect(parseComposite('(foo,bar,{baz})')).toStrictEqual(['foo', 'bar', '{baz}']); + }); + + test('escaped strings', () => { + expect(parseComposite('("foo""bar")')).toStrictEqual(['foo"bar']); + expect(parseComposite('("")')).toStrictEqual(['']); + }); + }); + + describe('range', () => { + const parseIntRange = (source: string) => { + return new StructureParser(source).parseRange((source) => Number(source)); + }; + + test('empty', () => { + // select '(3, 3)'::int4range + expect(parseIntRange('empty')).toStrictEqual('empty'); + }); + + test('regular', () => { + expect(parseIntRange('[1,2]')).toStrictEqual({ + lower: 1, + upper: 2, + lower_exclusive: false, + upper_exclusive: false + }); + expect(parseIntRange('[1,2)')).toStrictEqual({ + lower: 1, + upper: 2, + lower_exclusive: false, + upper_exclusive: true + }); + expect(parseIntRange('(1,2]')).toStrictEqual({ + lower: 1, + upper: 2, + lower_exclusive: true, + upper_exclusive: false + }); + expect(parseIntRange('(1,2)')).toStrictEqual({ + lower: 1, + upper: 2, + lower_exclusive: true, + upper_exclusive: true + }); + }); + + test('no lower bound', () => { + expect(parseIntRange('(,3]')).toStrictEqual({ + lower: null, + upper: 3, + lower_exclusive: true, + upper_exclusive: false + }); + }); + + test('no upper bound', () => { + expect(parseIntRange('(3,]')).toStrictEqual({ + lower: 3, + upper: null, + lower_exclusive: true, + upper_exclusive: false + }); + }); + + test('no bounds', () => { + expect(parseIntRange('(,)')).toStrictEqual({ + lower: null, + upper: null, + lower_exclusive: true, + upper_exclusive: true + }); + }); + }); + + describe('multirange', () => { + const parseIntMultiRange = (source: string) => { + return new StructureParser(source).parseMultiRange((source) => Number(source)); + }; + + test('empty', () => { + expect(parseIntMultiRange('{}')).toStrictEqual([]); + }); + + test('single', () => { + expect(parseIntMultiRange('{[3,7)}')).toStrictEqual([ + { + lower: 3, + upper: 7, + lower_exclusive: false, + upper_exclusive: true + } + ]); + }); + + test('multiple', () => { + expect(parseIntMultiRange('{[3,7),[8,9)}')).toStrictEqual([ + { + lower: 3, + upper: 7, + lower_exclusive: false, + upper_exclusive: true + }, + { + lower: 8, + upper: 9, + lower_exclusive: false, + upper_exclusive: true + } + ]); + }); + }); +}); diff --git a/packages/jpgwire/test/tsconfig.json b/packages/jpgwire/test/tsconfig.json new file mode 100644 index 000000000..5b5f74483 --- /dev/null +++ b/packages/jpgwire/test/tsconfig.json @@ -0,0 +1,17 @@ +{ + "extends": "../../../tsconfig.base.json", + "compilerOptions": { + "rootDir": "src", + "noEmit": true, + "baseUrl": "./", + "esModuleInterop": true, + "skipLibCheck": true, + "sourceMap": true + }, + "include": ["src"], + "references": [ + { + "path": "../" + } + ] +} diff --git a/packages/sync-rules/src/ExpressionType.ts b/packages/sync-rules/src/ExpressionType.ts index 3e625c717..9e120cb36 100644 --- a/packages/sync-rules/src/ExpressionType.ts +++ b/packages/sync-rules/src/ExpressionType.ts @@ -77,7 +77,7 @@ export class ExpressionType { } /** - * Here only for backwards-compatibility only. + * @deprecated Here only for backwards-compatibility only. */ export function expressionTypeFromPostgresType(type: string | undefined): ExpressionType { if (type?.endsWith('[]')) { diff --git a/packages/sync-rules/src/compatibility.ts b/packages/sync-rules/src/compatibility.ts index 091365b55..24ba9738b 100644 --- a/packages/sync-rules/src/compatibility.ts +++ b/packages/sync-rules/src/compatibility.ts @@ -34,10 +34,17 @@ export class CompatibilityOption { CompatibilityEdition.SYNC_STREAMS ); + static customTypes = new CompatibilityOption( + 'custom_postgres_types', + 'Map custom Postgres types into appropriate structures instead of syncing the raw string.', + CompatibilityEdition.SYNC_STREAMS + ); + static byName: Record = Object.freeze({ timestamps_iso8601: this.timestampsIso8601, versioned_bucket_ids: this.versionedBucketIds, - fixed_json_extract: this.fixedJsonExtract + fixed_json_extract: this.fixedJsonExtract, + custom_postgres_types: this.customTypes }); } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index d24c9a0e9..3541c5456 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -428,6 +428,9 @@ importers: pgwire: specifier: github:kagis/pgwire#f1cb95f9a0f42a612bb5a6b67bb2eb793fc5fc87 version: https://codeload.github.com/kagis/pgwire/tar.gz/f1cb95f9a0f42a612bb5a6b67bb2eb793fc5fc87 + vitest: + specifier: ^3.0.5 + version: 3.0.5(@types/node@22.16.2)(yaml@2.5.0) packages/jsonbig: dependencies: