Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement necessary exported types for compatibility with Teleport #5

Open
wants to merge 1 commit into
base: teleport.1
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
461 changes: 238 additions & 223 deletions azuread/configuration.go

Large diffs are not rendered by default.

12 changes: 11 additions & 1 deletion azuread/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (
"context"
"database/sql"
"database/sql/driver"

mssql "github.com/microsoft/go-mssqldb"
"github.com/microsoft/go-mssqldb/msdsn"
)

// DriverName is the name used to register the driver
Expand Down Expand Up @@ -43,6 +43,16 @@ func NewConnector(dsn string) (*mssql.Connector, error) {
return newConnectorConfig(config)
}

// NewConnectorFromConfig returns a new connector with the provided configuration and additional parameters
func NewConnectorFromConfig(dsnConfig msdsn.Config, params map[string]string) (*mssql.Connector, error) {
config, err := newConfig(dsnConfig, params)
if err != nil {
return nil, err
}

return newConnectorConfig(config)
}

// newConnectorConfig creates a Connector from config.
func newConnectorConfig(config *azureFedAuthConfig) (*mssql.Connector, error) {
switch config.fedAuthLibrary {
Expand Down
10 changes: 10 additions & 0 deletions buf.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ var bufpool = sync.Pool{
},
}

type TDSBuffer = tdsBuffer

// tdsBuffer reads and writes TDS packets of data to the transport.
// The write and read buffers are separate to make sending attn signals
// possible without locks. Currently attn signals are only sent during
Expand Down Expand Up @@ -59,6 +61,14 @@ type tdsBuffer struct {
afterFirst func()
}

// NewTdsBuffer returns an exported version of *tdsBuffer
func NewTdsBuffer(buff []byte, size int) *TDSBuffer {
return &tdsBuffer{
rbuf: buff,
rsize: size,
}
}

func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer {

// pull an existing buf if one is available or get and add a new buf to the bufpool
Expand Down
42 changes: 42 additions & 0 deletions error.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package mssql

import (
"bytes"
"database/sql/driver"
"encoding/binary"
"fmt"
)

Expand All @@ -23,6 +25,46 @@ type Error struct {
All []Error
}

// Marshal marshals the error to the wire protocol token.
func (e *Error) Marshal() ([]byte, error) {
buf := bytes.NewBuffer([]byte{
byte(tokenError),
})
length := 2 + // length
4 + // number
1 + // state
1 + // class
(2 + 2*len(e.Message)) + // message
(1 + 2*len(e.ServerName)) + // server name
(1 + 2*len(e.ProcName)) + // proc name
4 // line no
if err := binary.Write(buf, binary.LittleEndian, uint16(length)); err != nil {
return nil, err
}
if err := binary.Write(buf, binary.LittleEndian, e.Number); err != nil {
return nil, err
}
if err := buf.WriteByte(e.State); err != nil {
return nil, err
}
if err := buf.WriteByte(e.Class); err != nil {
return nil, err
}
if err := writeUsVarChar(buf, e.Message); err != nil {
return nil, err
}
if err := writeBVarChar(buf, e.ServerName); err != nil {
return nil, err
}
if err := writeBVarChar(buf, e.ProcName); err != nil {
return nil, err
}
if err := binary.Write(buf, binary.LittleEndian, e.LineNo); err != nil {
return nil, err
}
return buf.Bytes(), nil
}

func (e Error) Error() string {
return "mssql: " + e.Message
}
Expand Down
65 changes: 44 additions & 21 deletions mssql.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import (

// ReturnStatus may be used to return the return value from a proc.
//
// var rs mssql.ReturnStatus
// _, err := db.Exec("theproc", &rs)
// log.Printf("return status = %d", rs)
// var rs mssql.ReturnStatus
// _, err := db.Exec("theproc", &rs)
// log.Printf("return status = %d", rs)
type ReturnStatus int32

var driverInstance = &Driver{processQueryText: true}
Expand Down Expand Up @@ -150,6 +150,12 @@ func NewConnectorConfig(config msdsn.Config) *Connector {
}
}

type auth interface {
InitialBytes() ([]byte, error)
NextBytes([]byte) ([]byte, error)
Free()
}

// Connector holds the parsed DSN and is ready to make a new connection
// at any time.
//
Expand All @@ -169,6 +175,9 @@ type Connector struct {
// callback that can provide a security token during ADAL login
adalTokenProvider func(ctx context.Context, serverSPN, stsURL string) (string, error)

// auth allows to provide a custom authenticator.
auth auth

// SessionInitSQL is executed after marking a given session to be reset.
// When not present, the next query will still reset the session to the
// database defaults.
Expand Down Expand Up @@ -231,6 +240,16 @@ func (c *Conn) IsValid() bool {
return c.connectionGood
}

// GetUnderlyingConn returns underlying raw server connection.
func (c *Conn) GetUnderlyingConn() io.ReadWriteCloser {
return c.sess.buf.transport
}

// GetLoginFlags returns tokens returned by server during login handshake.
func (c *Conn) GetLoginFlags() []Token {
return c.sess.loginFlags
}

// checkBadConn marks the connection as bad based on the characteristics
// of the supplied error. Bad connections will be dropped from the connection
// pool rather than reused.
Expand Down Expand Up @@ -878,22 +897,24 @@ func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
// not a variable length type ok should return false.
// If length is not limited other than system limits, it should return math.MaxInt64.
// The following are examples of returned values for various types:
// TEXT (math.MaxInt64, true)
// varchar(10) (10, true)
// nvarchar(10) (10, true)
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
//
// TEXT (math.MaxInt64, true)
// varchar(10) (10, true)
// nvarchar(10) (10, true)
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
return makeGoLangTypeLength(r.cols[index].ti)
}

// It should return
// the precision and scale for decimal types. If not applicable, ok should be false.
// The following are examples of returned values for various types:
// decimal(38, 4) (38, 4, true)
// int (0, 0, false)
// decimal (math.MaxInt64, math.MaxInt64, true)
//
// decimal(38, 4) (38, 4, true)
// int (0, 0, false)
// decimal (math.MaxInt64, math.MaxInt64, true)
func (r *Rows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
return makeGoLangTypePrecisionScale(r.cols[index].ti)
}
Expand Down Expand Up @@ -1320,22 +1341,24 @@ func (r *Rowsq) ColumnTypeDatabaseTypeName(index int) string {
// not a variable length type ok should return false.
// If length is not limited other than system limits, it should return math.MaxInt64.
// The following are examples of returned values for various types:
// TEXT (math.MaxInt64, true)
// varchar(10) (10, true)
// nvarchar(10) (10, true)
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
//
// TEXT (math.MaxInt64, true)
// varchar(10) (10, true)
// nvarchar(10) (10, true)
// decimal (0, false)
// int (0, false)
// bytea(30) (30, true)
func (r *Rowsq) ColumnTypeLength(index int) (int64, bool) {
return makeGoLangTypeLength(r.cols[index].ti)
}

// It should return
// the precision and scale for decimal types. If not applicable, ok should be false.
// The following are examples of returned values for various types:
// decimal(38, 4) (38, 4, true)
// int (0, 0, false)
// decimal (math.MaxInt64, math.MaxInt64, true)
//
// decimal(38, 4) (38, 4, true)
// int (0, 0, false)
// decimal (math.MaxInt64, math.MaxInt64, true)
func (r *Rowsq) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
return makeGoLangTypePrecisionScale(r.cols[index].ti)
}
Expand Down
28 changes: 26 additions & 2 deletions tds.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ type tdsSession struct {
logger ContextLogger
routedServer string
routedPort uint16
loginFlags []Token
}

const (
Expand All @@ -168,9 +169,23 @@ func (p keySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }

// http://msdn.microsoft.com/en-us/library/dd357559.aspx
func writePrelogin(packetType packetType, w *tdsBuffer, fields map[uint8][]byte) error {
w.BeginPacket(packetType, false)
if err := WritePreLoginFields(w, fields); err != nil {
return err
}
return w.FinishPacket()
}

// Writer is an interface that combines Writer and ByteWriter.
type Writer interface {
io.Writer
io.ByteWriter
}

// WritePreLoginFields writes provided Pre-Login packet fields into the writer.
func WritePreLoginFields(w Writer, fields map[uint8][]byte) error {
var err error

w.BeginPacket(packetType, false)
offset := uint16(5*len(fields) + 1)
keys := make(keySlice, 0, len(fields))
for k := range fields {
Expand Down Expand Up @@ -210,7 +225,7 @@ func writePrelogin(packetType packetType, w *tdsBuffer, fields map[uint8][]byte)
return errors.New("Write method didn't write the whole value")
}
}
return w.FinishPacket()
return nil
}

func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) {
Expand Down Expand Up @@ -1195,6 +1210,15 @@ initiate_connection:
break
}

// Save options returned by the server so callers implementing
// proxies can pass them back to the original client.
switch tok.(type) {
case envChangeStruct, loginAckStruct, doneStruct:
if token, ok := tok.(Token); ok {
sess.loginFlags = append(sess.loginFlags, token)
}
}

switch token := tok.(type) {
case sspiMsg:
sspi_msg, err := auth.NextBytes(token)
Expand Down
Loading