From 62622d223dc1c4ada753412fc35a0340b1ac38ce Mon Sep 17 00:00:00 2001 From: Dwi Siswanto Date: Thu, 23 Apr 2026 04:32:08 +0700 Subject: [PATCH] fix: numeric helper argument coercion Several DSL helpers assumed numeric arguments always came as float64. This broke nested constant expressions where another helper returned `int` or `int64` like `"rand_base(rand_int(...))"` or `"rand_text_numeric(rand_int(...))"`. Switch the affected numeric paths to use cast- backed integer conversion. It still rejects non- numeric inputs and keeps the existing optional decompression limit fallback. Fixes #301 Signed-off-by: Dwi Siswanto --- dsl.go | 66 ++++++++++++++++++++++++++++++++-------------- dsl_test.go | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 1 + go.sum | 12 +++++++-- util.go | 47 ++++++++++++++++++++++++++++++--- 5 files changed, 177 insertions(+), 24 deletions(-) diff --git a/dsl.go b/dsl.go index c3f026d..b95a21a 100644 --- a/dsl.go +++ b/dsl.go @@ -333,8 +333,9 @@ func init() { readLimit := DefaultMaxDecompressionSize if len(args) > 1 { - if limit, ok := args[1].(float64); ok { - readLimit = int64(limit) + limit, err := toInt64(args[1]) + if err == nil { + readLimit = limit } } @@ -377,8 +378,9 @@ func init() { readLimit := DefaultMaxDecompressionSize if len(args) > 1 { - if limit, ok := args[1].(float64); ok { - readLimit = int64(limit) + limit, err := toInt64(args[1]) + if err == nil { + readLimit = limit } } @@ -425,8 +427,9 @@ func init() { readLimit := DefaultMaxDecompressionSize if len(args) > 1 { - if limit, ok := args[1].(float64); ok { - readLimit = int64(limit) + limit, err := toInt64(args[1]) + if err == nil { + readLimit = limit } } @@ -864,7 +867,6 @@ func init() { "(length uint, optionalCharSet string) string", false, func(args ...interface{}) (interface{}, error) { - var length int charSet := letters + numbers argSize := len(args) @@ -872,7 +874,10 @@ func init() { return nil, ErrInvalidDslFunction } - length = int(args[0].(float64)) + length, err := toInt(args[0]) + if err != nil { + return nil, err + } if argSize == 2 { inputCharSet := toString(args[1]) @@ -886,7 +891,6 @@ func init() { "(length uint, optionalBadChars string) string", false, func(args ...interface{}) (interface{}, error) { - length := 0 badChars := "" argSize := len(args) @@ -894,7 +898,10 @@ func init() { return nil, ErrInvalidDslFunction } - length = int(args[0].(float64)) + length, err := toInt(args[0]) + if err != nil { + return nil, err + } if argSize == 2 { badChars = toString(args[1]) @@ -906,7 +913,6 @@ func init() { "(length uint, optionalBadChars string) string", false, func(args ...interface{}) (interface{}, error) { - var length int badChars := "" argSize := len(args) @@ -914,7 +920,10 @@ func init() { return nil, ErrInvalidDslFunction } - length = int(args[0].(float64)) + length, err := toInt(args[0]) + if err != nil { + return nil, err + } if argSize == 2 { badChars = toString(args[1]) @@ -931,7 +940,10 @@ func init() { return nil, ErrInvalidDslFunction } - length := int(args[0].(float64)) + length, err := toInt(args[0]) + if err != nil { + return nil, err + } badNumbers := "" if argSize == 2 { @@ -954,10 +966,18 @@ func init() { max := math.MaxInt32 if argSize >= 1 { - min = int(args[0].(float64)) + convertedMin, err := toInt(args[0]) + if err != nil { + return nil, err + } + min = convertedMin } if argSize == 2 { - max = int(args[1].(float64)) + convertedMax, err := toInt(args[1]) + if err != nil { + return nil, err + } + max = convertedMax } rint, err := randint.IntN(max - min) @@ -1001,7 +1021,11 @@ func init() { if argSize != 0 && argSize != 1 { return nil, ErrInvalidDslFunction } else if argSize == 1 { - seconds = int(args[0].(float64)) + convertedSeconds, err := toInt(args[0]) + if err != nil { + return nil, err + } + seconds = convertedSeconds } offset := time.Now().Add(time.Duration(seconds) * time.Second) @@ -1044,7 +1068,10 @@ func init() { if len(args) != 1 { return nil, ErrInvalidDslFunction } - seconds := args[0].(float64) + seconds, err := toInt(args[0]) + if err != nil { + return nil, err + } time.Sleep(time.Duration(seconds) * time.Second) return true, nil })) @@ -1153,8 +1180,9 @@ func init() { return toBool(args[0]), nil })) MustAddFunction(NewWithPositionalArgs("dec_to_hex", 1, true, func(args ...interface{}) (interface{}, error) { - if number, ok := args[0].(float64); ok { - hexNum := strconv.FormatInt(int64(number), 16) + number, err := toInt64(args[0]) + if err == nil { + hexNum := strconv.FormatInt(number, 16) return toString(hexNum), nil } return nil, fmt.Errorf("invalid number: %T", args[0]) diff --git a/dsl_test.go b/dsl_test.go index e0fd4b2..0b64190 100644 --- a/dsl_test.go +++ b/dsl_test.go @@ -633,6 +633,81 @@ func TestRandDslExpressions(t *testing.T) { } } +func TestNestedNumericDslExpressions(t *testing.T) { + tests := []struct { + expression string + assert func(t *testing.T, result interface{}) + }{ + { + expression: `rand_base(rand_int(5, 6), "abc")`, + assert: func(t *testing.T, result interface{}) { + require.Regexp(t, regexp.MustCompile(`^[abc]{5}$`), toString(result)) + }, + }, + { + expression: `rand_text_alpha(rand_int(5, 6), "abc")`, + assert: func(t *testing.T, result interface{}) { + require.Regexp(t, regexp.MustCompile(`^[A-Zd-z]{5}$`), toString(result)) + }, + }, + { + expression: `rand_text_alphanumeric(rand_int(5, 6), "ab12")`, + assert: func(t *testing.T, result interface{}) { + require.Regexp(t, regexp.MustCompile(`^[03-9A-Zc-z]{5}$`), toString(result)) + }, + }, + { + expression: `rand_text_numeric(rand_int(5, 6), 123)`, + assert: func(t *testing.T, result interface{}) { + require.Regexp(t, regexp.MustCompile(`^[0456789]{5}$`), toString(result)) + }, + }, + { + expression: `dec_to_hex(rand_int(16, 17))`, + assert: func(t *testing.T, result interface{}) { + require.Equal(t, "10", result) + }, + }, + { + expression: `date_time("%Y", to_unix_time("2022-01-13"))`, + assert: func(t *testing.T, result interface{}) { + require.Equal(t, "2022", result) + }, + }, + { + expression: `unix_time(rand_int(1, 2)) > unix_time()`, + assert: func(t *testing.T, result interface{}) { + require.Equal(t, true, result) + }, + }, + { + expression: `len(gzip_decode(gzip("hello world"), rand_int(1, 2)))`, + assert: func(t *testing.T, result interface{}) { + require.Equal(t, float64(1), result) + }, + }, + { + expression: `len(zlib_decode(zlib("hello world"), rand_int(1, 2)))`, + assert: func(t *testing.T, result interface{}) { + require.Equal(t, float64(1), result) + }, + }, + { + expression: `len(inflate(deflate("hello world"), rand_int(1, 2)))`, + assert: func(t *testing.T, result interface{}) { + require.Equal(t, float64(1), result) + }, + }, + } + + for _, test := range tests { + t.Run(test.expression, func(t *testing.T) { + result := evaluateExpression(t, test.expression) + test.assert(t, result) + }) + } +} + func TestFakerDslExpressions(t *testing.T) { zeroArgsFakerFunctions := []string{ "rand_ach_account_number", diff --git a/go.mod b/go.mod index 0db3f2d..3fdc0df 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/projectdiscovery/utils v0.10.0 github.com/sashabaranov/go-openai v1.37.0 github.com/spaolacci/murmur3 v1.1.0 + github.com/spf13/cast v1.10.0 github.com/stretchr/testify v1.11.1 github.com/vulncheck-oss/go-exploit v1.51.0 golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b diff --git a/go.sum b/go.sum index 159fa41..40c20de 100644 --- a/go.sum +++ b/go.sum @@ -54,6 +54,8 @@ github.com/dsnet/compress v0.0.2-0.20230904184137-39efe44ab707/go.mod h1:qssHWj6 github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/gofrs/uuid v3.3.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= @@ -116,11 +118,13 @@ github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxh github.com/klauspost/cpuid v1.2.0/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU= github.com/klauspost/pgzip v1.2.6/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/logrusorgru/aurora v2.0.3+incompatible h1:tOpm7WcpBTn4fjmVfgpQq0EfczGlG91VSDkswnjF5A8= github.com/logrusorgru/aurora v2.0.3+incompatible/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4= github.com/logrusorgru/aurora/v4 v4.0.0 h1:sRjfPpun/63iADiSvGGjgA1cAYegEWMPCJdUpJYn9JA= @@ -162,6 +166,8 @@ github.com/projectdiscovery/utils v0.10.0 h1:E3nMm0h3LWt2bbnpRd8Whyj/y0DrMJKYx2z github.com/projectdiscovery/utils v0.10.0/go.mod h1:FL0cQdg3oBMtJdmbBrfLd5i73syNxpkbKO9tivQ0+rI= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rwcarlsen/goexif v0.0.0-20190401172101-9e8deecbddbd/go.mod h1:hPqNNc0+uJM6H+SuU8sEs5K5IQeKccPqeSjfgcKGgPk= github.com/saintfish/chardet v0.0.0-20230101081208-5e3ef4b5456d h1:hrujxIzL1woJ7AwssoOcM/tq5JjjG2yYOc8odClEiXA= github.com/saintfish/chardet v0.0.0-20230101081208-5e3ef4b5456d/go.mod h1:uugorj2VCxiV1x+LzaIdVa9b4S4qGAcH6cbhh4qVxOU= @@ -174,6 +180,8 @@ github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0b github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= diff --git a/util.go b/util.go index 9ecc8e9..6fb1c38 100644 --- a/util.go +++ b/util.go @@ -15,6 +15,7 @@ import ( "github.com/pkg/errors" "github.com/projectdiscovery/utils/html" randint "github.com/projectdiscovery/utils/rand" + "github.com/spf13/cast" ) const ( @@ -138,6 +139,44 @@ func toBool(data interface{}) bool { } } +func numericCastInput(data interface{}) (interface{}, error) { + switch value := data.(type) { + case int, int8, int16, int32, int64: + return value, nil + case uint, uint8, uint16, uint32, uint64: + return value, nil + case float32, float64: + return value, nil + case string: + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return nil, fmt.Errorf("invalid number: %T", data) + } + if _, err := strconv.ParseFloat(trimmed, 64); err != nil { + return nil, err + } + return trimmed, nil + default: + return nil, fmt.Errorf("invalid number: %T", data) + } +} + +func toInt(data interface{}) (int, error) { + value, err := numericCastInput(data) + if err != nil { + return 0, err + } + return cast.ToIntE(value) +} + +func toInt64(data interface{}) (int64, error) { + value, err := numericCastInput(data) + if err != nil { + return 0, err + } + return cast.ToInt64E(value) +} + func insertInto(s string, interval int, sep rune) string { var buffer bytes.Buffer before := interval - 1 @@ -206,10 +245,12 @@ func parseTimeOrNow(arguments []interface{}) (time.Time, error) { return time.Time{}, errors.New("invalid argument type") } currentTime = time.Unix(unixTime, 0) - case int64, float64: - currentTime = time.Unix(int64(inputUnixTime.(float64)), 0) default: - return time.Time{}, errors.New("invalid argument type") + unixTime, err := toInt64(inputUnixTime) + if err != nil { + return time.Time{}, errors.New("invalid argument type") + } + currentTime = time.Unix(unixTime, 0) } } else { currentTime = time.Now()