diff --git a/pkg/sqlcmd/format.go b/pkg/sqlcmd/format.go index 55bd2e25..58d3f2cb 100644 --- a/pkg/sqlcmd/format.go +++ b/pkg/sqlcmd/format.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "strconv" "strings" "time" @@ -59,6 +60,11 @@ const ( ControlReplaceConsecutive ) +const ( + realDefaultWidth int64 = 14 // REAL and SMALLMONEY + floatDefaultWidth int64 = 24 // FLOAT and MONEY +) + type columnDetail struct { displayWidth int64 leftJustify bool @@ -371,11 +377,11 @@ func calcColumnDetails(cols []*sql.ColumnType, fixed int64, variable int64) ([]c columnDetails[i].displayWidth = max64(21, nameLen) case "REAL", "SMALLMONEY": columnDetails[i].leftJustify = false - columnDetails[i].displayWidth = max64(14, nameLen) + columnDetails[i].displayWidth = max64(realDefaultWidth, nameLen) columnDetails[i].zeroesAfterDecimal = true case "FLOAT", "MONEY": columnDetails[i].leftJustify = false - columnDetails[i].displayWidth = max64(24, nameLen) + columnDetails[i].displayWidth = max64(floatDefaultWidth, nameLen) columnDetails[i].zeroesAfterDecimal = true case "DECIMAL": columnDetails[i].leftJustify = false @@ -530,6 +536,10 @@ func (f *sqlCmdFormatterType) scanRow(rows *sql.Rows) ([]string, error) { } else { row[n] = "0" } + case float64: + row[n] = formatFloat(x, 64, f.columnDetails[n]) + case float32: + row[n] = formatFloat(float64(x), 32, f.columnDetails[n]) default: var err error if row[n], err = fmt.Sprintf("%v", x), nil; err != nil { @@ -552,6 +562,28 @@ func dateTimeFormatString(scale int, addOffset bool) string { return format } +// formatFloat formats a float value to match ODBC sqlcmd behavior. +// Uses decimal notation for typical values, falls back to scientific for extreme values. +func formatFloat(x float64, bitSize int, col columnDetail) string { + formatted := strconv.FormatFloat(x, 'f', -1, bitSize) + + // Determine width threshold for fallback to scientific notation + threshold := col.displayWidth + if threshold == 0 { + typeName := col.col.DatabaseTypeName() + if typeName == "REAL" || typeName == "SMALLMONEY" { + threshold = realDefaultWidth + } else { + threshold = floatDefaultWidth + } + } + + if int64(len(formatted)) > threshold { + formatted = strconv.FormatFloat(x, 'g', -1, bitSize) + } + return formatted +} + // Prints the final version of a cell based on formatting variables and command line parameters func (f *sqlCmdFormatterType) printColumnValue(val string, col int) { c := f.columnDetails[col] diff --git a/pkg/sqlcmd/format_test.go b/pkg/sqlcmd/format_test.go index f4bee464..b0c5ed0e 100644 --- a/pkg/sqlcmd/format_test.go +++ b/pkg/sqlcmd/format_test.go @@ -162,3 +162,30 @@ func TestFormatterXmlMode(t *testing.T) { assert.NoError(t, err, "runSqlCmd returned error") assert.Equal(t, ``+SqlcmdEol, buf.buf.String()) } + +func TestFormatterFloatDecimalNotation(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer func() { _ = buf.Close() }() + + s.vars.Set(SQLCMDMAXVARTYPEWIDTH, "256") + query := `SELECT CAST(4713347.3103808956 AS FLOAT) as val` + err := runSqlCmd(t, s, []string{query, "GO"}) + assert.NoError(t, err) + + output := buf.buf.String() + assert.NotContains(t, output, "e+", "typical floats should use decimal notation") + assert.Contains(t, output, "4713347.310380", "should contain decimal value") +} + +func TestFormatterFloatScientificFallback(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer func() { _ = buf.Close() }() + + s.vars.Set(SQLCMDMAXVARTYPEWIDTH, "256") + query := `SELECT CAST(1e100 AS FLOAT) as val` + err := runSqlCmd(t, s, []string{query, "GO"}) + assert.NoError(t, err) + + output := buf.buf.String() + assert.Contains(t, output, "e+", "extreme values should use scientific notation") +}