Skip to content

Commit

Permalink
[PECO-1016-2] Add handling for special types (databricks#158)
Browse files Browse the repository at this point in the history
In this PR, we add further handling to allow for special types to be set
via the DBSQLParams variable.
  • Loading branch information
nithinkdb authored Sep 7, 2023
2 parents 34599c4 + 27d9a87 commit 73073d2
Show file tree
Hide file tree
Showing 5 changed files with 200 additions and 102 deletions.
104 changes: 14 additions & 90 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ package dbsql
import (
"context"
"database/sql/driver"
"fmt"
"strconv"
"time"

"github.com/databricks/databricks-sql-go/driverctx"
Expand Down Expand Up @@ -102,9 +100,6 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
defer log.Duration(msg, start)

ctx = driverctx.NewContextWithConnId(ctx, c.id)
if len(args) > 0 {
return nil, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrParametersNotSupported, nil)
}

exStmtResp, opStatusResp, err := c.runQuery(ctx, query, args)

Expand Down Expand Up @@ -145,9 +140,6 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
msg, start := log.Track("QueryContext")

ctx = driverctx.NewContextWithConnId(ctx, c.id)
if len(args) > 0 {
return nil, dbsqlerrint.NewDriverError(ctx, dbsqlerr.ErrParametersNotSupported, nil)
}

// first we try to get the results synchronously.
// at any point in time that the context is done we must cancel and return
Expand Down Expand Up @@ -288,7 +280,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
MaxRows: int64(c.cfg.MaxRows),
},
CanDecompressLZ4Result_: &c.cfg.UseLz4Compression,
Parameters: namedValuesToTSparkParams(args),
Parameters: convertNamedValuesToSparkParams(args),
}

if c.cfg.UseArrowBatches {
Expand Down Expand Up @@ -342,87 +334,6 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
return resp, err
}

func namedValuesToTSparkParams(args []driver.NamedValue) []*cli_service.TSparkParameter {
var ts []string = []string{"STRING", "DOUBLE", "BOOLEAN", "TIMESTAMP", "FLOAT", "INTEGER", "TINYINT", "SMALLINT", "BIGINT"}
var params []*cli_service.TSparkParameter
for i := range args {
arg := args[i]
param := cli_service.TSparkParameter{Value: &cli_service.TSparkParameterValue{}}
if arg.Name != "" {
param.Name = &arg.Name
} else {
i := int32(arg.Ordinal)
param.Ordinal = &i
}

switch t := arg.Value.(type) {
case bool:
b := arg.Value.(bool)
param.Value.BooleanValue = &b
param.Type = &ts[2]
case string:
s := arg.Value.(string)
param.Value.StringValue = &s
param.Type = &ts[0]
case int:
f := float64(t)
param.Value.DoubleValue = &f
param.Type = &ts[5]
case uint:
f := float64(t)
param.Value.DoubleValue = &f
param.Type = &ts[5]
case int8:
f := float64(t)
param.Value.DoubleValue = &f
param.Type = &ts[6]
case uint8:
f := float64(t)
param.Value.DoubleValue = &f
param.Type = &ts[6]
case int16:
f := float64(t)
param.Value.DoubleValue = &f
param.Type = &ts[7]
case uint16:
f := float64(t)
param.Value.DoubleValue = &f
param.Type = &ts[7]
case int32:
f := float64(t)
param.Value.DoubleValue = &f
param.Type = &ts[5]
case uint32:
f := float64(t)
param.Value.DoubleValue = &f
param.Type = &ts[5]
case int64:
s := strconv.FormatInt(t, 10)
param.Value.StringValue = &s
param.Type = &ts[8]
case uint64:
s := strconv.FormatUint(t, 10)
param.Value.StringValue = &s
param.Type = &ts[8]
case float32:
f := float64(t)
param.Value.DoubleValue = &f
param.Type = &ts[4]
case time.Time:
s := t.String()
param.Value.StringValue = &s
param.Type = &ts[3]
default:
s := fmt.Sprintf("%s", arg.Value)
param.Value.StringValue = &s
param.Type = &ts[0]
}

params = append(params, &param)
}
return params
}

func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperationHandle) (*cli_service.TGetOperationStatusResp, error) {
corrId := driverctx.CorrelationIdFromContext(ctx)
log := logger.WithContext(c.id, corrId, client.SprintGuid(opHandle.OperationId.GUID))
Expand Down Expand Up @@ -481,6 +392,18 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
return statusResp, nil
}

func (c *conn) CheckNamedValue(nv *driver.NamedValue) error {
var err error
if dbsqlParam, ok := nv.Value.(DBSqlParam); ok {
nv.Name = dbsqlParam.Name
dbsqlParam.Value, err = driver.DefaultParameterConverter.ConvertValue(dbsqlParam.Value)
return err
}

nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
return err
}

var _ driver.Conn = (*conn)(nil)
var _ driver.Pinger = (*conn)(nil)
var _ driver.SessionResetter = (*conn)(nil)
Expand All @@ -489,3 +412,4 @@ var _ driver.ExecerContext = (*conn)(nil)
var _ driver.QueryerContext = (*conn)(nil)
var _ driver.ConnPrepareContext = (*conn)(nil)
var _ driver.ConnBeginTx = (*conn)(nil)
var _ driver.NamedValueChecker = (*conn)(nil)
1 change: 0 additions & 1 deletion errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ const (
// Driver errors
ErrNotImplemented = "not implemented"
ErrTransactionsNotSupported = "transactions are not supported"
ErrParametersNotSupported = "query parameters are not supported"
ErrReadQueryStatus = "could not read query status"
ErrSentinelTimeout = "sentinel timed out waiting for operation to complete"

Expand Down
36 changes: 36 additions & 0 deletions parameter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package dbsql

import (
"database/sql/driver"
"strconv"
"testing"
"time"

"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/stretchr/testify/assert"
)

func TestParameter_Inference(t *testing.T) {
t.Run("Should infer types correctly", func(t *testing.T) {
values := [5]driver.NamedValue{{Name: "", Value: float32(5.1)}, {Name: "", Value: time.Now()}, {Name: "", Value: int64(5)}, {Name: "", Value: true}, {Name: "", Value: DBSqlParam{Value: "6.2", Type: Decimal}}}
parameters := convertNamedValuesToSparkParams(values[:])
assert.Equal(t, strconv.FormatFloat(float64(5.1), 'f', -1, 64), *parameters[0].Value.StringValue)
assert.NotNil(t, parameters[1].Value.StringValue)
assert.Equal(t, string("TIMESTAMP"), *parameters[1].Type)
assert.Equal(t, &cli_service.TSparkParameterValue{StringValue: strPtr("5")}, parameters[2].Value)
assert.Equal(t, string("true"), *parameters[3].Value.StringValue)
assert.Equal(t, string("DECIMAL"), *parameters[4].Type)
assert.Equal(t, string("6.2"), *parameters[4].Value.StringValue)
})
}
func TestParameters_Names(t *testing.T) {
t.Run("Should infer types correctly", func(t *testing.T) {
values := [2]driver.NamedValue{{Name: "1", Value: int(26)}, {Name: "", Value: DBSqlParam{Name: "2", Type: Decimal, Value: "6.2"}}}
parameters := convertNamedValuesToSparkParams(values[:])
assert.Equal(t, string("1"), *parameters[0].Name)
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("26")}, *parameters[0].Value)
assert.Equal(t, string("2"), *parameters[1].Name)
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, *parameters[1].Value)
assert.Equal(t, string("DECIMAL"), *parameters[1].Type)
})
}
150 changes: 150 additions & 0 deletions parameters.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
package dbsql

import (
"database/sql/driver"
"fmt"
"strconv"
"time"

"github.com/databricks/databricks-sql-go/internal/cli_service"
)

type DBSqlParam struct {
Name string
Type SqlType
Value any
}

type SqlType int64

const (
String SqlType = iota
Date
Timestamp
Float
Decimal
Double
Integer
BigInt
SmallInt
TinyInt
Boolean
IntervalMonth
IntervalDay
)

func (s SqlType) String() string {
switch s {
case String:
return "STRING"
case Date:
return "DATE"
case Timestamp:
return "TIMESTAMP"
case Float:
return "FLOAT"
case Decimal:
return "DECIMAL"
case Double:
return "DOUBLE"
case Integer:
return "INTEGER"
case BigInt:
return "BIGINT"
case SmallInt:
return "SMALLINT"
case TinyInt:
return "TINYINT"
case Boolean:
return "BOOLEAN"
case IntervalMonth:
return "INTERVAL MONTH"
case IntervalDay:
return "INTERVAL DAY"
}
return "unknown"
}

func valuesToDBSQLParams(namedValues []driver.NamedValue) []DBSqlParam {
var params []DBSqlParam
for i := range namedValues {
namedValue := namedValues[i]
param := *new(DBSqlParam)
param.Name = namedValue.Name
param.Value = namedValue.Value
params = append(params, param)
}
return params
}

func inferTypes(params []DBSqlParam) {
for i := range params {
param := &params[i]
switch value := param.Value.(type) {
case bool:
param.Value = strconv.FormatBool(value)
param.Type = Boolean
case string:
param.Value = value
param.Type = String
case int:
param.Value = strconv.Itoa(value)
param.Type = Integer
case uint:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case int8:
param.Value = strconv.Itoa(int(value))
param.Type = Integer
case uint8:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case int16:
param.Value = strconv.Itoa(int(value))
param.Type = Integer
case uint16:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case int32:
param.Value = strconv.Itoa(int(value))
param.Type = Integer
case uint32:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case int64:
param.Value = strconv.Itoa(int(value))
param.Type = Integer
case uint64:
param.Value = strconv.FormatUint(uint64(value), 10)
param.Type = Integer
case float32:
param.Value = strconv.FormatFloat(float64(value), 'f', -1, 32)
param.Type = Float
case time.Time:
param.Value = value.String()
param.Type = Timestamp
case DBSqlParam:
param.Name = value.Name
param.Value = value.Value
param.Type = value.Type
default:
s := fmt.Sprintf("%s", value)
param.Value = s
param.Type = String
}
}
}
func convertNamedValuesToSparkParams(values []driver.NamedValue) []*cli_service.TSparkParameter {
var sparkParams []*cli_service.TSparkParameter

sqlParams := valuesToDBSQLParams(values)
inferTypes(sqlParams)
for i := range sqlParams {
sqlParam := sqlParams[i]
sparkParamValue := sqlParam.Value.(string)
sparkParamType := sqlParam.Type.String()
sparkParam := cli_service.TSparkParameter{Name: &sqlParam.Name, Type: &sparkParamType, Value: &cli_service.TSparkParameterValue{StringValue: &sparkParamValue}}
sparkParams = append(sparkParams, &sparkParam)
}
return sparkParams
}
11 changes: 0 additions & 11 deletions statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"context"
"database/sql/driver"
"testing"
"time"

"github.com/apache/thrift/lib/go/thrift"
"github.com/databricks/databricks-sql-go/internal/cli_service"
Expand Down Expand Up @@ -166,13 +165,3 @@ func TestStmt_QueryContext(t *testing.T) {
assert.Equal(t, testQuery, savedQueryString)
})
}
func TestParameters(t *testing.T) {
t.Run("Parameter casting should be correct", func(t *testing.T) {
values := [3]driver.NamedValue{{Ordinal: 1, Name: "", Value: float32(5)}, {Ordinal: 2, Name: "", Value: time.Now()}, {Ordinal: 3, Name: "", Value: int64(5)}}
parameters := namedValuesToTSparkParams(values[:])
assert.Equal(t, &cli_service.TSparkParameterValue{DoubleValue: thrift.Float64Ptr(5)}, parameters[0].Value)
assert.NotNil(t, parameters[1].Value.StringValue)
assert.Equal(t, string("TIMESTAMP"), *parameters[1].Type)
assert.Equal(t, &cli_service.TSparkParameterValue{StringValue: strPtr("5")}, parameters[2].Value)
})
}

0 comments on commit 73073d2

Please sign in to comment.