forgejo/vendor/github.com/denisenkom/go-mssqldb/buf.go
btrepp 25b5ffb6af Enables mssql support ()
* Enables mssql support

Port of dlobs work in gogs.
Enables options in index.js
Enables MSSQL as a database option in go.
Sets ID to 0 on initial migration. Required for
MSSQL insert statements.

Signed-off-by: Beau Trepp <beautrepp@gmail.com>

* Vendors in denisenkom/go-mssqldb

Includes golang.org/x/crypto/md4
as this is required by go-msssqldb

Signed-off-by: Beau Trepp <beautrepp@gmail.com>
2016-12-24 09:37:35 +08:00

222 lines
4 KiB
Go

package mssql
import (
"encoding/binary"
"io"
"errors"
)
type header struct {
PacketType uint8
Status uint8
Size uint16
Spid uint16
PacketNo uint8
Pad uint8
}
type tdsBuffer struct {
buf []byte
pos uint16
transport io.ReadWriteCloser
size uint16
final bool
packet_type uint8
afterFirst func()
}
func newTdsBuffer(bufsize int, transport io.ReadWriteCloser) *tdsBuffer {
buf := make([]byte, bufsize)
w := new(tdsBuffer)
w.buf = buf
w.pos = 8
w.transport = transport
w.size = 0
return w
}
func (w *tdsBuffer) flush() (err error) {
// writing packet size
binary.BigEndian.PutUint16(w.buf[2:], w.pos)
// writing packet into underlying transport
if _, err = w.transport.Write(w.buf[:w.pos]); err != nil {
return err
}
// execute afterFirst hook if it is set
if w.afterFirst != nil {
w.afterFirst()
w.afterFirst = nil
}
w.pos = 8
// packet number
w.buf[6] += 1
return nil
}
func (w *tdsBuffer) Write(p []byte) (total int, err error) {
total = 0
for {
copied := copy(w.buf[w.pos:], p)
w.pos += uint16(copied)
total += copied
if copied == len(p) {
break
}
if err = w.flush(); err != nil {
return
}
p = p[copied:]
}
return
}
func (w *tdsBuffer) WriteByte(b byte) error {
if int(w.pos) == len(w.buf) {
if err := w.flush(); err != nil {
return err
}
}
w.buf[w.pos] = b
w.pos += 1
return nil
}
func (w *tdsBuffer) BeginPacket(packet_type byte) {
w.buf[0] = packet_type
w.buf[1] = 0 // packet is incomplete
w.buf[4] = 0 // spid
w.buf[5] = 0
w.buf[6] = 1 // packet id
w.buf[7] = 0 // window
w.pos = 8
}
func (w *tdsBuffer) FinishPacket() error {
w.buf[1] = 1 // this is last packet
return w.flush()
}
func (r *tdsBuffer) readNextPacket() error {
header := header{}
var err error
err = binary.Read(r.transport, binary.BigEndian, &header)
if err != nil {
return err
}
offset := uint16(binary.Size(header))
if int(header.Size) > len(r.buf) {
return errors.New("Invalid packet size, it is longer than buffer size")
}
if int(offset) > int(header.Size) {
return errors.New("Invalid packet size, it is shorter than header size")
}
_, err = io.ReadFull(r.transport, r.buf[offset:header.Size])
if err != nil {
return err
}
r.pos = offset
r.size = header.Size
r.final = header.Status != 0
r.packet_type = header.PacketType
return nil
}
func (r *tdsBuffer) BeginRead() (uint8, error) {
err := r.readNextPacket()
if err != nil {
return 0, err
}
return r.packet_type, nil
}
func (r *tdsBuffer) ReadByte() (res byte, err error) {
if r.pos == r.size {
if r.final {
return 0, io.EOF
}
err = r.readNextPacket()
if err != nil {
return 0, err
}
}
res = r.buf[r.pos]
r.pos++
return res, nil
}
func (r *tdsBuffer) byte() byte {
b, err := r.ReadByte()
if err != nil {
badStreamPanic(err)
}
return b
}
func (r *tdsBuffer) ReadFull(buf []byte) {
_, err := io.ReadFull(r, buf[:])
if err != nil {
badStreamPanic(err)
}
}
func (r *tdsBuffer) uint64() uint64 {
var buf [8]byte
r.ReadFull(buf[:])
return binary.LittleEndian.Uint64(buf[:])
}
func (r *tdsBuffer) int32() int32 {
return int32(r.uint32())
}
func (r *tdsBuffer) uint32() uint32 {
var buf [4]byte
r.ReadFull(buf[:])
return binary.LittleEndian.Uint32(buf[:])
}
func (r *tdsBuffer) uint16() uint16 {
var buf [2]byte
r.ReadFull(buf[:])
return binary.LittleEndian.Uint16(buf[:])
}
func (r *tdsBuffer) BVarChar() string {
l := int(r.byte())
return r.readUcs2(l)
}
func (r *tdsBuffer) UsVarChar() string {
l := int(r.uint16())
return r.readUcs2(l)
}
func (r *tdsBuffer) readUcs2(numchars int) string {
b := make([]byte, numchars*2)
r.ReadFull(b)
res, err := ucs22str(b)
if err != nil {
badStreamPanic(err)
}
return res
}
func (r *tdsBuffer) Read(buf []byte) (copied int, err error) {
copied = 0
err = nil
if r.pos == r.size {
if r.final {
return 0, io.EOF
}
err = r.readNextPacket()
if err != nil {
return
}
}
copied = copy(buf, r.buf[r.pos:r.size])
r.pos += uint16(copied)
return
}