diff --git a/cipher b/cipher new file mode 100644 index 0000000..356b2f8 Binary files /dev/null and b/cipher differ diff --git a/src/auth.ts b/src/auth.ts index fece8b7..63a86fe 100644 --- a/src/auth.ts +++ b/src/auth.ts @@ -1,10 +1,5 @@ import { createHash, encode, SupportedAlgorithm } from "../deps.ts"; - -function xor(a: Uint8Array, b: Uint8Array): Uint8Array { - return a.map((byte, index) => { - return byte ^ b[index]; - }); -} +import { xor } from "./util.ts"; function hash(algorithm: SupportedAlgorithm, data: Uint8Array): Uint8Array { return new Uint8Array(createHash(algorithm).update(data).digest()); @@ -39,8 +34,7 @@ export default function auth( return mysqlNativePassword(password, seed); case "caching_sha2_password": - // TODO - // return cachingSha2Password(password, seed); + return cachingSha2Password(password, seed); default: throw new Error("Not supported"); } diff --git a/src/auth_plugin/caching_sha2_password.ts b/src/auth_plugin/caching_sha2_password.ts new file mode 100644 index 0000000..c3ed0df --- /dev/null +++ b/src/auth_plugin/caching_sha2_password.ts @@ -0,0 +1,69 @@ +import { xor } from "../util.ts"; +import { ReceivePacket } from "../packets/packet.ts"; +import { encryptWithPublicKey } from "./crypt.ts"; + +interface handler { + done: boolean; + quickRead?: boolean; + next?: (packet: ReceivePacket) => any; + data?: Uint8Array; +} + +let scramble: Uint8Array, password: string; +function start(scramble_: Uint8Array, password_: string): handler { + scramble = scramble_; + password = password_; + return { done: false, next: authMoreResponse }; +} +function authMoreResponse(packet: ReceivePacket): handler { + const enum AuthStatusFlags { + FullAuth = 0x04, + FastPath = 0x03, + } + const REQUEST_PUBLIC_KEY = 0x02; + const statusFlag = packet.body.skip(1).readUint8(); + let authMoreData, done = true, next, quickRead = false; + if (statusFlag === AuthStatusFlags.FullAuth) { + authMoreData = new Uint8Array([REQUEST_PUBLIC_KEY]); + done = false; + next = encryptWithKey; + } + if (statusFlag === AuthStatusFlags.FastPath) { + done = false; + quickRead = true; + next = terminate; + } + return { done, next, quickRead, data: authMoreData }; +} + +function encryptWithKey(packet: ReceivePacket): handler { + const publicKey = parsePublicKey(packet); + const len = password.length; + let passwordBuffer: Uint8Array = new Uint8Array(len + 1); + for (let n = 0; n < len; n++) { + passwordBuffer[n] = password.charCodeAt(n); + } + passwordBuffer[len] = 0x00; + + const encryptedPassword = encrypt(passwordBuffer, scramble, publicKey); + return { done: false, next: terminate, data: encryptedPassword }; +} + +function parsePublicKey(packet: ReceivePacket): string { + return packet.body.skip(1).readNullTerminatedString(); +} +function encrypt( + password: Uint8Array, + scramble: Uint8Array, + key: string, +): Uint8Array { + const stage1 = xor(password, scramble); + const encrypted = encryptWithPublicKey(key, stage1); + return encrypted; +} + +function terminate() { + return { done: true }; +} + +export { start }; diff --git a/src/auth_plugin/crypt.ts b/src/auth_plugin/crypt.ts new file mode 100644 index 0000000..6e12394 --- /dev/null +++ b/src/auth_plugin/crypt.ts @@ -0,0 +1,7 @@ +import { RSA } from "https://deno.land/x/god_crypto@v0.2.0/mod.ts"; +function encryptWithPublicKey(key: string, data: Uint8Array): Uint8Array { + const publicKey = RSA.parseKey(key); + return RSA.encrypt(data, publicKey); +} + +export { encryptWithPublicKey }; diff --git a/src/auth_plugin/index.ts b/src/auth_plugin/index.ts new file mode 100644 index 0000000..198e023 --- /dev/null +++ b/src/auth_plugin/index.ts @@ -0,0 +1,4 @@ +import * as caching_sha2_password from "./caching_sha2_password.ts"; +export default { + caching_sha2_password, +}; diff --git a/src/connection.ts b/src/connection.ts index 0acc301..8e36347 100644 --- a/src/connection.ts +++ b/src/connection.ts @@ -1,4 +1,4 @@ -import { delay } from "../deps.ts"; +import { byteFormat, delay } from "../deps.ts"; import { ClientConfig } from "./client.ts"; import { ConnnectionError, @@ -11,8 +11,14 @@ import { buildAuth } from "./packets/builders/auth.ts"; import { buildQuery } from "./packets/builders/query.ts"; import { ReceivePacket, SendPacket } from "./packets/packet.ts"; import { parseError } from "./packets/parsers/err.ts"; -import { parseHandshake } from "./packets/parsers/handshake.ts"; +import { + AuthResult, + parseAuth, + parseHandshake, +} from "./packets/parsers/handshake.ts"; import { FieldInfo, parseField, parseRow } from "./packets/parsers/result.ts"; +import { PacketType } from "./constant/packet.ts"; +import authPlugin from "./auth_plugin/index.ts"; /** * Connection state @@ -53,7 +59,8 @@ export class Connection { private async _connect() { // TODO: implement connect timeout - const { hostname, port = 3306, socketPath } = this.config; + const { hostname, port = 3306, socketPath, username = "", password } = + this.config; log.info(`connecting ${this.remoteAddr}`); this.conn = !socketPath ? await Deno.connect({ @@ -70,16 +77,51 @@ export class Connection { let receive = await this.nextPacket(); const handshakePacket = parseHandshake(receive.body); const data = buildAuth(handshakePacket, { - username: this.config.username ?? "", - password: this.config.password, + username, + password, db: this.config.db, }); + await new SendPacket(data, 0x1).send(this.conn); + this.state = ConnectionState.CONNECTING; this.serverVersion = handshakePacket.serverVersion; this.capabilities = handshakePacket.serverCapabilities; receive = await this.nextPacket(); + + const authResult = parseAuth(receive); + let handler; + + switch (authResult) { + case AuthResult.AuthMoreRequired: + const adaptedPlugin = + (authPlugin as any)[handshakePacket.authPluginName]; + handler = adaptedPlugin; + break; + case AuthResult.MethodMismatch: + // TODO: Negotiate + throw new Error("Currently cannot support auth method mismatch!"); + } + + let result; + if (handler) { + result = handler.start(handshakePacket.seed, password!); + while (!result.done) { + if (result.data) { + const sequenceNumber = receive.header.no + 1; + await new SendPacket(result.data, sequenceNumber).send(this.conn); + receive = await this.nextPacket(); + } + if (result.quickRead) { + await this.nextPacket(); + } + if (result.next) { + result = result.next(receive); + } + } + } + const header = receive.body.readUint8(); if (header === 0xff) { const error = parseError(receive.body, this); @@ -137,7 +179,7 @@ export class Connection { this.close(); throw new ReadError("Connection closed unexpectedly"); } - if (packet.type === "ERR") { + if (packet.type === PacketType.ERR_Packet) { packet.body.skip(1); const error = parseError(packet.body, this); throw new Error(error.message); @@ -212,13 +254,13 @@ export class Connection { try { await new SendPacket(data, 0).send(this.conn!); let receive = await this.nextPacket(); - if (receive.type === "OK") { + if (receive.type === PacketType.OK_Packet) { receive.body.skip(1); return { affectedRows: receive.body.readEncodedLen(), lastInsertId: receive.body.readEncodedLen(), }; - } else if (receive.type !== "RESULT") { + } else if (receive.type !== PacketType.Result) { throw new ProtocolError(); } let fieldCount = receive.body.readEncodedLen(); @@ -235,14 +277,14 @@ export class Connection { if (this.lessThan57()) { // EOF(less than 5.7) receive = await this.nextPacket(); - if (receive.type !== "EOF") { + if (receive.type !== PacketType.EOF_Packet) { throw new ProtocolError(); } } while (true) { receive = await this.nextPacket(); - if (receive.type === "EOF") { + if (receive.type === PacketType.EOF_Packet) { break; } else { const row = parseRow(receive.body, fields); diff --git a/src/constant/capabilities.ts b/src/constant/capabilities.ts index 12ab767..6477e1a 100644 --- a/src/constant/capabilities.ts +++ b/src/constant/capabilities.ts @@ -11,6 +11,10 @@ enum ServerCapabilities { CLIENT_SECURE_CONNECTION = 0x8000, CLIENT_FOUND_ROWS = 0x00000002, CLIENT_CONNECT_ATTRS = 0x00100000, + CLIENT_IGNORE_SPACE = 0x00000100, + CLIENT_IGNORE_SIGPIPE = 0x00001000, + CLIENT_RESERVED = 0x00004000, + CLIENT_PS_MULTI_RESULTS = 0x00040000, } export default ServerCapabilities; diff --git a/src/constant/packet.ts b/src/constant/packet.ts new file mode 100644 index 0000000..715e411 --- /dev/null +++ b/src/constant/packet.ts @@ -0,0 +1,6 @@ +export enum PacketType { + OK_Packet = 0x00, + EOF_Packet = 0xfe, + ERR_Packet = 0xff, + Result, +} diff --git a/src/packets/packet.ts b/src/packets/packet.ts index 6d42a77..b0f7c7d 100644 --- a/src/packets/packet.ts +++ b/src/packets/packet.ts @@ -2,6 +2,7 @@ import { byteFormat } from "../../deps.ts"; import { BufferReader, BufferWriter } from "../buffer.ts"; import { WriteError } from "../constant/errors.ts"; import { debug, log } from "../logger.ts"; +import { PacketType } from "../../src/constant/packet.ts"; /** @ignore */ interface PacketHeader { @@ -39,7 +40,7 @@ export class SendPacket { export class ReceivePacket { header!: PacketHeader; body!: BufferReader; - type!: "EOF" | "OK" | "ERR" | "RESULT"; + type!: PacketType; async parse(reader: Deno.Reader): Promise { const header = new BufferReader(new Uint8Array(4)); @@ -57,18 +58,19 @@ export class ReceivePacket { if (nread === null) return null; readCount += nread; + const { OK_Packet, ERR_Packet, EOF_Packet, Result } = PacketType; switch (this.body.buffer[0]) { - case 0x00: - this.type = "OK"; + case OK_Packet: + this.type = OK_Packet; break; case 0xff: - this.type = "ERR"; + this.type = ERR_Packet; break; case 0xfe: - this.type = "EOF"; + this.type = EOF_Packet; break; default: - this.type = "RESULT"; + this.type = Result; break; } diff --git a/src/packets/parsers/handshake.ts b/src/packets/parsers/handshake.ts index 5d4663b..959e028 100644 --- a/src/packets/parsers/handshake.ts +++ b/src/packets/parsers/handshake.ts @@ -1,5 +1,7 @@ import { BufferReader, BufferWriter } from "../../buffer.ts"; import ServerCapabilities from "../../constant/capabilities.ts"; +import { PacketType } from "../../constant/packet.ts"; +import { ReceivePacket } from "../packet.ts"; /** @ignore */ export interface HandshakeBody { @@ -65,3 +67,21 @@ export function parseHandshake(reader: BufferReader): HandshakeBody { authPluginName, }; } + +export enum AuthResult { + AuthPassed, + MethodMismatch, + AuthMoreRequired, +} +export function parseAuth(packet: ReceivePacket): AuthResult { + switch (packet.type) { + case PacketType.EOF_Packet: + return AuthResult.MethodMismatch; + case PacketType.Result: + return AuthResult.AuthMoreRequired; + case PacketType.OK_Packet: + return AuthResult.AuthPassed; + default: + return AuthResult.AuthPassed; + } +} diff --git a/src/util.ts b/src/util.ts new file mode 100644 index 0000000..1e4efd9 --- /dev/null +++ b/src/util.ts @@ -0,0 +1,5 @@ +export function xor(a: Uint8Array, b: Uint8Array): Uint8Array { + return a.map((byte, index) => { + return byte ^ b[index]; + }); +}