@@ -0,0 +1,2 @@ | |||
.idea | |||
dist/ |
@@ -0,0 +1,16 @@ | |||
[servers] | |||
[servers.main] | |||
addr = "127.0.0.1:6379" | |||
[servers.other] | |||
addr = "127.0.0.1:6380" | |||
[servers.third] | |||
addr = "127.0.0.1:6381" | |||
[[shovels]] | |||
src = "main" | |||
dst = "other" | |||
key = "test" | |||
[[shovels]] | |||
src = "other" | |||
dst = "third" | |||
key = "foo" | |||
dstkey = "bar" |
@@ -0,0 +1,13 @@ | |||
module reshovel | |||
go 1.17 | |||
require ( | |||
github.com/BurntSushi/toml v1.0.0 | |||
github.com/go-redis/redis/v8 v8.11.4 | |||
) | |||
require ( | |||
github.com/cespare/xxhash/v2 v2.1.2 // indirect | |||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect | |||
) |
@@ -0,0 +1,99 @@ | |||
github.com/BurntSushi/toml v1.0.0 h1:dtDWrepsVPfW9H/4y7dDgFc2MBUSeJhlaDtK13CxFlU= | |||
github.com/BurntSushi/toml v1.0.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= | |||
github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= | |||
github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= | |||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | |||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | |||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= | |||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= | |||
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= | |||
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= | |||
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= | |||
github.com/go-redis/redis/v8 v8.11.4 h1:kHoYkfZP6+pe04aFTnhDH6GDROa5yJdHJVNxV3F46Tg= | |||
github.com/go-redis/redis/v8 v8.11.4/go.mod h1:2Z2wHZXdQpCDXEGzqMockDpNyYvi2l4Pxt6RJr792+w= | |||
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE= | |||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= | |||
github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= | |||
github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= | |||
github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= | |||
github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= | |||
github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= | |||
github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= | |||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= | |||
github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= | |||
github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= | |||
github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= | |||
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | |||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | |||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= | |||
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= | |||
github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= | |||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= | |||
github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= | |||
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= | |||
github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= | |||
github.com/onsi/ginkgo v1.16.4 h1:29JGrr5oVBm5ulCWet69zQkzWipVXIol6ygQUe/EzNc= | |||
github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= | |||
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY= | |||
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo= | |||
github.com/onsi/gomega v1.16.0 h1:6gjqkI8iiRHMvdccRJM8rVKjCWk6ZIm6FTm3ddIe4/c= | |||
github.com/onsi/gomega v1.16.0/go.mod h1:HnhC7FXeEQY45zxNK3PPoIUhzk/80Xly9PcubAlGdZY= | |||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | |||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= | |||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= | |||
github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= | |||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= | |||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= | |||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= | |||
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= | |||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= | |||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= | |||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= | |||
golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= | |||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= | |||
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0= | |||
golang.org/x/net v0.0.0-20210428140749-89ef3d95e781/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= | |||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | |||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | |||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= | |||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | |||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= | |||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | |||
golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | |||
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | |||
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | |||
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | |||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | |||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | |||
golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | |||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da h1:b3NXsE2LusjYGGjL5bxEVZZORm/YEFFrWFjR8eFrw/c= | |||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | |||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= | |||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= | |||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | |||
golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= | |||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= | |||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= | |||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= | |||
golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= | |||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | |||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | |||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | |||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= | |||
google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= | |||
google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= | |||
google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= | |||
google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= | |||
google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= | |||
google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= | |||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= | |||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= | |||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | |||
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= | |||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= | |||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= | |||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= | |||
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= | |||
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= | |||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= | |||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= |
@@ -0,0 +1,164 @@ | |||
package main | |||
import ( | |||
"context" | |||
"log" | |||
"os" | |||
"os/signal" | |||
"syscall" | |||
"time" | |||
"github.com/BurntSushi/toml" | |||
"github.com/go-redis/redis/v8" | |||
) | |||
const BatchSize = 10000 | |||
var RedisClients map[string]*redis.Client | |||
type Config struct { | |||
RedisServerConfigs map[string]RedisServerConfig `toml:"servers"` | |||
ShovelConfigs []ShovelConfig `toml:"shovels"` | |||
} | |||
type ShovelConfig struct { | |||
Key string `toml:"key"` | |||
Src string `toml:"src"` | |||
Dst string `toml:"dst"` | |||
DstKey string `toml:"dstkey"` | |||
} | |||
type RedisServerConfig struct { | |||
Network string `toml:"network"` | |||
Addr string `toml:"addr"` | |||
Username string `toml:"username"` | |||
Password string `toml:"password"` | |||
DB int `toml:"db"` | |||
MaxRetries int `toml:"maxretries"` | |||
MinRetryBackoff float64 `toml:"minretrybackoff"` | |||
MaxRetryBackoff float64 `toml:"maxretrybackoff"` | |||
DialTimeout float64 `toml:"dialtimeout"` | |||
ReadTimeout float64 `toml:"readtimeout"` | |||
WriteTimeout float64 `toml:"writetimeout"` | |||
PoolFIFO bool `toml:"poolfifo"` | |||
PoolSize int `toml:"poolsize"` | |||
MinIdleConns int `toml:"minidleconns"` | |||
MaxConnAge float64 `toml:"maxconnage"` | |||
PoolTimeout float64 `toml:"pooltimeout"` | |||
IdleTimeout float64 `toml:"idletimeout"` | |||
IdleCheckFrequency float64 `toml:"idlecheckfrequency"` | |||
} | |||
func RedisConfigToRedisOptions(config RedisServerConfig) *redis.Options { | |||
nano := float64(time.Second.Nanoseconds()) | |||
if config.ReadTimeout == 0 { | |||
config.ReadTimeout = 15 * time.Minute.Seconds() | |||
} | |||
return &redis.Options{ | |||
Network: config.Network, | |||
Addr: config.Addr, | |||
Username: config.Username, | |||
Password: config.Password, | |||
DB: config.DB, | |||
MaxRetries: config.MaxRetries, | |||
MinRetryBackoff: time.Duration(config.MinRetryBackoff * nano), | |||
MaxRetryBackoff: time.Duration(config.MaxRetryBackoff * nano), | |||
DialTimeout: time.Duration(config.DialTimeout * nano), | |||
ReadTimeout: time.Duration(config.ReadTimeout * nano), | |||
WriteTimeout: time.Duration(config.WriteTimeout * nano), | |||
PoolFIFO: config.PoolFIFO, | |||
PoolSize: config.PoolSize, | |||
MinIdleConns: config.MinIdleConns, | |||
MaxConnAge: time.Duration(config.MaxConnAge * nano), | |||
PoolTimeout: time.Duration(config.PoolTimeout * nano), | |||
IdleTimeout: time.Duration(config.IdleTimeout * nano), | |||
IdleCheckFrequency: time.Duration(config.IdleCheckFrequency * nano), | |||
} | |||
} | |||
func StartShovelWorker(c context.Context, dc chan bool, s *redis.Client, d *redis.Client, sk string, dk string) { | |||
defer close(dc) | |||
if dk == "" { | |||
dk = sk | |||
} | |||
var m time.Duration = 0 | |||
for { | |||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) | |||
items, err := s.SPopN(ctx, sk, BatchSize).Result() | |||
cancel() | |||
if err != nil { | |||
log.Printf("unable to spop %s: %s", sk, err) | |||
} else if len(items) != 0 { | |||
var iitems []interface{} | |||
for _, item := range items { | |||
iitems = append(iitems, item) | |||
} | |||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) | |||
err = d.SAdd(ctx, dk, iitems...).Err() | |||
cancel() | |||
if err != nil { | |||
log.Printf("unable to sadd %s: %s", dk, err) | |||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) | |||
err = s.SAdd(ctx, sk, iitems...).Err() | |||
cancel() | |||
if err != nil { | |||
log.Printf("unable to revert spop %s: %s", sk, err) | |||
} | |||
} else if len(items) >= BatchSize { | |||
m = 0 | |||
} | |||
} | |||
t := time.NewTimer(m * time.Second) | |||
select { | |||
case <-c.Done(): | |||
if !t.Stop() { | |||
<-t.C | |||
} | |||
return | |||
case <-t.C: | |||
} | |||
if m < 60 { | |||
m++ | |||
} | |||
} | |||
} | |||
func main() { | |||
var config Config | |||
_, err := toml.DecodeFile("./config.toml", &config) | |||
if err != nil { | |||
log.Panicf("error parsing config.toml: %s", err) | |||
} | |||
RedisClients = map[string]*redis.Client{} | |||
for i, c := range config.ShovelConfigs { | |||
if _, has := config.RedisServerConfigs[c.Src]; !has { | |||
log.Panicf("invalid redis source: %s", c.Src) | |||
} | |||
if _, has := config.RedisServerConfigs[c.Dst]; !has { | |||
log.Panicf("invalid redis destination: %s", c.Dst) | |||
} | |||
if c.DstKey == "" { | |||
config.ShovelConfigs[i].DstKey = c.Key | |||
} | |||
} | |||
for n, c := range config.RedisServerConfigs { | |||
RedisClients[n] = redis.NewClient(RedisConfigToRedisOptions(c)) | |||
} | |||
ctx, cancel := context.WithCancel(context.Background()) | |||
var doneChans []chan bool | |||
for _, c := range config.ShovelConfigs { | |||
log.Printf("starting shovel worker for %s/%s -> %s/%s", c.Src, c.Key, c.Dst, c.DstKey) | |||
doneChan := make(chan bool) | |||
go StartShovelWorker(ctx, doneChan, RedisClients[c.Src], RedisClients[c.Dst], c.Key, c.DstKey) | |||
doneChans = append(doneChans, doneChan) | |||
} | |||
sc := make(chan os.Signal, 1) | |||
signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt, os.Kill) | |||
<-sc | |||
cancel() | |||
log.Printf("waiting for %d workers to shut down...", len(doneChans)) | |||
for _, c := range doneChans { | |||
<-c | |||
} | |||
} |
@@ -0,0 +1,2 @@ | |||
toml.test | |||
/toml-test |
@@ -0,0 +1 @@ | |||
Compatible with TOML version [v1.0.0](https://toml.io/en/v1.0.0). |
@@ -0,0 +1,21 @@ | |||
The MIT License (MIT) | |||
Copyright (c) 2013 TOML authors | |||
Permission is hereby granted, free of charge, to any person obtaining a copy | |||
of this software and associated documentation files (the "Software"), to deal | |||
in the Software without restriction, including without limitation the rights | |||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |||
copies of the Software, and to permit persons to whom the Software is | |||
furnished to do so, subject to the following conditions: | |||
The above copyright notice and this permission notice shall be included in | |||
all copies or substantial portions of the Software. | |||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | |||
THE SOFTWARE. |
@@ -0,0 +1,211 @@ | |||
TOML stands for Tom's Obvious, Minimal Language. This Go package provides a | |||
reflection interface similar to Go's standard library `json` and `xml` | |||
packages. | |||
Compatible with TOML version [v1.0.0](https://toml.io/en/v1.0.0). | |||
Documentation: https://godocs.io/github.com/BurntSushi/toml | |||
See the [releases page](https://github.com/BurntSushi/toml/releases) for a | |||
changelog; this information is also in the git tag annotations (e.g. `git show | |||
v0.4.0`). | |||
This library requires Go 1.13 or newer; install it with: | |||
% go get github.com/BurntSushi/toml@latest | |||
It also comes with a TOML validator CLI tool: | |||
% go install github.com/BurntSushi/toml/cmd/tomlv@latest | |||
% tomlv some-toml-file.toml | |||
### Testing | |||
This package passes all tests in [toml-test] for both the decoder and the | |||
encoder. | |||
[toml-test]: https://github.com/BurntSushi/toml-test | |||
### Examples | |||
This package works similar to how the Go standard library handles XML and JSON. | |||
Namely, data is loaded into Go values via reflection. | |||
For the simplest example, consider some TOML file as just a list of keys and | |||
values: | |||
```toml | |||
Age = 25 | |||
Cats = [ "Cauchy", "Plato" ] | |||
Pi = 3.14 | |||
Perfection = [ 6, 28, 496, 8128 ] | |||
DOB = 1987-07-05T05:45:00Z | |||
``` | |||
Which could be defined in Go as: | |||
```go | |||
type Config struct { | |||
Age int | |||
Cats []string | |||
Pi float64 | |||
Perfection []int | |||
DOB time.Time // requires `import time` | |||
} | |||
``` | |||
And then decoded with: | |||
```go | |||
var conf Config | |||
err := toml.Decode(tomlData, &conf) | |||
// handle error | |||
``` | |||
You can also use struct tags if your struct field name doesn't map to a TOML | |||
key value directly: | |||
```toml | |||
some_key_NAME = "wat" | |||
``` | |||
```go | |||
type TOML struct { | |||
ObscureKey string `toml:"some_key_NAME"` | |||
} | |||
``` | |||
Beware that like other most other decoders **only exported fields** are | |||
considered when encoding and decoding; private fields are silently ignored. | |||
### Using the `Marshaler` and `encoding.TextUnmarshaler` interfaces | |||
Here's an example that automatically parses duration strings into | |||
`time.Duration` values: | |||
```toml | |||
[[song]] | |||
name = "Thunder Road" | |||
duration = "4m49s" | |||
[[song]] | |||
name = "Stairway to Heaven" | |||
duration = "8m03s" | |||
``` | |||
Which can be decoded with: | |||
```go | |||
type song struct { | |||
Name string | |||
Duration duration | |||
} | |||
type songs struct { | |||
Song []song | |||
} | |||
var favorites songs | |||
if _, err := toml.Decode(blob, &favorites); err != nil { | |||
log.Fatal(err) | |||
} | |||
for _, s := range favorites.Song { | |||
fmt.Printf("%s (%s)\n", s.Name, s.Duration) | |||
} | |||
``` | |||
And you'll also need a `duration` type that satisfies the | |||
`encoding.TextUnmarshaler` interface: | |||
```go | |||
type duration struct { | |||
time.Duration | |||
} | |||
func (d *duration) UnmarshalText(text []byte) error { | |||
var err error | |||
d.Duration, err = time.ParseDuration(string(text)) | |||
return err | |||
} | |||
``` | |||
To target TOML specifically you can implement `UnmarshalTOML` TOML interface in | |||
a similar way. | |||
### More complex usage | |||
Here's an example of how to load the example from the official spec page: | |||
```toml | |||
# This is a TOML document. Boom. | |||
title = "TOML Example" | |||
[owner] | |||
name = "Tom Preston-Werner" | |||
organization = "GitHub" | |||
bio = "GitHub Cofounder & CEO\nLikes tater tots and beer." | |||
dob = 1979-05-27T07:32:00Z # First class dates? Why not? | |||
[database] | |||
server = "192.168.1.1" | |||
ports = [ 8001, 8001, 8002 ] | |||
connection_max = 5000 | |||
enabled = true | |||
[servers] | |||
# You can indent as you please. Tabs or spaces. TOML don't care. | |||
[servers.alpha] | |||
ip = "10.0.0.1" | |||
dc = "eqdc10" | |||
[servers.beta] | |||
ip = "10.0.0.2" | |||
dc = "eqdc10" | |||
[clients] | |||
data = [ ["gamma", "delta"], [1, 2] ] # just an update to make sure parsers support it | |||
# Line breaks are OK when inside arrays | |||
hosts = [ | |||
"alpha", | |||
"omega" | |||
] | |||
``` | |||
And the corresponding Go types are: | |||
```go | |||
type tomlConfig struct { | |||
Title string | |||
Owner ownerInfo | |||
DB database `toml:"database"` | |||
Servers map[string]server | |||
Clients clients | |||
} | |||
type ownerInfo struct { | |||
Name string | |||
Org string `toml:"organization"` | |||
Bio string | |||
DOB time.Time | |||
} | |||
type database struct { | |||
Server string | |||
Ports []int | |||
ConnMax int `toml:"connection_max"` | |||
Enabled bool | |||
} | |||
type server struct { | |||
IP string | |||
DC string | |||
} | |||
type clients struct { | |||
Data [][]interface{} | |||
Hosts []string | |||
} | |||
``` | |||
Note that a case insensitive match will be tried if an exact match can't be | |||
found. | |||
A working example of the above can be found in `_example/example.{go,toml}`. |
@@ -0,0 +1,560 @@ | |||
package toml | |||
import ( | |||
"encoding" | |||
"fmt" | |||
"io" | |||
"io/ioutil" | |||
"math" | |||
"os" | |||
"reflect" | |||
"strings" | |||
) | |||
// Unmarshaler is the interface implemented by objects that can unmarshal a | |||
// TOML description of themselves. | |||
type Unmarshaler interface { | |||
UnmarshalTOML(interface{}) error | |||
} | |||
// Unmarshal decodes the contents of `p` in TOML format into a pointer `v`. | |||
func Unmarshal(p []byte, v interface{}) error { | |||
_, err := Decode(string(p), v) | |||
return err | |||
} | |||
// Primitive is a TOML value that hasn't been decoded into a Go value. | |||
// | |||
// This type can be used for any value, which will cause decoding to be delayed. | |||
// You can use the PrimitiveDecode() function to "manually" decode these values. | |||
// | |||
// NOTE: The underlying representation of a `Primitive` value is subject to | |||
// change. Do not rely on it. | |||
// | |||
// NOTE: Primitive values are still parsed, so using them will only avoid the | |||
// overhead of reflection. They can be useful when you don't know the exact type | |||
// of TOML data until runtime. | |||
type Primitive struct { | |||
undecoded interface{} | |||
context Key | |||
} | |||
// The significand precision for float32 and float64 is 24 and 53 bits; this is | |||
// the range a natural number can be stored in a float without loss of data. | |||
const ( | |||
maxSafeFloat32Int = 16777215 // 2^24-1 | |||
maxSafeFloat64Int = 9007199254740991 // 2^53-1 | |||
) | |||
// PrimitiveDecode is just like the other `Decode*` functions, except it | |||
// decodes a TOML value that has already been parsed. Valid primitive values | |||
// can *only* be obtained from values filled by the decoder functions, | |||
// including this method. (i.e., `v` may contain more `Primitive` | |||
// values.) | |||
// | |||
// Meta data for primitive values is included in the meta data returned by | |||
// the `Decode*` functions with one exception: keys returned by the Undecoded | |||
// method will only reflect keys that were decoded. Namely, any keys hidden | |||
// behind a Primitive will be considered undecoded. Executing this method will | |||
// update the undecoded keys in the meta data. (See the example.) | |||
func (md *MetaData) PrimitiveDecode(primValue Primitive, v interface{}) error { | |||
md.context = primValue.context | |||
defer func() { md.context = nil }() | |||
return md.unify(primValue.undecoded, rvalue(v)) | |||
} | |||
// Decoder decodes TOML data. | |||
// | |||
// TOML tables correspond to Go structs or maps (dealer's choice – they can be | |||
// used interchangeably). | |||
// | |||
// TOML table arrays correspond to either a slice of structs or a slice of maps. | |||
// | |||
// TOML datetimes correspond to Go time.Time values. Local datetimes are parsed | |||
// in the local timezone. | |||
// | |||
// All other TOML types (float, string, int, bool and array) correspond to the | |||
// obvious Go types. | |||
// | |||
// An exception to the above rules is if a type implements the TextUnmarshaler | |||
// interface, in which case any primitive TOML value (floats, strings, integers, | |||
// booleans, datetimes) will be converted to a []byte and given to the value's | |||
// UnmarshalText method. See the Unmarshaler example for a demonstration with | |||
// time duration strings. | |||
// | |||
// Key mapping | |||
// | |||
// TOML keys can map to either keys in a Go map or field names in a Go struct. | |||
// The special `toml` struct tag can be used to map TOML keys to struct fields | |||
// that don't match the key name exactly (see the example). A case insensitive | |||
// match to struct names will be tried if an exact match can't be found. | |||
// | |||
// The mapping between TOML values and Go values is loose. That is, there may | |||
// exist TOML values that cannot be placed into your representation, and there | |||
// may be parts of your representation that do not correspond to TOML values. | |||
// This loose mapping can be made stricter by using the IsDefined and/or | |||
// Undecoded methods on the MetaData returned. | |||
// | |||
// This decoder does not handle cyclic types. Decode will not terminate if a | |||
// cyclic type is passed. | |||
type Decoder struct { | |||
r io.Reader | |||
} | |||
// NewDecoder creates a new Decoder. | |||
func NewDecoder(r io.Reader) *Decoder { | |||
return &Decoder{r: r} | |||
} | |||
var ( | |||
unmarshalToml = reflect.TypeOf((*Unmarshaler)(nil)).Elem() | |||
unmarshalText = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() | |||
) | |||
// Decode TOML data in to the pointer `v`. | |||
func (dec *Decoder) Decode(v interface{}) (MetaData, error) { | |||
rv := reflect.ValueOf(v) | |||
if rv.Kind() != reflect.Ptr { | |||
s := "%q" | |||
if reflect.TypeOf(v) == nil { | |||
s = "%v" | |||
} | |||
return MetaData{}, e("cannot decode to non-pointer "+s, reflect.TypeOf(v)) | |||
} | |||
if rv.IsNil() { | |||
return MetaData{}, e("cannot decode to nil value of %q", reflect.TypeOf(v)) | |||
} | |||
// Check if this is a supported type: struct, map, interface{}, or something | |||
// that implements UnmarshalTOML or UnmarshalText. | |||
rv = indirect(rv) | |||
rt := rv.Type() | |||
if rv.Kind() != reflect.Struct && rv.Kind() != reflect.Map && | |||
!(rv.Kind() == reflect.Interface && rv.NumMethod() == 0) && | |||
!rt.Implements(unmarshalToml) && !rt.Implements(unmarshalText) { | |||
return MetaData{}, e("cannot decode to type %s", rt) | |||
} | |||
// TODO: parser should read from io.Reader? Or at the very least, make it | |||
// read from []byte rather than string | |||
data, err := ioutil.ReadAll(dec.r) | |||
if err != nil { | |||
return MetaData{}, err | |||
} | |||
p, err := parse(string(data)) | |||
if err != nil { | |||
return MetaData{}, err | |||
} | |||
md := MetaData{ | |||
mapping: p.mapping, | |||
types: p.types, | |||
keys: p.ordered, | |||
decoded: make(map[string]struct{}, len(p.ordered)), | |||
context: nil, | |||
} | |||
return md, md.unify(p.mapping, rv) | |||
} | |||
// Decode the TOML data in to the pointer v. | |||
// | |||
// See the documentation on Decoder for a description of the decoding process. | |||
func Decode(data string, v interface{}) (MetaData, error) { | |||
return NewDecoder(strings.NewReader(data)).Decode(v) | |||
} | |||
// DecodeFile is just like Decode, except it will automatically read the | |||
// contents of the file at path and decode it for you. | |||
func DecodeFile(path string, v interface{}) (MetaData, error) { | |||
fp, err := os.Open(path) | |||
if err != nil { | |||
return MetaData{}, err | |||
} | |||
defer fp.Close() | |||
return NewDecoder(fp).Decode(v) | |||
} | |||
// unify performs a sort of type unification based on the structure of `rv`, | |||
// which is the client representation. | |||
// | |||
// Any type mismatch produces an error. Finding a type that we don't know | |||
// how to handle produces an unsupported type error. | |||
func (md *MetaData) unify(data interface{}, rv reflect.Value) error { | |||
// Special case. Look for a `Primitive` value. | |||
// TODO: #76 would make this superfluous after implemented. | |||
if rv.Type() == reflect.TypeOf((*Primitive)(nil)).Elem() { | |||
// Save the undecoded data and the key context into the primitive | |||
// value. | |||
context := make(Key, len(md.context)) | |||
copy(context, md.context) | |||
rv.Set(reflect.ValueOf(Primitive{ | |||
undecoded: data, | |||
context: context, | |||
})) | |||
return nil | |||
} | |||
// Special case. Unmarshaler Interface support. | |||
if rv.CanAddr() { | |||
if v, ok := rv.Addr().Interface().(Unmarshaler); ok { | |||
return v.UnmarshalTOML(data) | |||
} | |||
} | |||
// Special case. Look for a value satisfying the TextUnmarshaler interface. | |||
if v, ok := rv.Interface().(encoding.TextUnmarshaler); ok { | |||
return md.unifyText(data, v) | |||
} | |||
// TODO: | |||
// The behavior here is incorrect whenever a Go type satisfies the | |||
// encoding.TextUnmarshaler interface but also corresponds to a TOML hash or | |||
// array. In particular, the unmarshaler should only be applied to primitive | |||
// TOML values. But at this point, it will be applied to all kinds of values | |||
// and produce an incorrect error whenever those values are hashes or arrays | |||
// (including arrays of tables). | |||
k := rv.Kind() | |||
// laziness | |||
if k >= reflect.Int && k <= reflect.Uint64 { | |||
return md.unifyInt(data, rv) | |||
} | |||
switch k { | |||
case reflect.Ptr: | |||
elem := reflect.New(rv.Type().Elem()) | |||
err := md.unify(data, reflect.Indirect(elem)) | |||
if err != nil { | |||
return err | |||
} | |||
rv.Set(elem) | |||
return nil | |||
case reflect.Struct: | |||
return md.unifyStruct(data, rv) | |||
case reflect.Map: | |||
return md.unifyMap(data, rv) | |||
case reflect.Array: | |||
return md.unifyArray(data, rv) | |||
case reflect.Slice: | |||
return md.unifySlice(data, rv) | |||
case reflect.String: | |||
return md.unifyString(data, rv) | |||
case reflect.Bool: | |||
return md.unifyBool(data, rv) | |||
case reflect.Interface: | |||
// we only support empty interfaces. | |||
if rv.NumMethod() > 0 { | |||
return e("unsupported type %s", rv.Type()) | |||
} | |||
return md.unifyAnything(data, rv) | |||
case reflect.Float32, reflect.Float64: | |||
return md.unifyFloat64(data, rv) | |||
} | |||
return e("unsupported type %s", rv.Kind()) | |||
} | |||
func (md *MetaData) unifyStruct(mapping interface{}, rv reflect.Value) error { | |||
tmap, ok := mapping.(map[string]interface{}) | |||
if !ok { | |||
if mapping == nil { | |||
return nil | |||
} | |||
return e("type mismatch for %s: expected table but found %T", | |||
rv.Type().String(), mapping) | |||
} | |||
for key, datum := range tmap { | |||
var f *field | |||
fields := cachedTypeFields(rv.Type()) | |||
for i := range fields { | |||
ff := &fields[i] | |||
if ff.name == key { | |||
f = ff | |||
break | |||
} | |||
if f == nil && strings.EqualFold(ff.name, key) { | |||
f = ff | |||
} | |||
} | |||
if f != nil { | |||
subv := rv | |||
for _, i := range f.index { | |||
subv = indirect(subv.Field(i)) | |||
} | |||
if isUnifiable(subv) { | |||
md.decoded[md.context.add(key).String()] = struct{}{} | |||
md.context = append(md.context, key) | |||
err := md.unify(datum, subv) | |||
if err != nil { | |||
return err | |||
} | |||
md.context = md.context[0 : len(md.context)-1] | |||
} else if f.name != "" { | |||
return e("cannot write unexported field %s.%s", rv.Type().String(), f.name) | |||
} | |||
} | |||
} | |||
return nil | |||
} | |||
func (md *MetaData) unifyMap(mapping interface{}, rv reflect.Value) error { | |||
if k := rv.Type().Key().Kind(); k != reflect.String { | |||
return fmt.Errorf( | |||
"toml: cannot decode to a map with non-string key type (%s in %q)", | |||
k, rv.Type()) | |||
} | |||
tmap, ok := mapping.(map[string]interface{}) | |||
if !ok { | |||
if tmap == nil { | |||
return nil | |||
} | |||
return md.badtype("map", mapping) | |||
} | |||
if rv.IsNil() { | |||
rv.Set(reflect.MakeMap(rv.Type())) | |||
} | |||
for k, v := range tmap { | |||
md.decoded[md.context.add(k).String()] = struct{}{} | |||
md.context = append(md.context, k) | |||
rvval := reflect.Indirect(reflect.New(rv.Type().Elem())) | |||
if err := md.unify(v, rvval); err != nil { | |||
return err | |||
} | |||
md.context = md.context[0 : len(md.context)-1] | |||
rvkey := indirect(reflect.New(rv.Type().Key())) | |||
rvkey.SetString(k) | |||
rv.SetMapIndex(rvkey, rvval) | |||
} | |||
return nil | |||
} | |||
func (md *MetaData) unifyArray(data interface{}, rv reflect.Value) error { | |||
datav := reflect.ValueOf(data) | |||
if datav.Kind() != reflect.Slice { | |||
if !datav.IsValid() { | |||
return nil | |||
} | |||
return md.badtype("slice", data) | |||
} | |||
if l := datav.Len(); l != rv.Len() { | |||
return e("expected array length %d; got TOML array of length %d", rv.Len(), l) | |||
} | |||
return md.unifySliceArray(datav, rv) | |||
} | |||
func (md *MetaData) unifySlice(data interface{}, rv reflect.Value) error { | |||
datav := reflect.ValueOf(data) | |||
if datav.Kind() != reflect.Slice { | |||
if !datav.IsValid() { | |||
return nil | |||
} | |||
return md.badtype("slice", data) | |||
} | |||
n := datav.Len() | |||
if rv.IsNil() || rv.Cap() < n { | |||
rv.Set(reflect.MakeSlice(rv.Type(), n, n)) | |||
} | |||
rv.SetLen(n) | |||
return md.unifySliceArray(datav, rv) | |||
} | |||
func (md *MetaData) unifySliceArray(data, rv reflect.Value) error { | |||
l := data.Len() | |||
for i := 0; i < l; i++ { | |||
err := md.unify(data.Index(i).Interface(), indirect(rv.Index(i))) | |||
if err != nil { | |||
return err | |||
} | |||
} | |||
return nil | |||
} | |||
func (md *MetaData) unifyString(data interface{}, rv reflect.Value) error { | |||
if s, ok := data.(string); ok { | |||
rv.SetString(s) | |||
return nil | |||
} | |||
return md.badtype("string", data) | |||
} | |||
func (md *MetaData) unifyFloat64(data interface{}, rv reflect.Value) error { | |||
if num, ok := data.(float64); ok { | |||
switch rv.Kind() { | |||
case reflect.Float32: | |||
if num < -math.MaxFloat32 || num > math.MaxFloat32 { | |||
return e("value %f is out of range for float32", num) | |||
} | |||
fallthrough | |||
case reflect.Float64: | |||
rv.SetFloat(num) | |||
default: | |||
panic("bug") | |||
} | |||
return nil | |||
} | |||
if num, ok := data.(int64); ok { | |||
switch rv.Kind() { | |||
case reflect.Float32: | |||
if num < -maxSafeFloat32Int || num > maxSafeFloat32Int { | |||
return e("value %d is out of range for float32", num) | |||
} | |||
fallthrough | |||
case reflect.Float64: | |||
if num < -maxSafeFloat64Int || num > maxSafeFloat64Int { | |||
return e("value %d is out of range for float64", num) | |||
} | |||
rv.SetFloat(float64(num)) | |||
default: | |||
panic("bug") | |||
} | |||
return nil | |||
} | |||
return md.badtype("float", data) | |||
} | |||
func (md *MetaData) unifyInt(data interface{}, rv reflect.Value) error { | |||
if num, ok := data.(int64); ok { | |||
if rv.Kind() >= reflect.Int && rv.Kind() <= reflect.Int64 { | |||
switch rv.Kind() { | |||
case reflect.Int, reflect.Int64: | |||
// No bounds checking necessary. | |||
case reflect.Int8: | |||
if num < math.MinInt8 || num > math.MaxInt8 { | |||
return e("value %d is out of range for int8", num) | |||
} | |||
case reflect.Int16: | |||
if num < math.MinInt16 || num > math.MaxInt16 { | |||
return e("value %d is out of range for int16", num) | |||
} | |||
case reflect.Int32: | |||
if num < math.MinInt32 || num > math.MaxInt32 { | |||
return e("value %d is out of range for int32", num) | |||
} | |||
} | |||
rv.SetInt(num) | |||
} else if rv.Kind() >= reflect.Uint && rv.Kind() <= reflect.Uint64 { | |||
unum := uint64(num) | |||
switch rv.Kind() { | |||
case reflect.Uint, reflect.Uint64: | |||
// No bounds checking necessary. | |||
case reflect.Uint8: | |||
if num < 0 || unum > math.MaxUint8 { | |||
return e("value %d is out of range for uint8", num) | |||
} | |||
case reflect.Uint16: | |||
if num < 0 || unum > math.MaxUint16 { | |||
return e("value %d is out of range for uint16", num) | |||
} | |||
case reflect.Uint32: | |||
if num < 0 || unum > math.MaxUint32 { | |||
return e("value %d is out of range for uint32", num) | |||
} | |||
} | |||
rv.SetUint(unum) | |||
} else { | |||
panic("unreachable") | |||
} | |||
return nil | |||
} | |||
return md.badtype("integer", data) | |||
} | |||
func (md *MetaData) unifyBool(data interface{}, rv reflect.Value) error { | |||
if b, ok := data.(bool); ok { | |||
rv.SetBool(b) | |||
return nil | |||
} | |||
return md.badtype("boolean", data) | |||
} | |||
func (md *MetaData) unifyAnything(data interface{}, rv reflect.Value) error { | |||
rv.Set(reflect.ValueOf(data)) | |||
return nil | |||
} | |||
func (md *MetaData) unifyText(data interface{}, v encoding.TextUnmarshaler) error { | |||
var s string | |||
switch sdata := data.(type) { | |||
case Marshaler: | |||
text, err := sdata.MarshalTOML() | |||
if err != nil { | |||
return err | |||
} | |||
s = string(text) | |||
case TextMarshaler: | |||
text, err := sdata.MarshalText() | |||
if err != nil { | |||
return err | |||
} | |||
s = string(text) | |||
case fmt.Stringer: | |||
s = sdata.String() | |||
case string: | |||
s = sdata | |||
case bool: | |||
s = fmt.Sprintf("%v", sdata) | |||
case int64: | |||
s = fmt.Sprintf("%d", sdata) | |||
case float64: | |||
s = fmt.Sprintf("%f", sdata) | |||
default: | |||
return md.badtype("primitive (string-like)", data) | |||
} | |||
if err := v.UnmarshalText([]byte(s)); err != nil { | |||
return err | |||
} | |||
return nil | |||
} | |||
func (md *MetaData) badtype(dst string, data interface{}) error { | |||
return e("incompatible types: TOML key %q has type %T; destination has type %s", md.context, data, dst) | |||
} | |||
// rvalue returns a reflect.Value of `v`. All pointers are resolved. | |||
func rvalue(v interface{}) reflect.Value { | |||
return indirect(reflect.ValueOf(v)) | |||
} | |||
// indirect returns the value pointed to by a pointer. | |||
// | |||
// Pointers are followed until the value is not a pointer. New values are | |||
// allocated for each nil pointer. | |||
// | |||
// An exception to this rule is if the value satisfies an interface of interest | |||
// to us (like encoding.TextUnmarshaler). | |||
func indirect(v reflect.Value) reflect.Value { | |||
if v.Kind() != reflect.Ptr { | |||
if v.CanSet() { | |||
pv := v.Addr() | |||
if _, ok := pv.Interface().(encoding.TextUnmarshaler); ok { | |||
return pv | |||
} | |||
} | |||
return v | |||
} | |||
if v.IsNil() { | |||
v.Set(reflect.New(v.Type().Elem())) | |||
} | |||
return indirect(reflect.Indirect(v)) | |||
} | |||
func isUnifiable(rv reflect.Value) bool { | |||
if rv.CanSet() { | |||
return true | |||
} | |||
if _, ok := rv.Interface().(encoding.TextUnmarshaler); ok { | |||
return true | |||
} | |||
return false | |||
} | |||
func e(format string, args ...interface{}) error { | |||
return fmt.Errorf("toml: "+format, args...) | |||
} |
@@ -0,0 +1,19 @@ | |||
//go:build go1.16 | |||
// +build go1.16 | |||
package toml | |||
import ( | |||
"io/fs" | |||
) | |||
// DecodeFS is just like Decode, except it will automatically read the contents | |||
// of the file at `path` from a fs.FS instance. | |||
func DecodeFS(fsys fs.FS, path string, v interface{}) (MetaData, error) { | |||
fp, err := fsys.Open(path) | |||
if err != nil { | |||
return MetaData{}, err | |||
} | |||
defer fp.Close() | |||
return NewDecoder(fp).Decode(v) | |||
} |
@@ -0,0 +1,21 @@ | |||
package toml | |||
import ( | |||
"encoding" | |||
"io" | |||
) | |||
// Deprecated: use encoding.TextMarshaler | |||
type TextMarshaler encoding.TextMarshaler | |||
// Deprecated: use encoding.TextUnmarshaler | |||
type TextUnmarshaler encoding.TextUnmarshaler | |||
// Deprecated: use MetaData.PrimitiveDecode. | |||
func PrimitiveDecode(primValue Primitive, v interface{}) error { | |||
md := MetaData{decoded: make(map[string]struct{})} | |||
return md.unify(primValue.undecoded, rvalue(v)) | |||
} | |||
// Deprecated: use NewDecoder(reader).Decode(&value). | |||
func DecodeReader(r io.Reader, v interface{}) (MetaData, error) { return NewDecoder(r).Decode(v) } |
@@ -0,0 +1,13 @@ | |||
/* | |||
Package toml implements decoding and encoding of TOML files. | |||
This package supports TOML v1.0.0, as listed on https://toml.io | |||
There is also support for delaying decoding with the Primitive type, and | |||
querying the set of keys in a TOML document with the MetaData type. | |||
The github.com/BurntSushi/toml/cmd/tomlv package implements a TOML validator, | |||
and can be used to verify if TOML document is valid. It can also be used to | |||
print the type of each key. | |||
*/ | |||
package toml |
@@ -0,0 +1,694 @@ | |||
package toml | |||
import ( | |||
"bufio" | |||
"encoding" | |||
"errors" | |||
"fmt" | |||
"io" | |||
"math" | |||
"reflect" | |||
"sort" | |||
"strconv" | |||
"strings" | |||
"time" | |||
"github.com/BurntSushi/toml/internal" | |||
) | |||
type tomlEncodeError struct{ error } | |||
var ( | |||
errArrayNilElement = errors.New("toml: cannot encode array with nil element") | |||
errNonString = errors.New("toml: cannot encode a map with non-string key type") | |||
errNoKey = errors.New("toml: top-level values must be Go maps or structs") | |||
errAnything = errors.New("") // used in testing | |||
) | |||
var dblQuotedReplacer = strings.NewReplacer( | |||
"\"", "\\\"", | |||
"\\", "\\\\", | |||
"\x00", `\u0000`, | |||
"\x01", `\u0001`, | |||
"\x02", `\u0002`, | |||
"\x03", `\u0003`, | |||
"\x04", `\u0004`, | |||
"\x05", `\u0005`, | |||
"\x06", `\u0006`, | |||
"\x07", `\u0007`, | |||
"\b", `\b`, | |||
"\t", `\t`, | |||
"\n", `\n`, | |||
"\x0b", `\u000b`, | |||
"\f", `\f`, | |||
"\r", `\r`, | |||
"\x0e", `\u000e`, | |||
"\x0f", `\u000f`, | |||
"\x10", `\u0010`, | |||
"\x11", `\u0011`, | |||
"\x12", `\u0012`, | |||
"\x13", `\u0013`, | |||
"\x14", `\u0014`, | |||
"\x15", `\u0015`, | |||
"\x16", `\u0016`, | |||
"\x17", `\u0017`, | |||
"\x18", `\u0018`, | |||
"\x19", `\u0019`, | |||
"\x1a", `\u001a`, | |||
"\x1b", `\u001b`, | |||
"\x1c", `\u001c`, | |||
"\x1d", `\u001d`, | |||
"\x1e", `\u001e`, | |||
"\x1f", `\u001f`, | |||
"\x7f", `\u007f`, | |||
) | |||
// Marshaler is the interface implemented by types that can marshal themselves | |||
// into valid TOML. | |||
type Marshaler interface { | |||
MarshalTOML() ([]byte, error) | |||
} | |||
// Encoder encodes a Go to a TOML document. | |||
// | |||
// The mapping between Go values and TOML values should be precisely the same as | |||
// for the Decode* functions. | |||
// | |||
// The toml.Marshaler and encoder.TextMarshaler interfaces are supported to | |||
// encoding the value as custom TOML. | |||
// | |||
// If you want to write arbitrary binary data then you will need to use | |||
// something like base64 since TOML does not have any binary types. | |||
// | |||
// When encoding TOML hashes (Go maps or structs), keys without any sub-hashes | |||
// are encoded first. | |||
// | |||
// Go maps will be sorted alphabetically by key for deterministic output. | |||
// | |||
// Encoding Go values without a corresponding TOML representation will return an | |||
// error. Examples of this includes maps with non-string keys, slices with nil | |||
// elements, embedded non-struct types, and nested slices containing maps or | |||
// structs. (e.g. [][]map[string]string is not allowed but []map[string]string | |||
// is okay, as is []map[string][]string). | |||
// | |||
// NOTE: only exported keys are encoded due to the use of reflection. Unexported | |||
// keys are silently discarded. | |||
type Encoder struct { | |||
// String to use for a single indentation level; default is two spaces. | |||
Indent string | |||
w *bufio.Writer | |||
hasWritten bool // written any output to w yet? | |||
} | |||
// NewEncoder create a new Encoder. | |||
func NewEncoder(w io.Writer) *Encoder { | |||
return &Encoder{ | |||
w: bufio.NewWriter(w), | |||
Indent: " ", | |||
} | |||
} | |||
// Encode writes a TOML representation of the Go value to the Encoder's writer. | |||
// | |||
// An error is returned if the value given cannot be encoded to a valid TOML | |||
// document. | |||
func (enc *Encoder) Encode(v interface{}) error { | |||
rv := eindirect(reflect.ValueOf(v)) | |||
if err := enc.safeEncode(Key([]string{}), rv); err != nil { | |||
return err | |||
} | |||
return enc.w.Flush() | |||
} | |||
func (enc *Encoder) safeEncode(key Key, rv reflect.Value) (err error) { | |||
defer func() { | |||
if r := recover(); r != nil { | |||
if terr, ok := r.(tomlEncodeError); ok { | |||
err = terr.error | |||
return | |||
} | |||
panic(r) | |||
} | |||
}() | |||
enc.encode(key, rv) | |||
return nil | |||
} | |||
func (enc *Encoder) encode(key Key, rv reflect.Value) { | |||
// Special case: time needs to be in ISO8601 format. | |||
// | |||
// Special case: if we can marshal the type to text, then we used that. This | |||
// prevents the encoder for handling these types as generic structs (or | |||
// whatever the underlying type of a TextMarshaler is). | |||
switch t := rv.Interface().(type) { | |||
case time.Time, encoding.TextMarshaler, Marshaler: | |||
enc.writeKeyValue(key, rv, false) | |||
return | |||
// TODO: #76 would make this superfluous after implemented. | |||
case Primitive: | |||
enc.encode(key, reflect.ValueOf(t.undecoded)) | |||
return | |||
} | |||
k := rv.Kind() | |||
switch k { | |||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, | |||
reflect.Int64, | |||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, | |||
reflect.Uint64, | |||
reflect.Float32, reflect.Float64, reflect.String, reflect.Bool: | |||
enc.writeKeyValue(key, rv, false) | |||
case reflect.Array, reflect.Slice: | |||
if typeEqual(tomlArrayHash, tomlTypeOfGo(rv)) { | |||
enc.eArrayOfTables(key, rv) | |||
} else { | |||
enc.writeKeyValue(key, rv, false) | |||
} | |||
case reflect.Interface: | |||
if rv.IsNil() { | |||
return | |||
} | |||
enc.encode(key, rv.Elem()) | |||
case reflect.Map: | |||
if rv.IsNil() { | |||
return | |||
} | |||
enc.eTable(key, rv) | |||
case reflect.Ptr: | |||
if rv.IsNil() { | |||
return | |||
} | |||
enc.encode(key, rv.Elem()) | |||
case reflect.Struct: | |||
enc.eTable(key, rv) | |||
default: | |||
encPanic(fmt.Errorf("unsupported type for key '%s': %s", key, k)) | |||
} | |||
} | |||
// eElement encodes any value that can be an array element. | |||
func (enc *Encoder) eElement(rv reflect.Value) { | |||
switch v := rv.Interface().(type) { | |||
case time.Time: // Using TextMarshaler adds extra quotes, which we don't want. | |||
format := time.RFC3339Nano | |||
switch v.Location() { | |||
case internal.LocalDatetime: | |||
format = "2006-01-02T15:04:05.999999999" | |||
case internal.LocalDate: | |||
format = "2006-01-02" | |||
case internal.LocalTime: | |||
format = "15:04:05.999999999" | |||
} | |||
switch v.Location() { | |||
default: | |||
enc.wf(v.Format(format)) | |||
case internal.LocalDatetime, internal.LocalDate, internal.LocalTime: | |||
enc.wf(v.In(time.UTC).Format(format)) | |||
} | |||
return | |||
case Marshaler: | |||
s, err := v.MarshalTOML() | |||
if err != nil { | |||
encPanic(err) | |||
} | |||
enc.writeQuoted(string(s)) | |||
return | |||
case encoding.TextMarshaler: | |||
s, err := v.MarshalText() | |||
if err != nil { | |||
encPanic(err) | |||
} | |||
enc.writeQuoted(string(s)) | |||
return | |||
} | |||
switch rv.Kind() { | |||
case reflect.String: | |||
enc.writeQuoted(rv.String()) | |||
case reflect.Bool: | |||
enc.wf(strconv.FormatBool(rv.Bool())) | |||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||
enc.wf(strconv.FormatInt(rv.Int(), 10)) | |||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | |||
enc.wf(strconv.FormatUint(rv.Uint(), 10)) | |||
case reflect.Float32: | |||
f := rv.Float() | |||
if math.IsNaN(f) { | |||
enc.wf("nan") | |||
} else if math.IsInf(f, 0) { | |||
enc.wf("%cinf", map[bool]byte{true: '-', false: '+'}[math.Signbit(f)]) | |||
} else { | |||
enc.wf(floatAddDecimal(strconv.FormatFloat(f, 'f', -1, 32))) | |||
} | |||
case reflect.Float64: | |||
f := rv.Float() | |||
if math.IsNaN(f) { | |||
enc.wf("nan") | |||
} else if math.IsInf(f, 0) { | |||
enc.wf("%cinf", map[bool]byte{true: '-', false: '+'}[math.Signbit(f)]) | |||
} else { | |||
enc.wf(floatAddDecimal(strconv.FormatFloat(f, 'f', -1, 64))) | |||
} | |||
case reflect.Array, reflect.Slice: | |||
enc.eArrayOrSliceElement(rv) | |||
case reflect.Struct: | |||
enc.eStruct(nil, rv, true) | |||
case reflect.Map: | |||
enc.eMap(nil, rv, true) | |||
case reflect.Interface: | |||
enc.eElement(rv.Elem()) | |||
default: | |||
encPanic(fmt.Errorf("unexpected primitive type: %T", rv.Interface())) | |||
} | |||
} | |||
// By the TOML spec, all floats must have a decimal with at least one number on | |||
// either side. | |||
func floatAddDecimal(fstr string) string { | |||
if !strings.Contains(fstr, ".") { | |||
return fstr + ".0" | |||
} | |||
return fstr | |||
} | |||
func (enc *Encoder) writeQuoted(s string) { | |||
enc.wf("\"%s\"", dblQuotedReplacer.Replace(s)) | |||
} | |||
func (enc *Encoder) eArrayOrSliceElement(rv reflect.Value) { | |||
length := rv.Len() | |||
enc.wf("[") | |||
for i := 0; i < length; i++ { | |||
elem := rv.Index(i) | |||
enc.eElement(elem) | |||
if i != length-1 { | |||
enc.wf(", ") | |||
} | |||
} | |||
enc.wf("]") | |||
} | |||
func (enc *Encoder) eArrayOfTables(key Key, rv reflect.Value) { | |||
if len(key) == 0 { | |||
encPanic(errNoKey) | |||
} | |||
for i := 0; i < rv.Len(); i++ { | |||
trv := rv.Index(i) | |||
if isNil(trv) { | |||
continue | |||
} | |||
enc.newline() | |||
enc.wf("%s[[%s]]", enc.indentStr(key), key) | |||
enc.newline() | |||
enc.eMapOrStruct(key, trv, false) | |||
} | |||
} | |||
func (enc *Encoder) eTable(key Key, rv reflect.Value) { | |||
if len(key) == 1 { | |||
// Output an extra newline between top-level tables. | |||
// (The newline isn't written if nothing else has been written though.) | |||
enc.newline() | |||
} | |||
if len(key) > 0 { | |||
enc.wf("%s[%s]", enc.indentStr(key), key) | |||
enc.newline() | |||
} | |||
enc.eMapOrStruct(key, rv, false) | |||
} | |||
func (enc *Encoder) eMapOrStruct(key Key, rv reflect.Value, inline bool) { | |||
switch rv := eindirect(rv); rv.Kind() { | |||
case reflect.Map: | |||
enc.eMap(key, rv, inline) | |||
case reflect.Struct: | |||
enc.eStruct(key, rv, inline) | |||
default: | |||
// Should never happen? | |||
panic("eTable: unhandled reflect.Value Kind: " + rv.Kind().String()) | |||
} | |||
} | |||
func (enc *Encoder) eMap(key Key, rv reflect.Value, inline bool) { | |||
rt := rv.Type() | |||
if rt.Key().Kind() != reflect.String { | |||
encPanic(errNonString) | |||
} | |||
// Sort keys so that we have deterministic output. And write keys directly | |||
// underneath this key first, before writing sub-structs or sub-maps. | |||
var mapKeysDirect, mapKeysSub []string | |||
for _, mapKey := range rv.MapKeys() { | |||
k := mapKey.String() | |||
if typeIsTable(tomlTypeOfGo(rv.MapIndex(mapKey))) { | |||
mapKeysSub = append(mapKeysSub, k) | |||
} else { | |||
mapKeysDirect = append(mapKeysDirect, k) | |||
} | |||
} | |||
var writeMapKeys = func(mapKeys []string, trailC bool) { | |||
sort.Strings(mapKeys) | |||
for i, mapKey := range mapKeys { | |||
val := rv.MapIndex(reflect.ValueOf(mapKey)) | |||
if isNil(val) { | |||
continue | |||
} | |||
if inline { | |||
enc.writeKeyValue(Key{mapKey}, val, true) | |||
if trailC || i != len(mapKeys)-1 { | |||
enc.wf(", ") | |||
} | |||
} else { | |||
enc.encode(key.add(mapKey), val) | |||
} | |||
} | |||
} | |||
if inline { | |||
enc.wf("{") | |||
} | |||
writeMapKeys(mapKeysDirect, len(mapKeysSub) > 0) | |||
writeMapKeys(mapKeysSub, false) | |||
if inline { | |||
enc.wf("}") | |||
} | |||
} | |||
const is32Bit = (32 << (^uint(0) >> 63)) == 32 | |||
func (enc *Encoder) eStruct(key Key, rv reflect.Value, inline bool) { | |||
// Write keys for fields directly under this key first, because if we write | |||
// a field that creates a new table then all keys under it will be in that | |||
// table (not the one we're writing here). | |||
// | |||
// Fields is a [][]int: for fieldsDirect this always has one entry (the | |||
// struct index). For fieldsSub it contains two entries: the parent field | |||
// index from tv, and the field indexes for the fields of the sub. | |||
var ( | |||
rt = rv.Type() | |||
fieldsDirect, fieldsSub [][]int | |||
addFields func(rt reflect.Type, rv reflect.Value, start []int) | |||
) | |||
addFields = func(rt reflect.Type, rv reflect.Value, start []int) { | |||
for i := 0; i < rt.NumField(); i++ { | |||
f := rt.Field(i) | |||
if f.PkgPath != "" && !f.Anonymous { /// Skip unexported fields. | |||
continue | |||
} | |||
frv := rv.Field(i) | |||
// Treat anonymous struct fields with tag names as though they are | |||
// not anonymous, like encoding/json does. | |||
// | |||
// Non-struct anonymous fields use the normal encoding logic. | |||
if f.Anonymous { | |||
t := f.Type | |||
switch t.Kind() { | |||
case reflect.Struct: | |||
if getOptions(f.Tag).name == "" { | |||
addFields(t, frv, append(start, f.Index...)) | |||
continue | |||
} | |||
case reflect.Ptr: | |||
if t.Elem().Kind() == reflect.Struct && getOptions(f.Tag).name == "" { | |||
if !frv.IsNil() { | |||
addFields(t.Elem(), frv.Elem(), append(start, f.Index...)) | |||
} | |||
continue | |||
} | |||
} | |||
} | |||
if typeIsTable(tomlTypeOfGo(frv)) { | |||
fieldsSub = append(fieldsSub, append(start, f.Index...)) | |||
} else { | |||
// Copy so it works correct on 32bit archs; not clear why this | |||
// is needed. See #314, and https://www.reddit.com/r/golang/comments/pnx8v4 | |||
// This also works fine on 64bit, but 32bit archs are somewhat | |||
// rare and this is a wee bit faster. | |||
if is32Bit { | |||
copyStart := make([]int, len(start)) | |||
copy(copyStart, start) | |||
fieldsDirect = append(fieldsDirect, append(copyStart, f.Index...)) | |||
} else { | |||
fieldsDirect = append(fieldsDirect, append(start, f.Index...)) | |||
} | |||
} | |||
} | |||
} | |||
addFields(rt, rv, nil) | |||
writeFields := func(fields [][]int) { | |||
for _, fieldIndex := range fields { | |||
fieldType := rt.FieldByIndex(fieldIndex) | |||
fieldVal := rv.FieldByIndex(fieldIndex) | |||
if isNil(fieldVal) { /// Don't write anything for nil fields. | |||
continue | |||
} | |||
opts := getOptions(fieldType.Tag) | |||
if opts.skip { | |||
continue | |||
} | |||
keyName := fieldType.Name | |||
if opts.name != "" { | |||
keyName = opts.name | |||
} | |||
if opts.omitempty && isEmpty(fieldVal) { | |||
continue | |||
} | |||
if opts.omitzero && isZero(fieldVal) { | |||
continue | |||
} | |||
if inline { | |||
enc.writeKeyValue(Key{keyName}, fieldVal, true) | |||
if fieldIndex[0] != len(fields)-1 { | |||
enc.wf(", ") | |||
} | |||
} else { | |||
enc.encode(key.add(keyName), fieldVal) | |||
} | |||
} | |||
} | |||
if inline { | |||
enc.wf("{") | |||
} | |||
writeFields(fieldsDirect) | |||
writeFields(fieldsSub) | |||
if inline { | |||
enc.wf("}") | |||
} | |||
} | |||
// tomlTypeOfGo returns the TOML type name of the Go value's type. | |||
// | |||
// It is used to determine whether the types of array elements are mixed (which | |||
// is forbidden). If the Go value is nil, then it is illegal for it to be an | |||
// array element, and valueIsNil is returned as true. | |||
// | |||
// The type may be `nil`, which means no concrete TOML type could be found. | |||
func tomlTypeOfGo(rv reflect.Value) tomlType { | |||
if isNil(rv) || !rv.IsValid() { | |||
return nil | |||
} | |||
switch rv.Kind() { | |||
case reflect.Bool: | |||
return tomlBool | |||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, | |||
reflect.Int64, | |||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, | |||
reflect.Uint64: | |||
return tomlInteger | |||
case reflect.Float32, reflect.Float64: | |||
return tomlFloat | |||
case reflect.Array, reflect.Slice: | |||
if typeEqual(tomlHash, tomlArrayType(rv)) { | |||
return tomlArrayHash | |||
} | |||
return tomlArray | |||
case reflect.Ptr, reflect.Interface: | |||
return tomlTypeOfGo(rv.Elem()) | |||
case reflect.String: | |||
return tomlString | |||
case reflect.Map: | |||
return tomlHash | |||
case reflect.Struct: | |||
if _, ok := rv.Interface().(time.Time); ok { | |||
return tomlDatetime | |||
} | |||
if isMarshaler(rv) { | |||
return tomlString | |||
} | |||
return tomlHash | |||
default: | |||
if isMarshaler(rv) { | |||
return tomlString | |||
} | |||
encPanic(errors.New("unsupported type: " + rv.Kind().String())) | |||
panic("unreachable") | |||
} | |||
} | |||
func isMarshaler(rv reflect.Value) bool { | |||
switch rv.Interface().(type) { | |||
case encoding.TextMarshaler: | |||
return true | |||
case Marshaler: | |||
return true | |||
} | |||
// Someone used a pointer receiver: we can make it work for pointer values. | |||
if rv.CanAddr() { | |||
if _, ok := rv.Addr().Interface().(encoding.TextMarshaler); ok { | |||
return true | |||
} | |||
if _, ok := rv.Addr().Interface().(Marshaler); ok { | |||
return true | |||
} | |||
} | |||
return false | |||
} | |||
// tomlArrayType returns the element type of a TOML array. The type returned | |||
// may be nil if it cannot be determined (e.g., a nil slice or a zero length | |||
// slize). This function may also panic if it finds a type that cannot be | |||
// expressed in TOML (such as nil elements, heterogeneous arrays or directly | |||
// nested arrays of tables). | |||
func tomlArrayType(rv reflect.Value) tomlType { | |||
if isNil(rv) || !rv.IsValid() || rv.Len() == 0 { | |||
return nil | |||
} | |||
/// Don't allow nil. | |||
rvlen := rv.Len() | |||
for i := 1; i < rvlen; i++ { | |||
if tomlTypeOfGo(rv.Index(i)) == nil { | |||
encPanic(errArrayNilElement) | |||
} | |||
} | |||
firstType := tomlTypeOfGo(rv.Index(0)) | |||
if firstType == nil { | |||
encPanic(errArrayNilElement) | |||
} | |||
return firstType | |||
} | |||
type tagOptions struct { | |||
skip bool // "-" | |||
name string | |||
omitempty bool | |||
omitzero bool | |||
} | |||
func getOptions(tag reflect.StructTag) tagOptions { | |||
t := tag.Get("toml") | |||
if t == "-" { | |||
return tagOptions{skip: true} | |||
} | |||
var opts tagOptions | |||
parts := strings.Split(t, ",") | |||
opts.name = parts[0] | |||
for _, s := range parts[1:] { | |||
switch s { | |||
case "omitempty": | |||
opts.omitempty = true | |||
case "omitzero": | |||
opts.omitzero = true | |||
} | |||
} | |||
return opts | |||
} | |||
func isZero(rv reflect.Value) bool { | |||
switch rv.Kind() { | |||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||
return rv.Int() == 0 | |||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | |||
return rv.Uint() == 0 | |||
case reflect.Float32, reflect.Float64: | |||
return rv.Float() == 0.0 | |||
} | |||
return false | |||
} | |||
func isEmpty(rv reflect.Value) bool { | |||
switch rv.Kind() { | |||
case reflect.Array, reflect.Slice, reflect.Map, reflect.String: | |||
return rv.Len() == 0 | |||
case reflect.Bool: | |||
return !rv.Bool() | |||
} | |||
return false | |||
} | |||
func (enc *Encoder) newline() { | |||
if enc.hasWritten { | |||
enc.wf("\n") | |||
} | |||
} | |||
// Write a key/value pair: | |||
// | |||
// key = <any value> | |||
// | |||
// This is also used for "k = v" in inline tables; so something like this will | |||
// be written in three calls: | |||
// | |||
// ┌────────────────────┐ | |||
// │ ┌───┐ ┌─────┐│ | |||
// v v v v vv | |||
// key = {k = v, k2 = v2} | |||
// | |||
func (enc *Encoder) writeKeyValue(key Key, val reflect.Value, inline bool) { | |||
if len(key) == 0 { | |||
encPanic(errNoKey) | |||
} | |||
enc.wf("%s%s = ", enc.indentStr(key), key.maybeQuoted(len(key)-1)) | |||
enc.eElement(val) | |||
if !inline { | |||
enc.newline() | |||
} | |||
} | |||
func (enc *Encoder) wf(format string, v ...interface{}) { | |||
_, err := fmt.Fprintf(enc.w, format, v...) | |||
if err != nil { | |||
encPanic(err) | |||
} | |||
enc.hasWritten = true | |||
} | |||
func (enc *Encoder) indentStr(key Key) string { | |||
return strings.Repeat(enc.Indent, len(key)-1) | |||
} | |||
func encPanic(err error) { | |||
panic(tomlEncodeError{err}) | |||
} | |||
func eindirect(v reflect.Value) reflect.Value { | |||
switch v.Kind() { | |||
case reflect.Ptr, reflect.Interface: | |||
return eindirect(v.Elem()) | |||
default: | |||
return v | |||
} | |||
} | |||
func isNil(rv reflect.Value) bool { | |||
switch rv.Kind() { | |||
case reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice: | |||
return rv.IsNil() | |||
default: | |||
return false | |||
} | |||
} |
@@ -0,0 +1,229 @@ | |||
package toml | |||
import ( | |||
"fmt" | |||
"strings" | |||
) | |||
// ParseError is returned when there is an error parsing the TOML syntax. | |||
// | |||
// For example invalid syntax, duplicate keys, etc. | |||
// | |||
// In addition to the error message itself, you can also print detailed location | |||
// information with context by using ErrorWithLocation(): | |||
// | |||
// toml: error: Key 'fruit' was already created and cannot be used as an array. | |||
// | |||
// At line 4, column 2-7: | |||
// | |||
// 2 | fruit = [] | |||
// 3 | | |||
// 4 | [[fruit]] # Not allowed | |||
// ^^^^^ | |||
// | |||
// Furthermore, the ErrorWithUsage() can be used to print the above with some | |||
// more detailed usage guidance: | |||
// | |||
// toml: error: newlines not allowed within inline tables | |||
// | |||
// At line 1, column 18: | |||
// | |||
// 1 | x = [{ key = 42 # | |||
// ^ | |||
// | |||
// Error help: | |||
// | |||
// Inline tables must always be on a single line: | |||
// | |||
// table = {key = 42, second = 43} | |||
// | |||
// It is invalid to split them over multiple lines like so: | |||
// | |||
// # INVALID | |||
// table = { | |||
// key = 42, | |||
// second = 43 | |||
// } | |||
// | |||
// Use regular for this: | |||
// | |||
// [table] | |||
// key = 42 | |||
// second = 43 | |||
type ParseError struct { | |||
Message string // Short technical message. | |||
Usage string // Longer message with usage guidance; may be blank. | |||
Position Position // Position of the error | |||
LastKey string // Last parsed key, may be blank. | |||
Line int // Line the error occurred. Deprecated: use Position. | |||
err error | |||
input string | |||
} | |||
// Position of an error. | |||
type Position struct { | |||
Line int // Line number, starting at 1. | |||
Start int // Start of error, as byte offset starting at 0. | |||
Len int // Lenght in bytes. | |||
} | |||
func (pe ParseError) Error() string { | |||
msg := pe.Message | |||
if msg == "" { // Error from errorf() | |||
msg = pe.err.Error() | |||
} | |||
if pe.LastKey == "" { | |||
return fmt.Sprintf("toml: line %d: %s", pe.Position.Line, msg) | |||
} | |||
return fmt.Sprintf("toml: line %d (last key %q): %s", | |||
pe.Position.Line, pe.LastKey, msg) | |||
} | |||
// ErrorWithUsage() returns the error with detailed location context. | |||
// | |||
// See the documentation on ParseError. | |||
func (pe ParseError) ErrorWithPosition() string { | |||
if pe.input == "" { // Should never happen, but just in case. | |||
return pe.Error() | |||
} | |||
var ( | |||
lines = strings.Split(pe.input, "\n") | |||
col = pe.column(lines) | |||
b = new(strings.Builder) | |||
) | |||
msg := pe.Message | |||
if msg == "" { | |||
msg = pe.err.Error() | |||
} | |||
// TODO: don't show control characters as literals? This may not show up | |||
// well everywhere. | |||
if pe.Position.Len == 1 { | |||
fmt.Fprintf(b, "toml: error: %s\n\nAt line %d, column %d:\n\n", | |||
msg, pe.Position.Line, col+1) | |||
} else { | |||
fmt.Fprintf(b, "toml: error: %s\n\nAt line %d, column %d-%d:\n\n", | |||
msg, pe.Position.Line, col, col+pe.Position.Len) | |||
} | |||
if pe.Position.Line > 2 { | |||
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line-2, lines[pe.Position.Line-3]) | |||
} | |||
if pe.Position.Line > 1 { | |||
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line-1, lines[pe.Position.Line-2]) | |||
} | |||
fmt.Fprintf(b, "% 7d | %s\n", pe.Position.Line, lines[pe.Position.Line-1]) | |||
fmt.Fprintf(b, "% 10s%s%s\n", "", strings.Repeat(" ", col), strings.Repeat("^", pe.Position.Len)) | |||
return b.String() | |||
} | |||
// ErrorWithUsage() returns the error with detailed location context and usage | |||
// guidance. | |||
// | |||
// See the documentation on ParseError. | |||
func (pe ParseError) ErrorWithUsage() string { | |||
m := pe.ErrorWithPosition() | |||
if u, ok := pe.err.(interface{ Usage() string }); ok && u.Usage() != "" { | |||
return m + "Error help:\n\n " + | |||
strings.ReplaceAll(strings.TrimSpace(u.Usage()), "\n", "\n ") + | |||
"\n" | |||
} | |||
return m | |||
} | |||
func (pe ParseError) column(lines []string) int { | |||
var pos, col int | |||
for i := range lines { | |||
ll := len(lines[i]) + 1 // +1 for the removed newline | |||
if pos+ll >= pe.Position.Start { | |||
col = pe.Position.Start - pos | |||
if col < 0 { // Should never happen, but just in case. | |||
col = 0 | |||
} | |||
break | |||
} | |||
pos += ll | |||
} | |||
return col | |||
} | |||
type ( | |||
errLexControl struct{ r rune } | |||
errLexEscape struct{ r rune } | |||
errLexUTF8 struct{ b byte } | |||
errLexInvalidNum struct{ v string } | |||
errLexInvalidDate struct{ v string } | |||
errLexInlineTableNL struct{} | |||
errLexStringNL struct{} | |||
) | |||
func (e errLexControl) Error() string { | |||
return fmt.Sprintf("TOML files cannot contain control characters: '0x%02x'", e.r) | |||
} | |||
func (e errLexControl) Usage() string { return "" } | |||
func (e errLexEscape) Error() string { return fmt.Sprintf(`invalid escape in string '\%c'`, e.r) } | |||
func (e errLexEscape) Usage() string { return usageEscape } | |||
func (e errLexUTF8) Error() string { return fmt.Sprintf("invalid UTF-8 byte: 0x%02x", e.b) } | |||
func (e errLexUTF8) Usage() string { return "" } | |||
func (e errLexInvalidNum) Error() string { return fmt.Sprintf("invalid number: %q", e.v) } | |||
func (e errLexInvalidNum) Usage() string { return "" } | |||
func (e errLexInvalidDate) Error() string { return fmt.Sprintf("invalid date: %q", e.v) } | |||
func (e errLexInvalidDate) Usage() string { return "" } | |||
func (e errLexInlineTableNL) Error() string { return "newlines not allowed within inline tables" } | |||
func (e errLexInlineTableNL) Usage() string { return usageInlineNewline } | |||
func (e errLexStringNL) Error() string { return "strings cannot contain newlines" } | |||
func (e errLexStringNL) Usage() string { return usageStringNewline } | |||
const usageEscape = ` | |||
A '\' inside a "-delimited string is interpreted as an escape character. | |||
The following escape sequences are supported: | |||
\b, \t, \n, \f, \r, \", \\, \uXXXX, and \UXXXXXXXX | |||
To prevent a '\' from being recognized as an escape character, use either: | |||
- a ' or '''-delimited string; escape characters aren't processed in them; or | |||
- write two backslashes to get a single backslash: '\\'. | |||
If you're trying to add a Windows path (e.g. "C:\Users\martin") then using '/' | |||
instead of '\' will usually also work: "C:/Users/martin". | |||
` | |||
const usageInlineNewline = ` | |||
Inline tables must always be on a single line: | |||
table = {key = 42, second = 43} | |||
It is invalid to split them over multiple lines like so: | |||
# INVALID | |||
table = { | |||
key = 42, | |||
second = 43 | |||
} | |||
Use regular for this: | |||
[table] | |||
key = 42 | |||
second = 43 | |||
` | |||
const usageStringNewline = ` | |||
Strings must always be on a single line, and cannot span more than one line: | |||
# INVALID | |||
string = "Hello, | |||
world!" | |||
Instead use """ or ''' to split strings over multiple lines: | |||
string = """Hello, | |||
world!""" | |||
` |
@@ -0,0 +1,36 @@ | |||
package internal | |||
import "time" | |||
// Timezones used for local datetime, date, and time TOML types. | |||
// | |||
// The exact way times and dates without a timezone should be interpreted is not | |||
// well-defined in the TOML specification and left to the implementation. These | |||
// defaults to current local timezone offset of the computer, but this can be | |||
// changed by changing these variables before decoding. | |||
// | |||
// TODO: | |||
// Ideally we'd like to offer people the ability to configure the used timezone | |||
// by setting Decoder.Timezone and Encoder.Timezone; however, this is a bit | |||
// tricky: the reason we use three different variables for this is to support | |||
// round-tripping – without these specific TZ names we wouldn't know which | |||
// format to use. | |||
// | |||
// There isn't a good way to encode this right now though, and passing this sort | |||
// of information also ties in to various related issues such as string format | |||
// encoding, encoding of comments, etc. | |||
// | |||
// So, for the time being, just put this in internal until we can write a good | |||
// comprehensive API for doing all of this. | |||
// | |||
// The reason they're exported is because they're referred from in e.g. | |||
// internal/tag. | |||
// | |||
// Note that this behaviour is valid according to the TOML spec as the exact | |||
// behaviour is left up to implementations. | |||
var ( | |||
localOffset = func() int { _, o := time.Now().Zone(); return o }() | |||
LocalDatetime = time.FixedZone("datetime-local", localOffset) | |||
LocalDate = time.FixedZone("date-local", localOffset) | |||
LocalTime = time.FixedZone("time-local", localOffset) | |||
) |
@@ -0,0 +1,120 @@ | |||
package toml | |||
import ( | |||
"strings" | |||
) | |||
// MetaData allows access to meta information about TOML data that's not | |||
// accessible otherwise. | |||
// | |||
// It allows checking if a key is defined in the TOML data, whether any keys | |||
// were undecoded, and the TOML type of a key. | |||
type MetaData struct { | |||
context Key // Used only during decoding. | |||
mapping map[string]interface{} | |||
types map[string]tomlType | |||
keys []Key | |||
decoded map[string]struct{} | |||
} | |||
// IsDefined reports if the key exists in the TOML data. | |||
// | |||
// The key should be specified hierarchically, for example to access the TOML | |||
// key "a.b.c" you would use IsDefined("a", "b", "c"). Keys are case sensitive. | |||
// | |||
// Returns false for an empty key. | |||
func (md *MetaData) IsDefined(key ...string) bool { | |||
if len(key) == 0 { | |||
return false | |||
} | |||
var ( | |||
hash map[string]interface{} | |||
ok bool | |||
hashOrVal interface{} = md.mapping | |||
) | |||
for _, k := range key { | |||
if hash, ok = hashOrVal.(map[string]interface{}); !ok { | |||
return false | |||
} | |||
if hashOrVal, ok = hash[k]; !ok { | |||
return false | |||
} | |||
} | |||
return true | |||
} | |||
// Type returns a string representation of the type of the key specified. | |||
// | |||
// Type will return the empty string if given an empty key or a key that does | |||
// not exist. Keys are case sensitive. | |||
func (md *MetaData) Type(key ...string) string { | |||
if typ, ok := md.types[Key(key).String()]; ok { | |||
return typ.typeString() | |||
} | |||
return "" | |||
} | |||
// Keys returns a slice of every key in the TOML data, including key groups. | |||
// | |||
// Each key is itself a slice, where the first element is the top of the | |||
// hierarchy and the last is the most specific. The list will have the same | |||
// order as the keys appeared in the TOML data. | |||
// | |||
// All keys returned are non-empty. | |||
func (md *MetaData) Keys() []Key { | |||
return md.keys | |||
} | |||
// Undecoded returns all keys that have not been decoded in the order in which | |||
// they appear in the original TOML document. | |||
// | |||
// This includes keys that haven't been decoded because of a Primitive value. | |||
// Once the Primitive value is decoded, the keys will be considered decoded. | |||
// | |||
// Also note that decoding into an empty interface will result in no decoding, | |||
// and so no keys will be considered decoded. | |||
// | |||
// In this sense, the Undecoded keys correspond to keys in the TOML document | |||
// that do not have a concrete type in your representation. | |||
func (md *MetaData) Undecoded() []Key { | |||
undecoded := make([]Key, 0, len(md.keys)) | |||
for _, key := range md.keys { | |||
if _, ok := md.decoded[key.String()]; !ok { | |||
undecoded = append(undecoded, key) | |||
} | |||
} | |||
return undecoded | |||
} | |||
// Key represents any TOML key, including key groups. Use (MetaData).Keys to get | |||
// values of this type. | |||
type Key []string | |||
func (k Key) String() string { | |||
ss := make([]string, len(k)) | |||
for i := range k { | |||
ss[i] = k.maybeQuoted(i) | |||
} | |||
return strings.Join(ss, ".") | |||
} | |||
func (k Key) maybeQuoted(i int) string { | |||
if k[i] == "" { | |||
return `""` | |||
} | |||
for _, c := range k[i] { | |||
if !isBareKeyChar(c) { | |||
return `"` + dblQuotedReplacer.Replace(k[i]) + `"` | |||
} | |||
} | |||
return k[i] | |||
} | |||
func (k Key) add(piece string) Key { | |||
newKey := make(Key, len(k)+1) | |||
copy(newKey, k) | |||
newKey[len(k)] = piece | |||
return newKey | |||
} |
@@ -0,0 +1,763 @@ | |||
package toml | |||
import ( | |||
"fmt" | |||
"strconv" | |||
"strings" | |||
"time" | |||
"unicode/utf8" | |||
"github.com/BurntSushi/toml/internal" | |||
) | |||
type parser struct { | |||
lx *lexer | |||
context Key // Full key for the current hash in scope. | |||
currentKey string // Base key name for everything except hashes. | |||
pos Position // Current position in the TOML file. | |||
ordered []Key // List of keys in the order that they appear in the TOML data. | |||
mapping map[string]interface{} // Map keyname → key value. | |||
types map[string]tomlType // Map keyname → TOML type. | |||
implicits map[string]struct{} // Record implicit keys (e.g. "key.group.names"). | |||
} | |||
func parse(data string) (p *parser, err error) { | |||
defer func() { | |||
if r := recover(); r != nil { | |||
if pErr, ok := r.(ParseError); ok { | |||
pErr.input = data | |||
err = pErr | |||
return | |||
} | |||
panic(r) | |||
} | |||
}() | |||
// Read over BOM; do this here as the lexer calls utf8.DecodeRuneInString() | |||
// which mangles stuff. | |||
if strings.HasPrefix(data, "\xff\xfe") || strings.HasPrefix(data, "\xfe\xff") { | |||
data = data[2:] | |||
} | |||
// Examine first few bytes for NULL bytes; this probably means it's a UTF-16 | |||
// file (second byte in surrogate pair being NULL). Again, do this here to | |||
// avoid having to deal with UTF-8/16 stuff in the lexer. | |||
ex := 6 | |||
if len(data) < 6 { | |||
ex = len(data) | |||
} | |||
if i := strings.IndexRune(data[:ex], 0); i > -1 { | |||
return nil, ParseError{ | |||
Message: "files cannot contain NULL bytes; probably using UTF-16; TOML files must be UTF-8", | |||
Position: Position{Line: 1, Start: i, Len: 1}, | |||
Line: 1, | |||
input: data, | |||
} | |||
} | |||
p = &parser{ | |||
mapping: make(map[string]interface{}), | |||
types: make(map[string]tomlType), | |||
lx: lex(data), | |||
ordered: make([]Key, 0), | |||
implicits: make(map[string]struct{}), | |||
} | |||
for { | |||
item := p.next() | |||
if item.typ == itemEOF { | |||
break | |||
} | |||
p.topLevel(item) | |||
} | |||
return p, nil | |||
} | |||
func (p *parser) panicItemf(it item, format string, v ...interface{}) { | |||
panic(ParseError{ | |||
Message: fmt.Sprintf(format, v...), | |||
Position: it.pos, | |||
Line: it.pos.Len, | |||
LastKey: p.current(), | |||
}) | |||
} | |||
func (p *parser) panicf(format string, v ...interface{}) { | |||
panic(ParseError{ | |||
Message: fmt.Sprintf(format, v...), | |||
Position: p.pos, | |||
Line: p.pos.Line, | |||
LastKey: p.current(), | |||
}) | |||
} | |||
func (p *parser) next() item { | |||
it := p.lx.nextItem() | |||
//fmt.Printf("ITEM %-18s line %-3d │ %q\n", it.typ, it.line, it.val) | |||
if it.typ == itemError { | |||
if it.err != nil { | |||
panic(ParseError{ | |||
Position: it.pos, | |||
Line: it.pos.Line, | |||
LastKey: p.current(), | |||
err: it.err, | |||
}) | |||
} | |||
p.panicItemf(it, "%s", it.val) | |||
} | |||
return it | |||
} | |||
func (p *parser) nextPos() item { | |||
it := p.next() | |||
p.pos = it.pos | |||
return it | |||
} | |||
func (p *parser) bug(format string, v ...interface{}) { | |||
panic(fmt.Sprintf("BUG: "+format+"\n\n", v...)) | |||
} | |||
func (p *parser) expect(typ itemType) item { | |||
it := p.next() | |||
p.assertEqual(typ, it.typ) | |||
return it | |||
} | |||
func (p *parser) assertEqual(expected, got itemType) { | |||
if expected != got { | |||
p.bug("Expected '%s' but got '%s'.", expected, got) | |||
} | |||
} | |||
func (p *parser) topLevel(item item) { | |||
switch item.typ { | |||
case itemCommentStart: // # .. | |||
p.expect(itemText) | |||
case itemTableStart: // [ .. ] | |||
name := p.nextPos() | |||
var key Key | |||
for ; name.typ != itemTableEnd && name.typ != itemEOF; name = p.next() { | |||
key = append(key, p.keyString(name)) | |||
} | |||
p.assertEqual(itemTableEnd, name.typ) | |||
p.addContext(key, false) | |||
p.setType("", tomlHash) | |||
p.ordered = append(p.ordered, key) | |||
case itemArrayTableStart: // [[ .. ]] | |||
name := p.nextPos() | |||
var key Key | |||
for ; name.typ != itemArrayTableEnd && name.typ != itemEOF; name = p.next() { | |||
key = append(key, p.keyString(name)) | |||
} | |||
p.assertEqual(itemArrayTableEnd, name.typ) | |||
p.addContext(key, true) | |||
p.setType("", tomlArrayHash) | |||
p.ordered = append(p.ordered, key) | |||
case itemKeyStart: // key = .. | |||
outerContext := p.context | |||
/// Read all the key parts (e.g. 'a' and 'b' in 'a.b') | |||
k := p.nextPos() | |||
var key Key | |||
for ; k.typ != itemKeyEnd && k.typ != itemEOF; k = p.next() { | |||
key = append(key, p.keyString(k)) | |||
} | |||
p.assertEqual(itemKeyEnd, k.typ) | |||
/// The current key is the last part. | |||
p.currentKey = key[len(key)-1] | |||
/// All the other parts (if any) are the context; need to set each part | |||
/// as implicit. | |||
context := key[:len(key)-1] | |||
for i := range context { | |||
p.addImplicitContext(append(p.context, context[i:i+1]...)) | |||
} | |||
/// Set value. | |||
val, typ := p.value(p.next(), false) | |||
p.set(p.currentKey, val, typ) | |||
p.ordered = append(p.ordered, p.context.add(p.currentKey)) | |||
/// Remove the context we added (preserving any context from [tbl] lines). | |||
p.context = outerContext | |||
p.currentKey = "" | |||
default: | |||
p.bug("Unexpected type at top level: %s", item.typ) | |||
} | |||
} | |||
// Gets a string for a key (or part of a key in a table name). | |||
func (p *parser) keyString(it item) string { | |||
switch it.typ { | |||
case itemText: | |||
return it.val | |||
case itemString, itemMultilineString, | |||
itemRawString, itemRawMultilineString: | |||
s, _ := p.value(it, false) | |||
return s.(string) | |||
default: | |||
p.bug("Unexpected key type: %s", it.typ) | |||
} | |||
panic("unreachable") | |||
} | |||
var datetimeRepl = strings.NewReplacer( | |||
"z", "Z", | |||
"t", "T", | |||
" ", "T") | |||
// value translates an expected value from the lexer into a Go value wrapped | |||
// as an empty interface. | |||
func (p *parser) value(it item, parentIsArray bool) (interface{}, tomlType) { | |||
switch it.typ { | |||
case itemString: | |||
return p.replaceEscapes(it, it.val), p.typeOfPrimitive(it) | |||
case itemMultilineString: | |||
return p.replaceEscapes(it, stripFirstNewline(stripEscapedNewlines(it.val))), p.typeOfPrimitive(it) | |||
case itemRawString: | |||
return it.val, p.typeOfPrimitive(it) | |||
case itemRawMultilineString: | |||
return stripFirstNewline(it.val), p.typeOfPrimitive(it) | |||
case itemInteger: | |||
return p.valueInteger(it) | |||
case itemFloat: | |||
return p.valueFloat(it) | |||
case itemBool: | |||
switch it.val { | |||
case "true": | |||
return true, p.typeOfPrimitive(it) | |||
case "false": | |||
return false, p.typeOfPrimitive(it) | |||
default: | |||
p.bug("Expected boolean value, but got '%s'.", it.val) | |||
} | |||
case itemDatetime: | |||
return p.valueDatetime(it) | |||
case itemArray: | |||
return p.valueArray(it) | |||
case itemInlineTableStart: | |||
return p.valueInlineTable(it, parentIsArray) | |||
default: | |||
p.bug("Unexpected value type: %s", it.typ) | |||
} | |||
panic("unreachable") | |||
} | |||
func (p *parser) valueInteger(it item) (interface{}, tomlType) { | |||
if !numUnderscoresOK(it.val) { | |||
p.panicItemf(it, "Invalid integer %q: underscores must be surrounded by digits", it.val) | |||
} | |||
if numHasLeadingZero(it.val) { | |||
p.panicItemf(it, "Invalid integer %q: cannot have leading zeroes", it.val) | |||
} | |||
num, err := strconv.ParseInt(it.val, 0, 64) | |||
if err != nil { | |||
// Distinguish integer values. Normally, it'd be a bug if the lexer | |||
// provides an invalid integer, but it's possible that the number is | |||
// out of range of valid values (which the lexer cannot determine). | |||
// So mark the former as a bug but the latter as a legitimate user | |||
// error. | |||
if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange { | |||
p.panicItemf(it, "Integer '%s' is out of the range of 64-bit signed integers.", it.val) | |||
} else { | |||
p.bug("Expected integer value, but got '%s'.", it.val) | |||
} | |||
} | |||
return num, p.typeOfPrimitive(it) | |||
} | |||
func (p *parser) valueFloat(it item) (interface{}, tomlType) { | |||
parts := strings.FieldsFunc(it.val, func(r rune) bool { | |||
switch r { | |||
case '.', 'e', 'E': | |||
return true | |||
} | |||
return false | |||
}) | |||
for _, part := range parts { | |||
if !numUnderscoresOK(part) { | |||
p.panicItemf(it, "Invalid float %q: underscores must be surrounded by digits", it.val) | |||
} | |||
} | |||
if len(parts) > 0 && numHasLeadingZero(parts[0]) { | |||
p.panicItemf(it, "Invalid float %q: cannot have leading zeroes", it.val) | |||
} | |||
if !numPeriodsOK(it.val) { | |||
// As a special case, numbers like '123.' or '1.e2', | |||
// which are valid as far as Go/strconv are concerned, | |||
// must be rejected because TOML says that a fractional | |||
// part consists of '.' followed by 1+ digits. | |||
p.panicItemf(it, "Invalid float %q: '.' must be followed by one or more digits", it.val) | |||
} | |||
val := strings.Replace(it.val, "_", "", -1) | |||
if val == "+nan" || val == "-nan" { // Go doesn't support this, but TOML spec does. | |||
val = "nan" | |||
} | |||
num, err := strconv.ParseFloat(val, 64) | |||
if err != nil { | |||
if e, ok := err.(*strconv.NumError); ok && e.Err == strconv.ErrRange { | |||
p.panicItemf(it, "Float '%s' is out of the range of 64-bit IEEE-754 floating-point numbers.", it.val) | |||
} else { | |||
p.panicItemf(it, "Invalid float value: %q", it.val) | |||
} | |||
} | |||
return num, p.typeOfPrimitive(it) | |||
} | |||
var dtTypes = []struct { | |||
fmt string | |||
zone *time.Location | |||
}{ | |||
{time.RFC3339Nano, time.Local}, | |||
{"2006-01-02T15:04:05.999999999", internal.LocalDatetime}, | |||
{"2006-01-02", internal.LocalDate}, | |||
{"15:04:05.999999999", internal.LocalTime}, | |||
} | |||
func (p *parser) valueDatetime(it item) (interface{}, tomlType) { | |||
it.val = datetimeRepl.Replace(it.val) | |||
var ( | |||
t time.Time | |||
ok bool | |||
err error | |||
) | |||
for _, dt := range dtTypes { | |||
t, err = time.ParseInLocation(dt.fmt, it.val, dt.zone) | |||
if err == nil { | |||
ok = true | |||
break | |||
} | |||
} | |||
if !ok { | |||
p.panicItemf(it, "Invalid TOML Datetime: %q.", it.val) | |||
} | |||
return t, p.typeOfPrimitive(it) | |||
} | |||
func (p *parser) valueArray(it item) (interface{}, tomlType) { | |||
p.setType(p.currentKey, tomlArray) | |||
// p.setType(p.currentKey, typ) | |||
var ( | |||
types []tomlType | |||
// Initialize to a non-nil empty slice. This makes it consistent with | |||
// how S = [] decodes into a non-nil slice inside something like struct | |||
// { S []string }. See #338 | |||
array = []interface{}{} | |||
) | |||
for it = p.next(); it.typ != itemArrayEnd; it = p.next() { | |||
if it.typ == itemCommentStart { | |||
p.expect(itemText) | |||
continue | |||
} | |||
val, typ := p.value(it, true) | |||
array = append(array, val) | |||
types = append(types, typ) | |||
// XXX: types isn't used here, we need it to record the accurate type | |||
// information. | |||
// | |||
// Not entirely sure how to best store this; could use "key[0]", | |||
// "key[1]" notation, or maybe store it on the Array type? | |||
} | |||
return array, tomlArray | |||
} | |||
func (p *parser) valueInlineTable(it item, parentIsArray bool) (interface{}, tomlType) { | |||
var ( | |||
hash = make(map[string]interface{}) | |||
outerContext = p.context | |||
outerKey = p.currentKey | |||
) | |||
p.context = append(p.context, p.currentKey) | |||
prevContext := p.context | |||
p.currentKey = "" | |||
p.addImplicit(p.context) | |||
p.addContext(p.context, parentIsArray) | |||
/// Loop over all table key/value pairs. | |||
for it := p.next(); it.typ != itemInlineTableEnd; it = p.next() { | |||
if it.typ == itemCommentStart { | |||
p.expect(itemText) | |||
continue | |||
} | |||
/// Read all key parts. | |||
k := p.nextPos() | |||
var key Key | |||
for ; k.typ != itemKeyEnd && k.typ != itemEOF; k = p.next() { | |||
key = append(key, p.keyString(k)) | |||
} | |||
p.assertEqual(itemKeyEnd, k.typ) | |||
/// The current key is the last part. | |||
p.currentKey = key[len(key)-1] | |||
/// All the other parts (if any) are the context; need to set each part | |||
/// as implicit. | |||
context := key[:len(key)-1] | |||
for i := range context { | |||
p.addImplicitContext(append(p.context, context[i:i+1]...)) | |||
} | |||
/// Set the value. | |||
val, typ := p.value(p.next(), false) | |||
p.set(p.currentKey, val, typ) | |||
p.ordered = append(p.ordered, p.context.add(p.currentKey)) | |||
hash[p.currentKey] = val | |||
/// Restore context. | |||
p.context = prevContext | |||
} | |||
p.context = outerContext | |||
p.currentKey = outerKey | |||
return hash, tomlHash | |||
} | |||
// numHasLeadingZero checks if this number has leading zeroes, allowing for '0', | |||
// +/- signs, and base prefixes. | |||
func numHasLeadingZero(s string) bool { | |||
if len(s) > 1 && s[0] == '0' && !(s[1] == 'b' || s[1] == 'o' || s[1] == 'x') { // Allow 0b, 0o, 0x | |||
return true | |||
} | |||
if len(s) > 2 && (s[0] == '-' || s[0] == '+') && s[1] == '0' { | |||
return true | |||
} | |||
return false | |||
} | |||
// numUnderscoresOK checks whether each underscore in s is surrounded by | |||
// characters that are not underscores. | |||
func numUnderscoresOK(s string) bool { | |||
switch s { | |||
case "nan", "+nan", "-nan", "inf", "-inf", "+inf": | |||
return true | |||
} | |||
accept := false | |||
for _, r := range s { | |||
if r == '_' { | |||
if !accept { | |||
return false | |||
} | |||
} | |||
// isHexadecimal is a superset of all the permissable characters | |||
// surrounding an underscore. | |||
accept = isHexadecimal(r) | |||
} | |||
return accept | |||
} | |||
// numPeriodsOK checks whether every period in s is followed by a digit. | |||
func numPeriodsOK(s string) bool { | |||
period := false | |||
for _, r := range s { | |||
if period && !isDigit(r) { | |||
return false | |||
} | |||
period = r == '.' | |||
} | |||
return !period | |||
} | |||
// Set the current context of the parser, where the context is either a hash or | |||
// an array of hashes, depending on the value of the `array` parameter. | |||
// | |||
// Establishing the context also makes sure that the key isn't a duplicate, and | |||
// will create implicit hashes automatically. | |||
func (p *parser) addContext(key Key, array bool) { | |||
var ok bool | |||
// Always start at the top level and drill down for our context. | |||
hashContext := p.mapping | |||
keyContext := make(Key, 0) | |||
// We only need implicit hashes for key[0:-1] | |||
for _, k := range key[0 : len(key)-1] { | |||
_, ok = hashContext[k] | |||
keyContext = append(keyContext, k) | |||
// No key? Make an implicit hash and move on. | |||
if !ok { | |||
p.addImplicit(keyContext) | |||
hashContext[k] = make(map[string]interface{}) | |||
} | |||
// If the hash context is actually an array of tables, then set | |||
// the hash context to the last element in that array. | |||
// | |||
// Otherwise, it better be a table, since this MUST be a key group (by | |||
// virtue of it not being the last element in a key). | |||
switch t := hashContext[k].(type) { | |||
case []map[string]interface{}: | |||
hashContext = t[len(t)-1] | |||
case map[string]interface{}: | |||
hashContext = t | |||
default: | |||
p.panicf("Key '%s' was already created as a hash.", keyContext) | |||
} | |||
} | |||
p.context = keyContext | |||
if array { | |||
// If this is the first element for this array, then allocate a new | |||
// list of tables for it. | |||
k := key[len(key)-1] | |||
if _, ok := hashContext[k]; !ok { | |||
hashContext[k] = make([]map[string]interface{}, 0, 4) | |||
} | |||
// Add a new table. But make sure the key hasn't already been used | |||
// for something else. | |||
if hash, ok := hashContext[k].([]map[string]interface{}); ok { | |||
hashContext[k] = append(hash, make(map[string]interface{})) | |||
} else { | |||
p.panicf("Key '%s' was already created and cannot be used as an array.", key) | |||
} | |||
} else { | |||
p.setValue(key[len(key)-1], make(map[string]interface{})) | |||
} | |||
p.context = append(p.context, key[len(key)-1]) | |||
} | |||
// set calls setValue and setType. | |||
func (p *parser) set(key string, val interface{}, typ tomlType) { | |||
p.setValue(key, val) | |||
p.setType(key, typ) | |||
} | |||
// setValue sets the given key to the given value in the current context. | |||
// It will make sure that the key hasn't already been defined, account for | |||
// implicit key groups. | |||
func (p *parser) setValue(key string, value interface{}) { | |||
var ( | |||
tmpHash interface{} | |||
ok bool | |||
hash = p.mapping | |||
keyContext Key | |||
) | |||
for _, k := range p.context { | |||
keyContext = append(keyContext, k) | |||
if tmpHash, ok = hash[k]; !ok { | |||
p.bug("Context for key '%s' has not been established.", keyContext) | |||
} | |||
switch t := tmpHash.(type) { | |||
case []map[string]interface{}: | |||
// The context is a table of hashes. Pick the most recent table | |||
// defined as the current hash. | |||
hash = t[len(t)-1] | |||
case map[string]interface{}: | |||
hash = t | |||
default: | |||
p.panicf("Key '%s' has already been defined.", keyContext) | |||
} | |||
} | |||
keyContext = append(keyContext, key) | |||
if _, ok := hash[key]; ok { | |||
// Normally redefining keys isn't allowed, but the key could have been | |||
// defined implicitly and it's allowed to be redefined concretely. (See | |||
// the `valid/implicit-and-explicit-after.toml` in toml-test) | |||
// | |||
// But we have to make sure to stop marking it as an implicit. (So that | |||
// another redefinition provokes an error.) | |||
// | |||
// Note that since it has already been defined (as a hash), we don't | |||
// want to overwrite it. So our business is done. | |||
if p.isArray(keyContext) { | |||
p.removeImplicit(keyContext) | |||
hash[key] = value | |||
return | |||
} | |||
if p.isImplicit(keyContext) { | |||
p.removeImplicit(keyContext) | |||
return | |||
} | |||
// Otherwise, we have a concrete key trying to override a previous | |||
// key, which is *always* wrong. | |||
p.panicf("Key '%s' has already been defined.", keyContext) | |||
} | |||
hash[key] = value | |||
} | |||
// setType sets the type of a particular value at a given key. It should be | |||
// called immediately AFTER setValue. | |||
// | |||
// Note that if `key` is empty, then the type given will be applied to the | |||
// current context (which is either a table or an array of tables). | |||
func (p *parser) setType(key string, typ tomlType) { | |||
keyContext := make(Key, 0, len(p.context)+1) | |||
keyContext = append(keyContext, p.context...) | |||
if len(key) > 0 { // allow type setting for hashes | |||
keyContext = append(keyContext, key) | |||
} | |||
// Special case to make empty keys ("" = 1) work. | |||
// Without it it will set "" rather than `""`. | |||
// TODO: why is this needed? And why is this only needed here? | |||
if len(keyContext) == 0 { | |||
keyContext = Key{""} | |||
} | |||
p.types[keyContext.String()] = typ | |||
} | |||
// Implicit keys need to be created when tables are implied in "a.b.c.d = 1" and | |||
// "[a.b.c]" (the "a", "b", and "c" hashes are never created explicitly). | |||
func (p *parser) addImplicit(key Key) { p.implicits[key.String()] = struct{}{} } | |||
func (p *parser) removeImplicit(key Key) { delete(p.implicits, key.String()) } | |||
func (p *parser) isImplicit(key Key) bool { _, ok := p.implicits[key.String()]; return ok } | |||
func (p *parser) isArray(key Key) bool { return p.types[key.String()] == tomlArray } | |||
func (p *parser) addImplicitContext(key Key) { | |||
p.addImplicit(key) | |||
p.addContext(key, false) | |||
} | |||
// current returns the full key name of the current context. | |||
func (p *parser) current() string { | |||
if len(p.currentKey) == 0 { | |||
return p.context.String() | |||
} | |||
if len(p.context) == 0 { | |||
return p.currentKey | |||
} | |||
return fmt.Sprintf("%s.%s", p.context, p.currentKey) | |||
} | |||
func stripFirstNewline(s string) string { | |||
if len(s) > 0 && s[0] == '\n' { | |||
return s[1:] | |||
} | |||
if len(s) > 1 && s[0] == '\r' && s[1] == '\n' { | |||
return s[2:] | |||
} | |||
return s | |||
} | |||
// Remove newlines inside triple-quoted strings if a line ends with "\". | |||
func stripEscapedNewlines(s string) string { | |||
split := strings.Split(s, "\n") | |||
if len(split) < 1 { | |||
return s | |||
} | |||
escNL := false // Keep track of the last non-blank line was escaped. | |||
for i, line := range split { | |||
line = strings.TrimRight(line, " \t\r") | |||
if len(line) == 0 || line[len(line)-1] != '\\' { | |||
split[i] = strings.TrimRight(split[i], "\r") | |||
if !escNL && i != len(split)-1 { | |||
split[i] += "\n" | |||
} | |||
continue | |||
} | |||
escBS := true | |||
for j := len(line) - 1; j >= 0 && line[j] == '\\'; j-- { | |||
escBS = !escBS | |||
} | |||
if escNL { | |||
line = strings.TrimLeft(line, " \t\r") | |||
} | |||
escNL = !escBS | |||
if escBS { | |||
split[i] += "\n" | |||
continue | |||
} | |||
split[i] = line[:len(line)-1] // Remove \ | |||
if len(split)-1 > i { | |||
split[i+1] = strings.TrimLeft(split[i+1], " \t\r") | |||
} | |||
} | |||
return strings.Join(split, "") | |||
} | |||
func (p *parser) replaceEscapes(it item, str string) string { | |||
replaced := make([]rune, 0, len(str)) | |||
s := []byte(str) | |||
r := 0 | |||
for r < len(s) { | |||
if s[r] != '\\' { | |||
c, size := utf8.DecodeRune(s[r:]) | |||
r += size | |||
replaced = append(replaced, c) | |||
continue | |||
} | |||
r += 1 | |||
if r >= len(s) { | |||
p.bug("Escape sequence at end of string.") | |||
return "" | |||
} | |||
switch s[r] { | |||
default: | |||
p.bug("Expected valid escape code after \\, but got %q.", s[r]) | |||
return "" | |||
case ' ', '\t': | |||
p.panicItemf(it, "invalid escape: '\\%c'", s[r]) | |||
return "" | |||
case 'b': | |||
replaced = append(replaced, rune(0x0008)) | |||
r += 1 | |||
case 't': | |||
replaced = append(replaced, rune(0x0009)) | |||
r += 1 | |||
case 'n': | |||
replaced = append(replaced, rune(0x000A)) | |||
r += 1 | |||
case 'f': | |||
replaced = append(replaced, rune(0x000C)) | |||
r += 1 | |||
case 'r': | |||
replaced = append(replaced, rune(0x000D)) | |||
r += 1 | |||
case '"': | |||
replaced = append(replaced, rune(0x0022)) | |||
r += 1 | |||
case '\\': | |||
replaced = append(replaced, rune(0x005C)) | |||
r += 1 | |||
case 'u': | |||
// At this point, we know we have a Unicode escape of the form | |||
// `uXXXX` at [r, r+5). (Because the lexer guarantees this | |||
// for us.) | |||
escaped := p.asciiEscapeToUnicode(it, s[r+1:r+5]) | |||
replaced = append(replaced, escaped) | |||
r += 5 | |||
case 'U': | |||
// At this point, we know we have a Unicode escape of the form | |||
// `uXXXX` at [r, r+9). (Because the lexer guarantees this | |||
// for us.) | |||
escaped := p.asciiEscapeToUnicode(it, s[r+1:r+9]) | |||
replaced = append(replaced, escaped) | |||
r += 9 | |||
} | |||
} | |||
return string(replaced) | |||
} | |||
func (p *parser) asciiEscapeToUnicode(it item, bs []byte) rune { | |||
s := string(bs) | |||
hex, err := strconv.ParseUint(strings.ToLower(s), 16, 32) | |||
if err != nil { | |||
p.bug("Could not parse '%s' as a hexadecimal number, but the lexer claims it's OK: %s", s, err) | |||
} | |||
if !utf8.ValidRune(rune(hex)) { | |||
p.panicItemf(it, "Escaped character '\\u%s' is not valid UTF-8.", s) | |||
} | |||
return rune(hex) | |||
} |
@@ -0,0 +1,242 @@ | |||
package toml | |||
// Struct field handling is adapted from code in encoding/json: | |||
// | |||
// Copyright 2010 The Go Authors. All rights reserved. | |||
// Use of this source code is governed by a BSD-style | |||
// license that can be found in the Go distribution. | |||
import ( | |||
"reflect" | |||
"sort" | |||
"sync" | |||
) | |||
// A field represents a single field found in a struct. | |||
type field struct { | |||
name string // the name of the field (`toml` tag included) | |||
tag bool // whether field has a `toml` tag | |||
index []int // represents the depth of an anonymous field | |||
typ reflect.Type // the type of the field | |||
} | |||
// byName sorts field by name, breaking ties with depth, | |||
// then breaking ties with "name came from toml tag", then | |||
// breaking ties with index sequence. | |||
type byName []field | |||
func (x byName) Len() int { return len(x) } | |||
func (x byName) Swap(i, j int) { x[i], x[j] = x[j], x[i] } | |||
func (x byName) Less(i, j int) bool { | |||
if x[i].name != x[j].name { | |||
return x[i].name < x[j].name | |||
} | |||
if len(x[i].index) != len(x[j].index) { | |||
return len(x[i].index) < len(x[j].index) | |||
} | |||
if x[i].tag != x[j].tag { | |||
return x[i].tag | |||
} | |||
return byIndex(x).Less(i, j) | |||
} | |||
// byIndex sorts field by index sequence. | |||
type byIndex []field | |||
func (x byIndex) Len() int { return len(x) } | |||
func (x byIndex) Swap(i, j int) { x[i], x[j] = x[j], x[i] } | |||
func (x byIndex) Less(i, j int) bool { | |||
for k, xik := range x[i].index { | |||
if k >= len(x[j].index) { | |||
return false | |||
} | |||
if xik != x[j].index[k] { | |||
return xik < x[j].index[k] | |||
} | |||
} | |||
return len(x[i].index) < len(x[j].index) | |||
} | |||
// typeFields returns a list of fields that TOML should recognize for the given | |||
// type. The algorithm is breadth-first search over the set of structs to | |||
// include - the top struct and then any reachable anonymous structs. | |||
func typeFields(t reflect.Type) []field { | |||
// Anonymous fields to explore at the current level and the next. | |||
current := []field{} | |||
next := []field{{typ: t}} | |||
// Count of queued names for current level and the next. | |||
var count map[reflect.Type]int | |||
var nextCount map[reflect.Type]int | |||
// Types already visited at an earlier level. | |||
visited := map[reflect.Type]bool{} | |||
// Fields found. | |||
var fields []field | |||
for len(next) > 0 { | |||
current, next = next, current[:0] | |||
count, nextCount = nextCount, map[reflect.Type]int{} | |||
for _, f := range current { | |||
if visited[f.typ] { | |||
continue | |||
} | |||
visited[f.typ] = true | |||
// Scan f.typ for fields to include. | |||
for i := 0; i < f.typ.NumField(); i++ { | |||
sf := f.typ.Field(i) | |||
if sf.PkgPath != "" && !sf.Anonymous { // unexported | |||
continue | |||
} | |||
opts := getOptions(sf.Tag) | |||
if opts.skip { | |||
continue | |||
} | |||
index := make([]int, len(f.index)+1) | |||
copy(index, f.index) | |||
index[len(f.index)] = i | |||
ft := sf.Type | |||
if ft.Name() == "" && ft.Kind() == reflect.Ptr { | |||
// Follow pointer. | |||
ft = ft.Elem() | |||
} | |||
// Record found field and index sequence. | |||
if opts.name != "" || !sf.Anonymous || ft.Kind() != reflect.Struct { | |||
tagged := opts.name != "" | |||
name := opts.name | |||
if name == "" { | |||
name = sf.Name | |||
} | |||
fields = append(fields, field{name, tagged, index, ft}) | |||
if count[f.typ] > 1 { | |||
// If there were multiple instances, add a second, | |||
// so that the annihilation code will see a duplicate. | |||
// It only cares about the distinction between 1 or 2, | |||
// so don't bother generating any more copies. | |||
fields = append(fields, fields[len(fields)-1]) | |||
} | |||
continue | |||
} | |||
// Record new anonymous struct to explore in next round. | |||
nextCount[ft]++ | |||
if nextCount[ft] == 1 { | |||
f := field{name: ft.Name(), index: index, typ: ft} | |||
next = append(next, f) | |||
} | |||
} | |||
} | |||
} | |||
sort.Sort(byName(fields)) | |||
// Delete all fields that are hidden by the Go rules for embedded fields, | |||
// except that fields with TOML tags are promoted. | |||
// The fields are sorted in primary order of name, secondary order | |||
// of field index length. Loop over names; for each name, delete | |||
// hidden fields by choosing the one dominant field that survives. | |||
out := fields[:0] | |||
for advance, i := 0, 0; i < len(fields); i += advance { | |||
// One iteration per name. | |||
// Find the sequence of fields with the name of this first field. | |||
fi := fields[i] | |||
name := fi.name | |||
for advance = 1; i+advance < len(fields); advance++ { | |||
fj := fields[i+advance] | |||
if fj.name != name { | |||
break | |||
} | |||
} | |||
if advance == 1 { // Only one field with this name | |||
out = append(out, fi) | |||
continue | |||
} | |||
dominant, ok := dominantField(fields[i : i+advance]) | |||
if ok { | |||
out = append(out, dominant) | |||
} | |||
} | |||
fields = out | |||
sort.Sort(byIndex(fields)) | |||
return fields | |||
} | |||
// dominantField looks through the fields, all of which are known to | |||
// have the same name, to find the single field that dominates the | |||
// others using Go's embedding rules, modified by the presence of | |||
// TOML tags. If there are multiple top-level fields, the boolean | |||
// will be false: This condition is an error in Go and we skip all | |||
// the fields. | |||
func dominantField(fields []field) (field, bool) { | |||
// The fields are sorted in increasing index-length order. The winner | |||
// must therefore be one with the shortest index length. Drop all | |||
// longer entries, which is easy: just truncate the slice. | |||
length := len(fields[0].index) | |||
tagged := -1 // Index of first tagged field. | |||
for i, f := range fields { | |||
if len(f.index) > length { | |||
fields = fields[:i] | |||
break | |||
} | |||
if f.tag { | |||
if tagged >= 0 { | |||
// Multiple tagged fields at the same level: conflict. | |||
// Return no field. | |||
return field{}, false | |||
} | |||
tagged = i | |||
} | |||
} | |||
if tagged >= 0 { | |||
return fields[tagged], true | |||
} | |||
// All remaining fields have the same length. If there's more than one, | |||
// we have a conflict (two fields named "X" at the same level) and we | |||
// return no field. | |||
if len(fields) > 1 { | |||
return field{}, false | |||
} | |||
return fields[0], true | |||
} | |||
var fieldCache struct { | |||
sync.RWMutex | |||
m map[reflect.Type][]field | |||
} | |||
// cachedTypeFields is like typeFields but uses a cache to avoid repeated work. | |||
func cachedTypeFields(t reflect.Type) []field { | |||
fieldCache.RLock() | |||
f := fieldCache.m[t] | |||
fieldCache.RUnlock() | |||
if f != nil { | |||
return f | |||
} | |||
// Compute fields without lock. | |||
// Might duplicate effort but won't hold other computations back. | |||
f = typeFields(t) | |||
if f == nil { | |||
f = []field{} | |||
} | |||
fieldCache.Lock() | |||
if fieldCache.m == nil { | |||
fieldCache.m = map[reflect.Type][]field{} | |||
} | |||
fieldCache.m[t] = f | |||
fieldCache.Unlock() | |||
return f | |||
} |
@@ -0,0 +1,70 @@ | |||
package toml | |||
// tomlType represents any Go type that corresponds to a TOML type. | |||
// While the first draft of the TOML spec has a simplistic type system that | |||
// probably doesn't need this level of sophistication, we seem to be militating | |||
// toward adding real composite types. | |||
type tomlType interface { | |||
typeString() string | |||
} | |||
// typeEqual accepts any two types and returns true if they are equal. | |||
func typeEqual(t1, t2 tomlType) bool { | |||
if t1 == nil || t2 == nil { | |||
return false | |||
} | |||
return t1.typeString() == t2.typeString() | |||
} | |||
func typeIsTable(t tomlType) bool { | |||
return typeEqual(t, tomlHash) || typeEqual(t, tomlArrayHash) | |||
} | |||
type tomlBaseType string | |||
func (btype tomlBaseType) typeString() string { | |||
return string(btype) | |||
} | |||
func (btype tomlBaseType) String() string { | |||
return btype.typeString() | |||
} | |||
var ( | |||
tomlInteger tomlBaseType = "Integer" | |||
tomlFloat tomlBaseType = "Float" | |||
tomlDatetime tomlBaseType = "Datetime" | |||
tomlString tomlBaseType = "String" | |||
tomlBool tomlBaseType = "Bool" | |||
tomlArray tomlBaseType = "Array" | |||
tomlHash tomlBaseType = "Hash" | |||
tomlArrayHash tomlBaseType = "ArrayHash" | |||
) | |||
// typeOfPrimitive returns a tomlType of any primitive value in TOML. | |||
// Primitive values are: Integer, Float, Datetime, String and Bool. | |||
// | |||
// Passing a lexer item other than the following will cause a BUG message | |||
// to occur: itemString, itemBool, itemInteger, itemFloat, itemDatetime. | |||
func (p *parser) typeOfPrimitive(lexItem item) tomlType { | |||
switch lexItem.typ { | |||
case itemInteger: | |||
return tomlInteger | |||
case itemFloat: | |||
return tomlFloat | |||
case itemDatetime: | |||
return tomlDatetime | |||
case itemString: | |||
return tomlString | |||
case itemMultilineString: | |||
return tomlString | |||
case itemRawString: | |||
return tomlString | |||
case itemRawMultilineString: | |||
return tomlString | |||
case itemBool: | |||
return tomlBool | |||
} | |||
p.bug("Cannot infer primitive type of lex item '%s'.", lexItem) | |||
panic("unreachable") | |||
} |
@@ -0,0 +1,22 @@ | |||
Copyright (c) 2016 Caleb Spare | |||
MIT License | |||
Permission is hereby granted, free of charge, to any person obtaining | |||
a copy of this software and associated documentation files (the | |||
"Software"), to deal in the Software without restriction, including | |||
without limitation the rights to use, copy, modify, merge, publish, | |||
distribute, sublicense, and/or sell copies of the Software, and to | |||
permit persons to whom the Software is furnished to do so, subject to | |||
the following conditions: | |||
The above copyright notice and this permission notice shall be | |||
included in all copies or substantial portions of the Software. | |||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, | |||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF | |||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND | |||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE | |||
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION | |||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION | |||
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. |
@@ -0,0 +1,69 @@ | |||
# xxhash | |||
[![Go Reference](https://pkg.go.dev/badge/github.com/cespare/xxhash/v2.svg)](https://pkg.go.dev/github.com/cespare/xxhash/v2) | |||
[![Test](https://github.com/cespare/xxhash/actions/workflows/test.yml/badge.svg)](https://github.com/cespare/xxhash/actions/workflows/test.yml) | |||
xxhash is a Go implementation of the 64-bit | |||
[xxHash](http://cyan4973.github.io/xxHash/) algorithm, XXH64. This is a | |||
high-quality hashing algorithm that is much faster than anything in the Go | |||
standard library. | |||
This package provides a straightforward API: | |||
``` | |||
func Sum64(b []byte) uint64 | |||
func Sum64String(s string) uint64 | |||
type Digest struct{ ... } | |||
func New() *Digest | |||
``` | |||
The `Digest` type implements hash.Hash64. Its key methods are: | |||
``` | |||
func (*Digest) Write([]byte) (int, error) | |||
func (*Digest) WriteString(string) (int, error) | |||
func (*Digest) Sum64() uint64 | |||
``` | |||
This implementation provides a fast pure-Go implementation and an even faster | |||
assembly implementation for amd64. | |||
## Compatibility | |||
This package is in a module and the latest code is in version 2 of the module. | |||
You need a version of Go with at least "minimal module compatibility" to use | |||
github.com/cespare/xxhash/v2: | |||
* 1.9.7+ for Go 1.9 | |||
* 1.10.3+ for Go 1.10 | |||
* Go 1.11 or later | |||
I recommend using the latest release of Go. | |||
## Benchmarks | |||
Here are some quick benchmarks comparing the pure-Go and assembly | |||
implementations of Sum64. | |||
| input size | purego | asm | | |||
| --- | --- | --- | | |||
| 5 B | 979.66 MB/s | 1291.17 MB/s | | |||
| 100 B | 7475.26 MB/s | 7973.40 MB/s | | |||
| 4 KB | 17573.46 MB/s | 17602.65 MB/s | | |||
| 10 MB | 17131.46 MB/s | 17142.16 MB/s | | |||
These numbers were generated on Ubuntu 18.04 with an Intel i7-8700K CPU using | |||
the following commands under Go 1.11.2: | |||
``` | |||
$ go test -tags purego -benchtime 10s -bench '/xxhash,direct,bytes' | |||
$ go test -benchtime 10s -bench '/xxhash,direct,bytes' | |||
``` | |||
## Projects using this package | |||
- [InfluxDB](https://github.com/influxdata/influxdb) | |||
- [Prometheus](https://github.com/prometheus/prometheus) | |||
- [VictoriaMetrics](https://github.com/VictoriaMetrics/VictoriaMetrics) | |||
- [FreeCache](https://github.com/coocood/freecache) | |||
- [FastCache](https://github.com/VictoriaMetrics/fastcache) |
@@ -0,0 +1,235 @@ | |||
// Package xxhash implements the 64-bit variant of xxHash (XXH64) as described | |||
// at http://cyan4973.github.io/xxHash/. | |||
package xxhash | |||
import ( | |||
"encoding/binary" | |||
"errors" | |||
"math/bits" | |||
) | |||
const ( | |||
prime1 uint64 = 11400714785074694791 | |||
prime2 uint64 = 14029467366897019727 | |||
prime3 uint64 = 1609587929392839161 | |||
prime4 uint64 = 9650029242287828579 | |||
prime5 uint64 = 2870177450012600261 | |||
) | |||
// NOTE(caleb): I'm using both consts and vars of the primes. Using consts where | |||
// possible in the Go code is worth a small (but measurable) performance boost | |||
// by avoiding some MOVQs. Vars are needed for the asm and also are useful for | |||
// convenience in the Go code in a few places where we need to intentionally | |||
// avoid constant arithmetic (e.g., v1 := prime1 + prime2 fails because the | |||
// result overflows a uint64). | |||
var ( | |||
prime1v = prime1 | |||
prime2v = prime2 | |||
prime3v = prime3 | |||
prime4v = prime4 | |||
prime5v = prime5 | |||
) | |||
// Digest implements hash.Hash64. | |||
type Digest struct { | |||
v1 uint64 | |||
v2 uint64 | |||
v3 uint64 | |||
v4 uint64 | |||
total uint64 | |||
mem [32]byte | |||
n int // how much of mem is used | |||
} | |||
// New creates a new Digest that computes the 64-bit xxHash algorithm. | |||
func New() *Digest { | |||
var d Digest | |||
d.Reset() | |||
return &d | |||
} | |||
// Reset clears the Digest's state so that it can be reused. | |||
func (d *Digest) Reset() { | |||
d.v1 = prime1v + prime2 | |||
d.v2 = prime2 | |||
d.v3 = 0 | |||
d.v4 = -prime1v | |||
d.total = 0 | |||
d.n = 0 | |||
} | |||
// Size always returns 8 bytes. | |||
func (d *Digest) Size() int { return 8 } | |||
// BlockSize always returns 32 bytes. | |||
func (d *Digest) BlockSize() int { return 32 } | |||
// Write adds more data to d. It always returns len(b), nil. | |||
func (d *Digest) Write(b []byte) (n int, err error) { | |||
n = len(b) | |||
d.total += uint64(n) | |||
if d.n+n < 32 { | |||
// This new data doesn't even fill the current block. | |||
copy(d.mem[d.n:], b) | |||
d.n += n | |||
return | |||
} | |||
if d.n > 0 { | |||
// Finish off the partial block. | |||
copy(d.mem[d.n:], b) | |||
d.v1 = round(d.v1, u64(d.mem[0:8])) | |||
d.v2 = round(d.v2, u64(d.mem[8:16])) | |||
d.v3 = round(d.v3, u64(d.mem[16:24])) | |||
d.v4 = round(d.v4, u64(d.mem[24:32])) | |||
b = b[32-d.n:] | |||
d.n = 0 | |||
} | |||
if len(b) >= 32 { | |||
// One or more full blocks left. | |||
nw := writeBlocks(d, b) | |||
b = b[nw:] | |||
} | |||
// Store any remaining partial block. | |||
copy(d.mem[:], b) | |||
d.n = len(b) | |||
return | |||
} | |||
// Sum appends the current hash to b and returns the resulting slice. | |||
func (d *Digest) Sum(b []byte) []byte { | |||
s := d.Sum64() | |||
return append( | |||
b, | |||
byte(s>>56), | |||
byte(s>>48), | |||
byte(s>>40), | |||
byte(s>>32), | |||
byte(s>>24), | |||
byte(s>>16), | |||
byte(s>>8), | |||
byte(s), | |||
) | |||
} | |||
// Sum64 returns the current hash. | |||
func (d *Digest) Sum64() uint64 { | |||
var h uint64 | |||
if d.total >= 32 { | |||
v1, v2, v3, v4 := d.v1, d.v2, d.v3, d.v4 | |||
h = rol1(v1) + rol7(v2) + rol12(v3) + rol18(v4) | |||
h = mergeRound(h, v1) | |||
h = mergeRound(h, v2) | |||
h = mergeRound(h, v3) | |||
h = mergeRound(h, v4) | |||
} else { | |||
h = d.v3 + prime5 | |||
} | |||
h += d.total | |||
i, end := 0, d.n | |||
for ; i+8 <= end; i += 8 { | |||
k1 := round(0, u64(d.mem[i:i+8])) | |||
h ^= k1 | |||
h = rol27(h)*prime1 + prime4 | |||
} | |||
if i+4 <= end { | |||
h ^= uint64(u32(d.mem[i:i+4])) * prime1 | |||
h = rol23(h)*prime2 + prime3 | |||
i += 4 | |||
} | |||
for i < end { | |||
h ^= uint64(d.mem[i]) * prime5 | |||
h = rol11(h) * prime1 | |||
i++ | |||
} | |||
h ^= h >> 33 | |||
h *= prime2 | |||
h ^= h >> 29 | |||
h *= prime3 | |||
h ^= h >> 32 | |||
return h | |||
} | |||
const ( | |||
magic = "xxh\x06" | |||
marshaledSize = len(magic) + 8*5 + 32 | |||
) | |||
// MarshalBinary implements the encoding.BinaryMarshaler interface. | |||
func (d *Digest) MarshalBinary() ([]byte, error) { | |||
b := make([]byte, 0, marshaledSize) | |||
b = append(b, magic...) | |||
b = appendUint64(b, d.v1) | |||
b = appendUint64(b, d.v2) | |||
b = appendUint64(b, d.v3) | |||
b = appendUint64(b, d.v4) | |||
b = appendUint64(b, d.total) | |||
b = append(b, d.mem[:d.n]...) | |||
b = b[:len(b)+len(d.mem)-d.n] | |||
return b, nil | |||
} | |||
// UnmarshalBinary implements the encoding.BinaryUnmarshaler interface. | |||
func (d *Digest) UnmarshalBinary(b []byte) error { | |||
if len(b) < len(magic) || string(b[:len(magic)]) != magic { | |||
return errors.New("xxhash: invalid hash state identifier") | |||
} | |||
if len(b) != marshaledSize { | |||
return errors.New("xxhash: invalid hash state size") | |||
} | |||
b = b[len(magic):] | |||
b, d.v1 = consumeUint64(b) | |||
b, d.v2 = consumeUint64(b) | |||
b, d.v3 = consumeUint64(b) | |||
b, d.v4 = consumeUint64(b) | |||
b, d.total = consumeUint64(b) | |||
copy(d.mem[:], b) | |||
d.n = int(d.total % uint64(len(d.mem))) | |||
return nil | |||
} | |||
func appendUint64(b []byte, x uint64) []byte { | |||
var a [8]byte | |||
binary.LittleEndian.PutUint64(a[:], x) | |||
return append(b, a[:]...) | |||
} | |||
func consumeUint64(b []byte) ([]byte, uint64) { | |||
x := u64(b) | |||
return b[8:], x | |||
} | |||
func u64(b []byte) uint64 { return binary.LittleEndian.Uint64(b) } | |||
func u32(b []byte) uint32 { return binary.LittleEndian.Uint32(b) } | |||
func round(acc, input uint64) uint64 { | |||
acc += input * prime2 | |||
acc = rol31(acc) | |||
acc *= prime1 | |||
return acc | |||
} | |||
func mergeRound(acc, val uint64) uint64 { | |||
val = round(0, val) | |||
acc ^= val | |||
acc = acc*prime1 + prime4 | |||
return acc | |||
} | |||
func rol1(x uint64) uint64 { return bits.RotateLeft64(x, 1) } | |||
func rol7(x uint64) uint64 { return bits.RotateLeft64(x, 7) } | |||
func rol11(x uint64) uint64 { return bits.RotateLeft64(x, 11) } | |||
func rol12(x uint64) uint64 { return bits.RotateLeft64(x, 12) } | |||
func rol18(x uint64) uint64 { return bits.RotateLeft64(x, 18) } | |||
func rol23(x uint64) uint64 { return bits.RotateLeft64(x, 23) } | |||
func rol27(x uint64) uint64 { return bits.RotateLeft64(x, 27) } | |||
func rol31(x uint64) uint64 { return bits.RotateLeft64(x, 31) } |
@@ -0,0 +1,13 @@ | |||
// +build !appengine | |||
// +build gc | |||
// +build !purego | |||
package xxhash | |||
// Sum64 computes the 64-bit xxHash digest of b. | |||
// | |||
//go:noescape | |||
func Sum64(b []byte) uint64 | |||
//go:noescape | |||
func writeBlocks(d *Digest, b []byte) int |
@@ -0,0 +1,215 @@ | |||
// +build !appengine | |||
// +build gc | |||
// +build !purego | |||
#include "textflag.h" | |||
// Register allocation: | |||
// AX h | |||
// SI pointer to advance through b | |||
// DX n | |||
// BX loop end | |||
// R8 v1, k1 | |||
// R9 v2 | |||
// R10 v3 | |||
// R11 v4 | |||
// R12 tmp | |||
// R13 prime1v | |||
// R14 prime2v | |||
// DI prime4v | |||
// round reads from and advances the buffer pointer in SI. | |||
// It assumes that R13 has prime1v and R14 has prime2v. | |||
#define round(r) \ | |||
MOVQ (SI), R12 \ | |||
ADDQ $8, SI \ | |||
IMULQ R14, R12 \ | |||
ADDQ R12, r \ | |||
ROLQ $31, r \ | |||
IMULQ R13, r | |||
// mergeRound applies a merge round on the two registers acc and val. | |||
// It assumes that R13 has prime1v, R14 has prime2v, and DI has prime4v. | |||
#define mergeRound(acc, val) \ | |||
IMULQ R14, val \ | |||
ROLQ $31, val \ | |||
IMULQ R13, val \ | |||
XORQ val, acc \ | |||
IMULQ R13, acc \ | |||
ADDQ DI, acc | |||
// func Sum64(b []byte) uint64 | |||
TEXT ·Sum64(SB), NOSPLIT, $0-32 | |||
// Load fixed primes. | |||
MOVQ ·prime1v(SB), R13 | |||
MOVQ ·prime2v(SB), R14 | |||
MOVQ ·prime4v(SB), DI | |||
// Load slice. | |||
MOVQ b_base+0(FP), SI | |||
MOVQ b_len+8(FP), DX | |||
LEAQ (SI)(DX*1), BX | |||
// The first loop limit will be len(b)-32. | |||
SUBQ $32, BX | |||
// Check whether we have at least one block. | |||
CMPQ DX, $32 | |||
JLT noBlocks | |||
// Set up initial state (v1, v2, v3, v4). | |||
MOVQ R13, R8 | |||
ADDQ R14, R8 | |||
MOVQ R14, R9 | |||
XORQ R10, R10 | |||
XORQ R11, R11 | |||
SUBQ R13, R11 | |||
// Loop until SI > BX. | |||
blockLoop: | |||
round(R8) | |||
round(R9) | |||
round(R10) | |||
round(R11) | |||
CMPQ SI, BX | |||
JLE blockLoop | |||
MOVQ R8, AX | |||
ROLQ $1, AX | |||
MOVQ R9, R12 | |||
ROLQ $7, R12 | |||
ADDQ R12, AX | |||
MOVQ R10, R12 | |||
ROLQ $12, R12 | |||
ADDQ R12, AX | |||
MOVQ R11, R12 | |||
ROLQ $18, R12 | |||
ADDQ R12, AX | |||
mergeRound(AX, R8) | |||
mergeRound(AX, R9) | |||
mergeRound(AX, R10) | |||
mergeRound(AX, R11) | |||
JMP afterBlocks | |||
noBlocks: | |||
MOVQ ·prime5v(SB), AX | |||
afterBlocks: | |||
ADDQ DX, AX | |||
// Right now BX has len(b)-32, and we want to loop until SI > len(b)-8. | |||
ADDQ $24, BX | |||
CMPQ SI, BX | |||
JG fourByte | |||
wordLoop: | |||
// Calculate k1. | |||
MOVQ (SI), R8 | |||
ADDQ $8, SI | |||
IMULQ R14, R8 | |||
ROLQ $31, R8 | |||
IMULQ R13, R8 | |||
XORQ R8, AX | |||
ROLQ $27, AX | |||
IMULQ R13, AX | |||
ADDQ DI, AX | |||
CMPQ SI, BX | |||
JLE wordLoop | |||
fourByte: | |||
ADDQ $4, BX | |||
CMPQ SI, BX | |||
JG singles | |||
MOVL (SI), R8 | |||
ADDQ $4, SI | |||
IMULQ R13, R8 | |||
XORQ R8, AX | |||
ROLQ $23, AX | |||
IMULQ R14, AX | |||
ADDQ ·prime3v(SB), AX | |||
singles: | |||
ADDQ $4, BX | |||
CMPQ SI, BX | |||
JGE finalize | |||
singlesLoop: | |||
MOVBQZX (SI), R12 | |||
ADDQ $1, SI | |||
IMULQ ·prime5v(SB), R12 | |||
XORQ R12, AX | |||
ROLQ $11, AX | |||
IMULQ R13, AX | |||
CMPQ SI, BX | |||
JL singlesLoop | |||
finalize: | |||
MOVQ AX, R12 | |||
SHRQ $33, R12 | |||
XORQ R12, AX | |||
IMULQ R14, AX | |||
MOVQ AX, R12 | |||
SHRQ $29, R12 | |||
XORQ R12, AX | |||
IMULQ ·prime3v(SB), AX | |||
MOVQ AX, R12 | |||
SHRQ $32, R12 | |||
XORQ R12, AX | |||
MOVQ AX, ret+24(FP) | |||
RET | |||
// writeBlocks uses the same registers as above except that it uses AX to store | |||
// the d pointer. | |||
// func writeBlocks(d *Digest, b []byte) int | |||
TEXT ·writeBlocks(SB), NOSPLIT, $0-40 | |||
// Load fixed primes needed for round. | |||
MOVQ ·prime1v(SB), R13 | |||
MOVQ ·prime2v(SB), R14 | |||
// Load slice. | |||
MOVQ b_base+8(FP), SI | |||
MOVQ b_len+16(FP), DX | |||
LEAQ (SI)(DX*1), BX | |||
SUBQ $32, BX | |||
// Load vN from d. | |||
MOVQ d+0(FP), AX | |||
MOVQ 0(AX), R8 // v1 | |||
MOVQ 8(AX), R9 // v2 | |||
MOVQ 16(AX), R10 // v3 | |||
MOVQ 24(AX), R11 // v4 | |||
// We don't need to check the loop condition here; this function is | |||
// always called with at least one block of data to process. | |||
blockLoop: | |||
round(R8) | |||
round(R9) | |||
round(R10) | |||
round(R11) | |||
CMPQ SI, BX | |||
JLE blockLoop | |||
// Copy vN back to d. | |||
MOVQ R8, 0(AX) | |||
MOVQ R9, 8(AX) | |||
MOVQ R10, 16(AX) | |||
MOVQ R11, 24(AX) | |||
// The number of bytes written is SI minus the old base pointer. | |||
SUBQ b_base+8(FP), SI | |||
MOVQ SI, ret+32(FP) | |||
RET |
@@ -0,0 +1,76 @@ | |||
// +build !amd64 appengine !gc purego | |||
package xxhash | |||
// Sum64 computes the 64-bit xxHash digest of b. | |||
func Sum64(b []byte) uint64 { | |||
// A simpler version would be | |||
// d := New() | |||
// d.Write(b) | |||
// return d.Sum64() | |||
// but this is faster, particularly for small inputs. | |||
n := len(b) | |||
var h uint64 | |||
if n >= 32 { | |||
v1 := prime1v + prime2 | |||
v2 := prime2 | |||
v3 := uint64(0) | |||
v4 := -prime1v | |||
for len(b) >= 32 { | |||
v1 = round(v1, u64(b[0:8:len(b)])) | |||
v2 = round(v2, u64(b[8:16:len(b)])) | |||
v3 = round(v3, u64(b[16:24:len(b)])) | |||
v4 = round(v4, u64(b[24:32:len(b)])) | |||
b = b[32:len(b):len(b)] | |||
} | |||
h = rol1(v1) + rol7(v2) + rol12(v3) + rol18(v4) | |||
h = mergeRound(h, v1) | |||
h = mergeRound(h, v2) | |||
h = mergeRound(h, v3) | |||
h = mergeRound(h, v4) | |||
} else { | |||
h = prime5 | |||
} | |||
h += uint64(n) | |||
i, end := 0, len(b) | |||
for ; i+8 <= end; i += 8 { | |||
k1 := round(0, u64(b[i:i+8:len(b)])) | |||
h ^= k1 | |||
h = rol27(h)*prime1 + prime4 | |||
} | |||
if i+4 <= end { | |||
h ^= uint64(u32(b[i:i+4:len(b)])) * prime1 | |||
h = rol23(h)*prime2 + prime3 | |||
i += 4 | |||
} | |||
for ; i < end; i++ { | |||
h ^= uint64(b[i]) * prime5 | |||
h = rol11(h) * prime1 | |||
} | |||
h ^= h >> 33 | |||
h *= prime2 | |||
h ^= h >> 29 | |||
h *= prime3 | |||
h ^= h >> 32 | |||
return h | |||
} | |||
func writeBlocks(d *Digest, b []byte) int { | |||
v1, v2, v3, v4 := d.v1, d.v2, d.v3, d.v4 | |||
n := len(b) | |||
for len(b) >= 32 { | |||
v1 = round(v1, u64(b[0:8:len(b)])) | |||
v2 = round(v2, u64(b[8:16:len(b)])) | |||
v3 = round(v3, u64(b[16:24:len(b)])) | |||
v4 = round(v4, u64(b[24:32:len(b)])) | |||
b = b[32:len(b):len(b)] | |||
} | |||
d.v1, d.v2, d.v3, d.v4 = v1, v2, v3, v4 | |||
return n - len(b) | |||
} |
@@ -0,0 +1,15 @@ | |||
// +build appengine | |||
// This file contains the safe implementations of otherwise unsafe-using code. | |||
package xxhash | |||
// Sum64String computes the 64-bit xxHash digest of s. | |||
func Sum64String(s string) uint64 { | |||
return Sum64([]byte(s)) | |||
} | |||
// WriteString adds more data to d. It always returns len(s), nil. | |||
func (d *Digest) WriteString(s string) (n int, err error) { | |||
return d.Write([]byte(s)) | |||
} |
@@ -0,0 +1,57 @@ | |||
// +build !appengine | |||
// This file encapsulates usage of unsafe. | |||
// xxhash_safe.go contains the safe implementations. | |||
package xxhash | |||
import ( | |||
"unsafe" | |||
) | |||
// In the future it's possible that compiler optimizations will make these | |||
// XxxString functions unnecessary by realizing that calls such as | |||
// Sum64([]byte(s)) don't need to copy s. See https://golang.org/issue/2205. | |||
// If that happens, even if we keep these functions they can be replaced with | |||
// the trivial safe code. | |||
// NOTE: The usual way of doing an unsafe string-to-[]byte conversion is: | |||
// | |||
// var b []byte | |||
// bh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) | |||
// bh.Data = (*reflect.StringHeader)(unsafe.Pointer(&s)).Data | |||
// bh.Len = len(s) | |||
// bh.Cap = len(s) | |||
// | |||
// Unfortunately, as of Go 1.15.3 the inliner's cost model assigns a high enough | |||
// weight to this sequence of expressions that any function that uses it will | |||
// not be inlined. Instead, the functions below use a different unsafe | |||
// conversion designed to minimize the inliner weight and allow both to be | |||
// inlined. There is also a test (TestInlining) which verifies that these are | |||
// inlined. | |||
// | |||
// See https://github.com/golang/go/issues/42739 for discussion. | |||
// Sum64String computes the 64-bit xxHash digest of s. | |||
// It may be faster than Sum64([]byte(s)) by avoiding a copy. | |||
func Sum64String(s string) uint64 { | |||
b := *(*[]byte)(unsafe.Pointer(&sliceHeader{s, len(s)})) | |||
return Sum64(b) | |||
} | |||
// WriteString adds more data to d. It always returns len(s), nil. | |||
// It may be faster than Write([]byte(s)) by avoiding a copy. | |||
func (d *Digest) WriteString(s string) (n int, err error) { | |||
d.Write(*(*[]byte)(unsafe.Pointer(&sliceHeader{s, len(s)}))) | |||
// d.Write always returns len(s), nil. | |||
// Ignoring the return output and returning these fixed values buys a | |||
// savings of 6 in the inliner's cost model. | |||
return len(s), nil | |||
} | |||
// sliceHeader is similar to reflect.SliceHeader, but it assumes that the layout | |||
// of the first two words is the same as the layout of a string. | |||
type sliceHeader struct { | |||
s string | |||
cap int | |||
} |
@@ -0,0 +1,21 @@ | |||
The MIT License (MIT) | |||
Copyright (c) 2017-2020 Damian Gryski <damian@gryski.com> | |||
Permission is hereby granted, free of charge, to any person obtaining a copy | |||
of this software and associated documentation files (the "Software"), to deal | |||
in the Software without restriction, including without limitation the rights | |||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |||
copies of the Software, and to permit persons to whom the Software is | |||
furnished to do so, subject to the following conditions: | |||
The above copyright notice and this permission notice shall be included in | |||
all copies or substantial portions of the Software. | |||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | |||
THE SOFTWARE. |
@@ -0,0 +1,79 @@ | |||
package rendezvous | |||
type Rendezvous struct { | |||
nodes map[string]int | |||
nstr []string | |||
nhash []uint64 | |||
hash Hasher | |||
} | |||
type Hasher func(s string) uint64 | |||
func New(nodes []string, hash Hasher) *Rendezvous { | |||
r := &Rendezvous{ | |||
nodes: make(map[string]int, len(nodes)), | |||
nstr: make([]string, len(nodes)), | |||
nhash: make([]uint64, len(nodes)), | |||
hash: hash, | |||
} | |||
for i, n := range nodes { | |||
r.nodes[n] = i | |||
r.nstr[i] = n | |||
r.nhash[i] = hash(n) | |||
} | |||
return r | |||
} | |||
func (r *Rendezvous) Lookup(k string) string { | |||
// short-circuit if we're empty | |||
if len(r.nodes) == 0 { | |||
return "" | |||
} | |||
khash := r.hash(k) | |||
var midx int | |||
var mhash = xorshiftMult64(khash ^ r.nhash[0]) | |||
for i, nhash := range r.nhash[1:] { | |||
if h := xorshiftMult64(khash ^ nhash); h > mhash { | |||
midx = i + 1 | |||
mhash = h | |||
} | |||
} | |||
return r.nstr[midx] | |||
} | |||
func (r *Rendezvous) Add(node string) { | |||
r.nodes[node] = len(r.nstr) | |||
r.nstr = append(r.nstr, node) | |||
r.nhash = append(r.nhash, r.hash(node)) | |||
} | |||
func (r *Rendezvous) Remove(node string) { | |||
// find index of node to remove | |||
nidx := r.nodes[node] | |||
// remove from the slices | |||
l := len(r.nstr) | |||
r.nstr[nidx] = r.nstr[l] | |||
r.nstr = r.nstr[:l] | |||
r.nhash[nidx] = r.nhash[l] | |||
r.nhash = r.nhash[:l] | |||
// update the map | |||
delete(r.nodes, node) | |||
moved := r.nstr[nidx] | |||
r.nodes[moved] = nidx | |||
} | |||
func xorshiftMult64(x uint64) uint64 { | |||
x ^= x >> 12 // a | |||
x ^= x << 25 // b | |||
x ^= x >> 27 // c | |||
return x * 2685821657736338717 | |||
} |
@@ -0,0 +1,3 @@ | |||
*.rdb | |||
testdata/*/ | |||
.idea/ |
@@ -0,0 +1,27 @@ | |||
run: | |||
concurrency: 8 | |||
deadline: 5m | |||
tests: false | |||
linters: | |||
enable-all: true | |||
disable: | |||
- funlen | |||
- gochecknoglobals | |||
- gochecknoinits | |||
- gocognit | |||
- goconst | |||
- godox | |||
- gosec | |||
- maligned | |||
- wsl | |||
- gomnd | |||
- goerr113 | |||
- exhaustive | |||
- nestif | |||
- nlreturn | |||
- exhaustivestruct | |||
- wrapcheck | |||
- errorlint | |||
- cyclop | |||
- forcetypeassert | |||
- forbidigo |
@@ -0,0 +1,4 @@ | |||
semi: false | |||
singleQuote: true | |||
proseWrap: always | |||
printWidth: 100 |
@@ -0,0 +1,149 @@ | |||
## [8.11.4](https://github.com/go-redis/redis/compare/v8.11.3...v8.11.4) (2021-10-04) | |||
### Features | |||
* add acl auth support for sentinels ([f66582f](https://github.com/go-redis/redis/commit/f66582f44f3dc3a4705a5260f982043fde4aa634)) | |||
* add Cmd.{String,Int,Float,Bool}Slice helpers and an example ([5d3d293](https://github.com/go-redis/redis/commit/5d3d293cc9c60b90871e2420602001463708ce24)) | |||
* add SetVal method for each command ([168981d](https://github.com/go-redis/redis/commit/168981da2d84ee9e07d15d3e74d738c162e264c4)) | |||
## v8.11 | |||
- Remove OpenTelemetry metrics. | |||
- Supports more redis commands and options. | |||
## v8.10 | |||
- Removed extra OpenTelemetry spans from go-redis core. Now go-redis instrumentation only adds a | |||
single span with a Redis command (instead of 4 spans). There are multiple reasons behind this | |||
decision: | |||
- Traces become smaller and less noisy. | |||
- It may be costly to process those 3 extra spans for each query. | |||
- go-redis no longer depends on OpenTelemetry. | |||
Eventually we hope to replace the information that we no longer collect with OpenTelemetry | |||
Metrics. | |||
## v8.9 | |||
- Changed `PubSub.Channel` to only rely on `Ping` result. You can now use `WithChannelSize`, | |||
`WithChannelHealthCheckInterval`, and `WithChannelSendTimeout` to override default settings. | |||
## v8.8 | |||
- To make updating easier, extra modules now have the same version as go-redis does. That means that | |||
you need to update your imports: | |||
``` | |||
github.com/go-redis/redis/extra/redisotel -> github.com/go-redis/redis/extra/redisotel/v8 | |||
github.com/go-redis/redis/extra/rediscensus -> github.com/go-redis/redis/extra/rediscensus/v8 | |||
``` | |||
## v8.5 | |||
- [knadh](https://github.com/knadh) contributed long-awaited ability to scan Redis Hash into a | |||
struct: | |||
```go | |||
err := rdb.HGetAll(ctx, "hash").Scan(&data) | |||
err := rdb.MGet(ctx, "key1", "key2").Scan(&data) | |||
``` | |||
- Please check [redismock](https://github.com/go-redis/redismock) by | |||
[monkey92t](https://github.com/monkey92t) if you are looking for mocking Redis Client. | |||
## v8 | |||
- All commands require `context.Context` as a first argument, e.g. `rdb.Ping(ctx)`. If you are not | |||
using `context.Context` yet, the simplest option is to define global package variable | |||
`var ctx = context.TODO()` and use it when `ctx` is required. | |||
- Full support for `context.Context` canceling. | |||
- Added `redis.NewFailoverClusterClient` that supports routing read-only commands to a slave node. | |||
- Added `redisext.OpenTemetryHook` that adds | |||
[Redis OpenTelemetry instrumentation](https://redis.uptrace.dev/tracing/). | |||
- Redis slow log support. | |||
- Ring uses Rendezvous Hashing by default which provides better distribution. You need to move | |||
existing keys to a new location or keys will be inaccessible / lost. To use old hashing scheme: | |||
```go | |||
import "github.com/golang/groupcache/consistenthash" | |||
ring := redis.NewRing(&redis.RingOptions{ | |||
NewConsistentHash: func() { | |||
return consistenthash.New(100, crc32.ChecksumIEEE) | |||
}, | |||
}) | |||
``` | |||
- `ClusterOptions.MaxRedirects` default value is changed from 8 to 3. | |||
- `Options.MaxRetries` default value is changed from 0 to 3. | |||
- `Cluster.ForEachNode` is renamed to `ForEachShard` for consistency with `Ring`. | |||
## v7.3 | |||
- New option `Options.Username` which causes client to use `AuthACL`. Be aware if your connection | |||
URL contains username. | |||
## v7.2 | |||
- Existing `HMSet` is renamed to `HSet` and old deprecated `HMSet` is restored for Redis 3 users. | |||
## v7.1 | |||
- Existing `Cmd.String` is renamed to `Cmd.Text`. New `Cmd.String` implements `fmt.Stringer` | |||
interface. | |||
## v7 | |||
- _Important_. Tx.Pipeline now returns a non-transactional pipeline. Use Tx.TxPipeline for a | |||
transactional pipeline. | |||
- WrapProcess is replaced with more convenient AddHook that has access to context.Context. | |||
- WithContext now can not be used to create a shallow copy of the client. | |||
- New methods ProcessContext, DoContext, and ExecContext. | |||
- Client respects Context.Deadline when setting net.Conn deadline. | |||
- Client listens on Context.Done while waiting for a connection from the pool and returns an error | |||
when context context is cancelled. | |||
- Add PubSub.ChannelWithSubscriptions that sends `*Subscription` in addition to `*Message` to allow | |||
detecting reconnections. | |||
- `time.Time` is now marshalled in RFC3339 format. `rdb.Get("foo").Time()` helper is added to parse | |||
the time. | |||
- `SetLimiter` is removed and added `Options.Limiter` instead. | |||
- `HMSet` is deprecated as of Redis v4. | |||
## v6.15 | |||
- Cluster and Ring pipelines process commands for each node in its own goroutine. | |||
## 6.14 | |||
- Added Options.MinIdleConns. | |||
- Added Options.MaxConnAge. | |||
- PoolStats.FreeConns is renamed to PoolStats.IdleConns. | |||
- Add Client.Do to simplify creating custom commands. | |||
- Add Cmd.String, Cmd.Int, Cmd.Int64, Cmd.Uint64, Cmd.Float64, and Cmd.Bool helpers. | |||
- Lower memory usage. | |||
## v6.13 | |||
- Ring got new options called `HashReplicas` and `Hash`. It is recommended to set | |||
`HashReplicas = 1000` for better keys distribution between shards. | |||
- Cluster client was optimized to use much less memory when reloading cluster state. | |||
- PubSub.ReceiveMessage is re-worked to not use ReceiveTimeout so it does not lose data when timeout | |||
occurres. In most cases it is recommended to use PubSub.Channel instead. | |||
- Dialer.KeepAlive is set to 5 minutes by default. | |||
## v6.12 | |||
- ClusterClient got new option called `ClusterSlots` which allows to build cluster of normal Redis | |||
Servers that don't have cluster mode enabled. See | |||
https://godoc.org/github.com/go-redis/redis#example-NewClusterClient--ManualSetup |
@@ -0,0 +1,25 @@ | |||
Copyright (c) 2013 The github.com/go-redis/redis Authors. | |||
All rights reserved. | |||
Redistribution and use in source and binary forms, with or without | |||
modification, are permitted provided that the following conditions are | |||
met: | |||
* Redistributions of source code must retain the above copyright | |||
notice, this list of conditions and the following disclaimer. | |||
* Redistributions in binary form must reproduce the above | |||
copyright notice, this list of conditions and the following disclaimer | |||
in the documentation and/or other materials provided with the | |||
distribution. | |||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS | |||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT | |||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR | |||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT | |||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, | |||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT | |||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, | |||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | |||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | |||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
@@ -0,0 +1,35 @@ | |||
PACKAGE_DIRS := $(shell find . -mindepth 2 -type f -name 'go.mod' -exec dirname {} \; | sort) | |||
test: testdeps | |||
go test ./... | |||
go test ./... -short -race | |||
go test ./... -run=NONE -bench=. -benchmem | |||
env GOOS=linux GOARCH=386 go test ./... | |||
go vet | |||
testdeps: testdata/redis/src/redis-server | |||
bench: testdeps | |||
go test ./... -test.run=NONE -test.bench=. -test.benchmem | |||
.PHONY: all test testdeps bench | |||
testdata/redis: | |||
mkdir -p $@ | |||
wget -qO- https://download.redis.io/releases/redis-6.2.5.tar.gz | tar xvz --strip-components=1 -C $@ | |||
testdata/redis/src/redis-server: testdata/redis | |||
cd $< && make all | |||
fmt: | |||
gofmt -w -s ./ | |||
goimports -w -local github.com/go-redis/redis ./ | |||
go_mod_tidy: | |||
go get -u && go mod tidy | |||
set -e; for dir in $(PACKAGE_DIRS); do \ | |||
echo "go mod tidy in $${dir}"; \ | |||
(cd "$${dir}" && \ | |||
go get -u && \ | |||
go mod tidy); \ | |||
done |
@@ -0,0 +1,178 @@ | |||
<p align="center"> | |||
<a href="https://uptrace.dev/?utm_source=gh-redis&utm_campaign=gh-redis-banner1"> | |||
<img src="https://raw.githubusercontent.com/uptrace/roadmap/master/banner1.png" alt="All-in-one tool to optimize performance and monitor errors & logs"> | |||
</a> | |||
</p> | |||
# Redis client for Golang | |||
![build workflow](https://github.com/go-redis/redis/actions/workflows/build.yml/badge.svg) | |||
[![PkgGoDev](https://pkg.go.dev/badge/github.com/go-redis/redis/v8)](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc) | |||
[![Documentation](https://img.shields.io/badge/redis-documentation-informational)](https://redis.uptrace.dev/) | |||
[![Chat](https://discordapp.com/api/guilds/752070105847955518/widget.png)](https://discord.gg/rWtp5Aj) | |||
- To ask questions, join [Discord](https://discord.gg/rWtp5Aj) or use | |||
[Discussions](https://github.com/go-redis/redis/discussions). | |||
- [Newsletter](https://blog.uptrace.dev/pages/newsletter.html) to get latest updates. | |||
- [Documentation](https://redis.uptrace.dev) | |||
- [Reference](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc) | |||
- [Examples](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#pkg-examples) | |||
- [RealWorld example app](https://github.com/uptrace/go-treemux-realworld-example-app) | |||
Other projects you may like: | |||
- [Bun](https://bun.uptrace.dev) - fast and simple SQL client for PostgreSQL, MySQL, and SQLite. | |||
- [treemux](https://github.com/vmihailenco/treemux) - high-speed, flexible, tree-based HTTP router | |||
for Go. | |||
## Ecosystem | |||
- [Redis Mock](https://github.com/go-redis/redismock). | |||
- [Distributed Locks](https://github.com/bsm/redislock). | |||
- [Redis Cache](https://github.com/go-redis/cache). | |||
- [Rate limiting](https://github.com/go-redis/redis_rate). | |||
## Features | |||
- Redis 3 commands except QUIT, MONITOR, and SYNC. | |||
- Automatic connection pooling with | |||
[circuit breaker](https://en.wikipedia.org/wiki/Circuit_breaker_design_pattern) support. | |||
- [Pub/Sub](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#PubSub). | |||
- [Transactions](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#example-Client-TxPipeline). | |||
- [Pipeline](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#example-Client-Pipeline) and | |||
[TxPipeline](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#example-Client-TxPipeline). | |||
- [Scripting](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#Script). | |||
- [Timeouts](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#Options). | |||
- [Redis Sentinel](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#NewFailoverClient). | |||
- [Redis Cluster](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#NewClusterClient). | |||
- [Cluster of Redis Servers](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#example-NewClusterClient--ManualSetup) | |||
without using cluster mode and Redis Sentinel. | |||
- [Ring](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#NewRing). | |||
- [Instrumentation](https://pkg.go.dev/github.com/go-redis/redis/v8?tab=doc#ex-package--Instrumentation). | |||
## Installation | |||
go-redis supports 2 last Go versions and requires a Go version with | |||
[modules](https://github.com/golang/go/wiki/Modules) support. So make sure to initialize a Go | |||
module: | |||
```shell | |||
go mod init github.com/my/repo | |||
``` | |||
And then install go-redis/v8 (note _v8_ in the import; omitting it is a popular mistake): | |||
```shell | |||
go get github.com/go-redis/redis/v8 | |||
``` | |||
## Quickstart | |||
```go | |||
import ( | |||
"context" | |||
"github.com/go-redis/redis/v8" | |||
) | |||
var ctx = context.Background() | |||
func ExampleClient() { | |||
rdb := redis.NewClient(&redis.Options{ | |||
Addr: "localhost:6379", | |||
Password: "", // no password set | |||
DB: 0, // use default DB | |||
}) | |||
err := rdb.Set(ctx, "key", "value", 0).Err() | |||
if err != nil { | |||
panic(err) | |||
} | |||
val, err := rdb.Get(ctx, "key").Result() | |||
if err != nil { | |||
panic(err) | |||
} | |||
fmt.Println("key", val) | |||
val2, err := rdb.Get(ctx, "key2").Result() | |||
if err == redis.Nil { | |||
fmt.Println("key2 does not exist") | |||
} else if err != nil { | |||
panic(err) | |||
} else { | |||
fmt.Println("key2", val2) | |||
} | |||
// Output: key value | |||
// key2 does not exist | |||
} | |||
``` | |||
## Look and feel | |||
Some corner cases: | |||
```go | |||
// SET key value EX 10 NX | |||
set, err := rdb.SetNX(ctx, "key", "value", 10*time.Second).Result() | |||
// SET key value keepttl NX | |||
set, err := rdb.SetNX(ctx, "key", "value", redis.KeepTTL).Result() | |||
// SORT list LIMIT 0 2 ASC | |||
vals, err := rdb.Sort(ctx, "list", &redis.Sort{Offset: 0, Count: 2, Order: "ASC"}).Result() | |||
// ZRANGEBYSCORE zset -inf +inf WITHSCORES LIMIT 0 2 | |||
vals, err := rdb.ZRangeByScoreWithScores(ctx, "zset", &redis.ZRangeBy{ | |||
Min: "-inf", | |||
Max: "+inf", | |||
Offset: 0, | |||
Count: 2, | |||
}).Result() | |||
// ZINTERSTORE out 2 zset1 zset2 WEIGHTS 2 3 AGGREGATE SUM | |||
vals, err := rdb.ZInterStore(ctx, "out", &redis.ZStore{ | |||
Keys: []string{"zset1", "zset2"}, | |||
Weights: []int64{2, 3} | |||
}).Result() | |||
// EVAL "return {KEYS[1],ARGV[1]}" 1 "key" "hello" | |||
vals, err := rdb.Eval(ctx, "return {KEYS[1],ARGV[1]}", []string{"key"}, "hello").Result() | |||
// custom command | |||
res, err := rdb.Do(ctx, "set", "key", "value").Result() | |||
``` | |||
## Run the test | |||
go-redis will start a redis-server and run the test cases. | |||
The paths of redis-server bin file and redis config file are defined in `main_test.go`: | |||
``` | |||
var ( | |||
redisServerBin, _ = filepath.Abs(filepath.Join("testdata", "redis", "src", "redis-server")) | |||
redisServerConf, _ = filepath.Abs(filepath.Join("testdata", "redis", "redis.conf")) | |||
) | |||
``` | |||
For local testing, you can change the variables to refer to your local files, or create a soft link | |||
to the corresponding folder for redis-server and copy the config file to `testdata/redis/`: | |||
``` | |||
ln -s /usr/bin/redis-server ./go-redis/testdata/redis/src | |||
cp ./go-redis/testdata/redis.conf ./go-redis/testdata/redis/ | |||
``` | |||
Lastly, run: | |||
``` | |||
go test | |||
``` | |||
## Contributors | |||
Thanks to all the people who already contributed! | |||
<a href="https://github.com/go-redis/redis/graphs/contributors"> | |||
<img src="https://contributors-img.web.app/image?repo=go-redis/redis" /> | |||
</a> |
@@ -0,0 +1,15 @@ | |||
# Releasing | |||
1. Run `release.sh` script which updates versions in go.mod files and pushes a new branch to GitHub: | |||
```shell | |||
TAG=v1.0.0 ./scripts/release.sh | |||
``` | |||
2. Open a pull request and wait for the build to finish. | |||
3. Merge the pull request and run `tag.sh` to create tags for packages: | |||
```shell | |||
TAG=v1.0.0 ./scripts/tag.sh | |||
``` |
@@ -0,0 +1,109 @@ | |||
package redis | |||
import ( | |||
"context" | |||
"sync" | |||
"sync/atomic" | |||
) | |||
func (c *ClusterClient) DBSize(ctx context.Context) *IntCmd { | |||
cmd := NewIntCmd(ctx, "dbsize") | |||
_ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error { | |||
var size int64 | |||
err := c.ForEachMaster(ctx, func(ctx context.Context, master *Client) error { | |||
n, err := master.DBSize(ctx).Result() | |||
if err != nil { | |||
return err | |||
} | |||
atomic.AddInt64(&size, n) | |||
return nil | |||
}) | |||
if err != nil { | |||
cmd.SetErr(err) | |||
} else { | |||
cmd.val = size | |||
} | |||
return nil | |||
}) | |||
return cmd | |||
} | |||
func (c *ClusterClient) ScriptLoad(ctx context.Context, script string) *StringCmd { | |||
cmd := NewStringCmd(ctx, "script", "load", script) | |||
_ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error { | |||
mu := &sync.Mutex{} | |||
err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error { | |||
val, err := shard.ScriptLoad(ctx, script).Result() | |||
if err != nil { | |||
return err | |||
} | |||
mu.Lock() | |||
if cmd.Val() == "" { | |||
cmd.val = val | |||
} | |||
mu.Unlock() | |||
return nil | |||
}) | |||
if err != nil { | |||
cmd.SetErr(err) | |||
} | |||
return nil | |||
}) | |||
return cmd | |||
} | |||
func (c *ClusterClient) ScriptFlush(ctx context.Context) *StatusCmd { | |||
cmd := NewStatusCmd(ctx, "script", "flush") | |||
_ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error { | |||
err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error { | |||
return shard.ScriptFlush(ctx).Err() | |||
}) | |||
if err != nil { | |||
cmd.SetErr(err) | |||
} | |||
return nil | |||
}) | |||
return cmd | |||
} | |||
func (c *ClusterClient) ScriptExists(ctx context.Context, hashes ...string) *BoolSliceCmd { | |||
args := make([]interface{}, 2+len(hashes)) | |||
args[0] = "script" | |||
args[1] = "exists" | |||
for i, hash := range hashes { | |||
args[2+i] = hash | |||
} | |||
cmd := NewBoolSliceCmd(ctx, args...) | |||
result := make([]bool, len(hashes)) | |||
for i := range result { | |||
result[i] = true | |||
} | |||
_ = c.hooks.process(ctx, cmd, func(ctx context.Context, _ Cmder) error { | |||
mu := &sync.Mutex{} | |||
err := c.ForEachShard(ctx, func(ctx context.Context, shard *Client) error { | |||
val, err := shard.ScriptExists(ctx, hashes...).Result() | |||
if err != nil { | |||
return err | |||
} | |||
mu.Lock() | |||
for i, v := range val { | |||
result[i] = result[i] && v | |||
} | |||
mu.Unlock() | |||
return nil | |||
}) | |||
if err != nil { | |||
cmd.SetErr(err) | |||
} else { | |||
cmd.val = result | |||
} | |||
return nil | |||
}) | |||
return cmd | |||
} |
@@ -0,0 +1,4 @@ | |||
/* | |||
Package redis implements a Redis client. | |||
*/ | |||
package redis |
@@ -0,0 +1,144 @@ | |||
package redis | |||
import ( | |||
"context" | |||
"io" | |||
"net" | |||
"strings" | |||
"github.com/go-redis/redis/v8/internal/pool" | |||
"github.com/go-redis/redis/v8/internal/proto" | |||
) | |||
// ErrClosed performs any operation on the closed client will return this error. | |||
var ErrClosed = pool.ErrClosed | |||
type Error interface { | |||
error | |||
// RedisError is a no-op function but | |||
// serves to distinguish types that are Redis | |||
// errors from ordinary errors: a type is a | |||
// Redis error if it has a RedisError method. | |||
RedisError() | |||
} | |||
var _ Error = proto.RedisError("") | |||
func shouldRetry(err error, retryTimeout bool) bool { | |||
switch err { | |||
case io.EOF, io.ErrUnexpectedEOF: | |||
return true | |||
case nil, context.Canceled, context.DeadlineExceeded: | |||
return false | |||
} | |||
if v, ok := err.(timeoutError); ok { | |||
if v.Timeout() { | |||
return retryTimeout | |||
} | |||
return true | |||
} | |||
s := err.Error() | |||
if s == "ERR max number of clients reached" { | |||
return true | |||
} | |||
if strings.HasPrefix(s, "LOADING ") { | |||
return true | |||
} | |||
if strings.HasPrefix(s, "READONLY ") { | |||
return true | |||
} | |||
if strings.HasPrefix(s, "CLUSTERDOWN ") { | |||
return true | |||
} | |||
if strings.HasPrefix(s, "TRYAGAIN ") { | |||
return true | |||
} | |||
return false | |||
} | |||
func isRedisError(err error) bool { | |||
_, ok := err.(proto.RedisError) | |||
return ok | |||
} | |||
func isBadConn(err error, allowTimeout bool, addr string) bool { | |||
switch err { | |||
case nil: | |||
return false | |||
case context.Canceled, context.DeadlineExceeded: | |||
return true | |||
} | |||
if isRedisError(err) { | |||
switch { | |||
case isReadOnlyError(err): | |||
// Close connections in read only state in case domain addr is used | |||
// and domain resolves to a different Redis Server. See #790. | |||
return true | |||
case isMovedSameConnAddr(err, addr): | |||
// Close connections when we are asked to move to the same addr | |||
// of the connection. Force a DNS resolution when all connections | |||
// of the pool are recycled | |||
return true | |||
default: | |||
return false | |||
} | |||
} | |||
if allowTimeout { | |||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() { | |||
return !netErr.Temporary() | |||
} | |||
} | |||
return true | |||
} | |||
func isMovedError(err error) (moved bool, ask bool, addr string) { | |||
if !isRedisError(err) { | |||
return | |||
} | |||
s := err.Error() | |||
switch { | |||
case strings.HasPrefix(s, "MOVED "): | |||
moved = true | |||
case strings.HasPrefix(s, "ASK "): | |||
ask = true | |||
default: | |||
return | |||
} | |||
ind := strings.LastIndex(s, " ") | |||
if ind == -1 { | |||
return false, false, "" | |||
} | |||
addr = s[ind+1:] | |||
return | |||
} | |||
func isLoadingError(err error) bool { | |||
return strings.HasPrefix(err.Error(), "LOADING ") | |||
} | |||
func isReadOnlyError(err error) bool { | |||
return strings.HasPrefix(err.Error(), "READONLY ") | |||
} | |||
func isMovedSameConnAddr(err error, addr string) bool { | |||
redisError := err.Error() | |||
if !strings.HasPrefix(redisError, "MOVED ") { | |||
return false | |||
} | |||
return strings.HasSuffix(redisError, addr) | |||
} | |||
//------------------------------------------------------------------------------ | |||
type timeoutError interface { | |||
Timeout() bool | |||
} |
@@ -0,0 +1,56 @@ | |||
package internal | |||
import ( | |||
"fmt" | |||
"strconv" | |||
"time" | |||
) | |||
func AppendArg(b []byte, v interface{}) []byte { | |||
switch v := v.(type) { | |||
case nil: | |||
return append(b, "<nil>"...) | |||
case string: | |||
return appendUTF8String(b, Bytes(v)) | |||
case []byte: | |||
return appendUTF8String(b, v) | |||
case int: | |||
return strconv.AppendInt(b, int64(v), 10) | |||
case int8: | |||
return strconv.AppendInt(b, int64(v), 10) | |||
case int16: | |||
return strconv.AppendInt(b, int64(v), 10) | |||
case int32: | |||
return strconv.AppendInt(b, int64(v), 10) | |||
case int64: | |||
return strconv.AppendInt(b, v, 10) | |||
case uint: | |||
return strconv.AppendUint(b, uint64(v), 10) | |||
case uint8: | |||
return strconv.AppendUint(b, uint64(v), 10) | |||
case uint16: | |||
return strconv.AppendUint(b, uint64(v), 10) | |||
case uint32: | |||
return strconv.AppendUint(b, uint64(v), 10) | |||
case uint64: | |||
return strconv.AppendUint(b, v, 10) | |||
case float32: | |||
return strconv.AppendFloat(b, float64(v), 'f', -1, 64) | |||
case float64: | |||
return strconv.AppendFloat(b, v, 'f', -1, 64) | |||
case bool: | |||
if v { | |||
return append(b, "true"...) | |||
} | |||
return append(b, "false"...) | |||
case time.Time: | |||
return v.AppendFormat(b, time.RFC3339Nano) | |||
default: | |||
return append(b, fmt.Sprint(v)...) | |||
} | |||
} | |||
func appendUTF8String(dst []byte, src []byte) []byte { | |||
dst = append(dst, src...) | |||
return dst | |||
} |
@@ -0,0 +1,78 @@ | |||
package hashtag | |||
import ( | |||
"strings" | |||
"github.com/go-redis/redis/v8/internal/rand" | |||
) | |||
const slotNumber = 16384 | |||
// CRC16 implementation according to CCITT standards. | |||
// Copyright 2001-2010 Georges Menie (www.menie.org) | |||
// Copyright 2013 The Go Authors. All rights reserved. | |||
// http://redis.io/topics/cluster-spec#appendix-a-crc16-reference-implementation-in-ansi-c | |||
var crc16tab = [256]uint16{ | |||
0x0000, 0x1021, 0x2042, 0x3063, 0x4084, 0x50a5, 0x60c6, 0x70e7, | |||
0x8108, 0x9129, 0xa14a, 0xb16b, 0xc18c, 0xd1ad, 0xe1ce, 0xf1ef, | |||
0x1231, 0x0210, 0x3273, 0x2252, 0x52b5, 0x4294, 0x72f7, 0x62d6, | |||
0x9339, 0x8318, 0xb37b, 0xa35a, 0xd3bd, 0xc39c, 0xf3ff, 0xe3de, | |||
0x2462, 0x3443, 0x0420, 0x1401, 0x64e6, 0x74c7, 0x44a4, 0x5485, | |||
0xa56a, 0xb54b, 0x8528, 0x9509, 0xe5ee, 0xf5cf, 0xc5ac, 0xd58d, | |||
0x3653, 0x2672, 0x1611, 0x0630, 0x76d7, 0x66f6, 0x5695, 0x46b4, | |||
0xb75b, 0xa77a, 0x9719, 0x8738, 0xf7df, 0xe7fe, 0xd79d, 0xc7bc, | |||
0x48c4, 0x58e5, 0x6886, 0x78a7, 0x0840, 0x1861, 0x2802, 0x3823, | |||
0xc9cc, 0xd9ed, 0xe98e, 0xf9af, 0x8948, 0x9969, 0xa90a, 0xb92b, | |||
0x5af5, 0x4ad4, 0x7ab7, 0x6a96, 0x1a71, 0x0a50, 0x3a33, 0x2a12, | |||
0xdbfd, 0xcbdc, 0xfbbf, 0xeb9e, 0x9b79, 0x8b58, 0xbb3b, 0xab1a, | |||
0x6ca6, 0x7c87, 0x4ce4, 0x5cc5, 0x2c22, 0x3c03, 0x0c60, 0x1c41, | |||
0xedae, 0xfd8f, 0xcdec, 0xddcd, 0xad2a, 0xbd0b, 0x8d68, 0x9d49, | |||
0x7e97, 0x6eb6, 0x5ed5, 0x4ef4, 0x3e13, 0x2e32, 0x1e51, 0x0e70, | |||
0xff9f, 0xefbe, 0xdfdd, 0xcffc, 0xbf1b, 0xaf3a, 0x9f59, 0x8f78, | |||
0x9188, 0x81a9, 0xb1ca, 0xa1eb, 0xd10c, 0xc12d, 0xf14e, 0xe16f, | |||
0x1080, 0x00a1, 0x30c2, 0x20e3, 0x5004, 0x4025, 0x7046, 0x6067, | |||
0x83b9, 0x9398, 0xa3fb, 0xb3da, 0xc33d, 0xd31c, 0xe37f, 0xf35e, | |||
0x02b1, 0x1290, 0x22f3, 0x32d2, 0x4235, 0x5214, 0x6277, 0x7256, | |||
0xb5ea, 0xa5cb, 0x95a8, 0x8589, 0xf56e, 0xe54f, 0xd52c, 0xc50d, | |||
0x34e2, 0x24c3, 0x14a0, 0x0481, 0x7466, 0x6447, 0x5424, 0x4405, | |||
0xa7db, 0xb7fa, 0x8799, 0x97b8, 0xe75f, 0xf77e, 0xc71d, 0xd73c, | |||
0x26d3, 0x36f2, 0x0691, 0x16b0, 0x6657, 0x7676, 0x4615, 0x5634, | |||
0xd94c, 0xc96d, 0xf90e, 0xe92f, 0x99c8, 0x89e9, 0xb98a, 0xa9ab, | |||
0x5844, 0x4865, 0x7806, 0x6827, 0x18c0, 0x08e1, 0x3882, 0x28a3, | |||
0xcb7d, 0xdb5c, 0xeb3f, 0xfb1e, 0x8bf9, 0x9bd8, 0xabbb, 0xbb9a, | |||
0x4a75, 0x5a54, 0x6a37, 0x7a16, 0x0af1, 0x1ad0, 0x2ab3, 0x3a92, | |||
0xfd2e, 0xed0f, 0xdd6c, 0xcd4d, 0xbdaa, 0xad8b, 0x9de8, 0x8dc9, | |||
0x7c26, 0x6c07, 0x5c64, 0x4c45, 0x3ca2, 0x2c83, 0x1ce0, 0x0cc1, | |||
0xef1f, 0xff3e, 0xcf5d, 0xdf7c, 0xaf9b, 0xbfba, 0x8fd9, 0x9ff8, | |||
0x6e17, 0x7e36, 0x4e55, 0x5e74, 0x2e93, 0x3eb2, 0x0ed1, 0x1ef0, | |||
} | |||
func Key(key string) string { | |||
if s := strings.IndexByte(key, '{'); s > -1 { | |||
if e := strings.IndexByte(key[s+1:], '}'); e > 0 { | |||
return key[s+1 : s+e+1] | |||
} | |||
} | |||
return key | |||
} | |||
func RandomSlot() int { | |||
return rand.Intn(slotNumber) | |||
} | |||
// Slot returns a consistent slot number between 0 and 16383 | |||
// for any given string key. | |||
func Slot(key string) int { | |||
if key == "" { | |||
return RandomSlot() | |||
} | |||
key = Key(key) | |||
return int(crc16sum(key)) % slotNumber | |||
} | |||
func crc16sum(key string) (crc uint16) { | |||
for i := 0; i < len(key); i++ { | |||
crc = (crc << 8) ^ crc16tab[(byte(crc>>8)^key[i])&0x00ff] | |||
} | |||
return | |||
} |
@@ -0,0 +1,201 @@ | |||
package hscan | |||
import ( | |||
"errors" | |||
"fmt" | |||
"reflect" | |||
"strconv" | |||
) | |||
// decoderFunc represents decoding functions for default built-in types. | |||
type decoderFunc func(reflect.Value, string) error | |||
var ( | |||
// List of built-in decoders indexed by their numeric constant values (eg: reflect.Bool = 1). | |||
decoders = []decoderFunc{ | |||
reflect.Bool: decodeBool, | |||
reflect.Int: decodeInt, | |||
reflect.Int8: decodeInt8, | |||
reflect.Int16: decodeInt16, | |||
reflect.Int32: decodeInt32, | |||
reflect.Int64: decodeInt64, | |||
reflect.Uint: decodeUint, | |||
reflect.Uint8: decodeUint8, | |||
reflect.Uint16: decodeUint16, | |||
reflect.Uint32: decodeUint32, | |||
reflect.Uint64: decodeUint64, | |||
reflect.Float32: decodeFloat32, | |||
reflect.Float64: decodeFloat64, | |||
reflect.Complex64: decodeUnsupported, | |||
reflect.Complex128: decodeUnsupported, | |||
reflect.Array: decodeUnsupported, | |||
reflect.Chan: decodeUnsupported, | |||
reflect.Func: decodeUnsupported, | |||
reflect.Interface: decodeUnsupported, | |||
reflect.Map: decodeUnsupported, | |||
reflect.Ptr: decodeUnsupported, | |||
reflect.Slice: decodeSlice, | |||
reflect.String: decodeString, | |||
reflect.Struct: decodeUnsupported, | |||
reflect.UnsafePointer: decodeUnsupported, | |||
} | |||
// Global map of struct field specs that is populated once for every new | |||
// struct type that is scanned. This caches the field types and the corresponding | |||
// decoder functions to avoid iterating through struct fields on subsequent scans. | |||
globalStructMap = newStructMap() | |||
) | |||
func Struct(dst interface{}) (StructValue, error) { | |||
v := reflect.ValueOf(dst) | |||
// The destination to scan into should be a struct pointer. | |||
if v.Kind() != reflect.Ptr || v.IsNil() { | |||
return StructValue{}, fmt.Errorf("redis.Scan(non-pointer %T)", dst) | |||
} | |||
v = v.Elem() | |||
if v.Kind() != reflect.Struct { | |||
return StructValue{}, fmt.Errorf("redis.Scan(non-struct %T)", dst) | |||
} | |||
return StructValue{ | |||
spec: globalStructMap.get(v.Type()), | |||
value: v, | |||
}, nil | |||
} | |||
// Scan scans the results from a key-value Redis map result set to a destination struct. | |||
// The Redis keys are matched to the struct's field with the `redis` tag. | |||
func Scan(dst interface{}, keys []interface{}, vals []interface{}) error { | |||
if len(keys) != len(vals) { | |||
return errors.New("args should have the same number of keys and vals") | |||
} | |||
strct, err := Struct(dst) | |||
if err != nil { | |||
return err | |||
} | |||
// Iterate through the (key, value) sequence. | |||
for i := 0; i < len(vals); i++ { | |||
key, ok := keys[i].(string) | |||
if !ok { | |||
continue | |||
} | |||
val, ok := vals[i].(string) | |||
if !ok { | |||
continue | |||
} | |||
if err := strct.Scan(key, val); err != nil { | |||
return err | |||
} | |||
} | |||
return nil | |||
} | |||
func decodeBool(f reflect.Value, s string) error { | |||
b, err := strconv.ParseBool(s) | |||
if err != nil { | |||
return err | |||
} | |||
f.SetBool(b) | |||
return nil | |||
} | |||
func decodeInt8(f reflect.Value, s string) error { | |||
return decodeNumber(f, s, 8) | |||
} | |||
func decodeInt16(f reflect.Value, s string) error { | |||
return decodeNumber(f, s, 16) | |||
} | |||
func decodeInt32(f reflect.Value, s string) error { | |||
return decodeNumber(f, s, 32) | |||
} | |||
func decodeInt64(f reflect.Value, s string) error { | |||
return decodeNumber(f, s, 64) | |||
} | |||
func decodeInt(f reflect.Value, s string) error { | |||
return decodeNumber(f, s, 0) | |||
} | |||
func decodeNumber(f reflect.Value, s string, bitSize int) error { | |||
v, err := strconv.ParseInt(s, 10, bitSize) | |||
if err != nil { | |||
return err | |||
} | |||
f.SetInt(v) | |||
return nil | |||
} | |||
func decodeUint8(f reflect.Value, s string) error { | |||
return decodeUnsignedNumber(f, s, 8) | |||
} | |||
func decodeUint16(f reflect.Value, s string) error { | |||
return decodeUnsignedNumber(f, s, 16) | |||
} | |||
func decodeUint32(f reflect.Value, s string) error { | |||
return decodeUnsignedNumber(f, s, 32) | |||
} | |||
func decodeUint64(f reflect.Value, s string) error { | |||
return decodeUnsignedNumber(f, s, 64) | |||
} | |||
func decodeUint(f reflect.Value, s string) error { | |||
return decodeUnsignedNumber(f, s, 0) | |||
} | |||
func decodeUnsignedNumber(f reflect.Value, s string, bitSize int) error { | |||
v, err := strconv.ParseUint(s, 10, bitSize) | |||
if err != nil { | |||
return err | |||
} | |||
f.SetUint(v) | |||
return nil | |||
} | |||
func decodeFloat32(f reflect.Value, s string) error { | |||
v, err := strconv.ParseFloat(s, 32) | |||
if err != nil { | |||
return err | |||
} | |||
f.SetFloat(v) | |||
return nil | |||
} | |||
// although the default is float64, but we better define it. | |||
func decodeFloat64(f reflect.Value, s string) error { | |||
v, err := strconv.ParseFloat(s, 64) | |||
if err != nil { | |||
return err | |||
} | |||
f.SetFloat(v) | |||
return nil | |||
} | |||
func decodeString(f reflect.Value, s string) error { | |||
f.SetString(s) | |||
return nil | |||
} | |||
func decodeSlice(f reflect.Value, s string) error { | |||
// []byte slice ([]uint8). | |||
if f.Type().Elem().Kind() == reflect.Uint8 { | |||
f.SetBytes([]byte(s)) | |||
} | |||
return nil | |||
} | |||
func decodeUnsupported(v reflect.Value, s string) error { | |||
return fmt.Errorf("redis.Scan(unsupported %s)", v.Type()) | |||
} |
@@ -0,0 +1,93 @@ | |||
package hscan | |||
import ( | |||
"fmt" | |||
"reflect" | |||
"strings" | |||
"sync" | |||
) | |||
// structMap contains the map of struct fields for target structs | |||
// indexed by the struct type. | |||
type structMap struct { | |||
m sync.Map | |||
} | |||
func newStructMap() *structMap { | |||
return new(structMap) | |||
} | |||
func (s *structMap) get(t reflect.Type) *structSpec { | |||
if v, ok := s.m.Load(t); ok { | |||
return v.(*structSpec) | |||
} | |||
spec := newStructSpec(t, "redis") | |||
s.m.Store(t, spec) | |||
return spec | |||
} | |||
//------------------------------------------------------------------------------ | |||
// structSpec contains the list of all fields in a target struct. | |||
type structSpec struct { | |||
m map[string]*structField | |||
} | |||
func (s *structSpec) set(tag string, sf *structField) { | |||
s.m[tag] = sf | |||
} | |||
func newStructSpec(t reflect.Type, fieldTag string) *structSpec { | |||
numField := t.NumField() | |||
out := &structSpec{ | |||
m: make(map[string]*structField, numField), | |||
} | |||
for i := 0; i < numField; i++ { | |||
f := t.Field(i) | |||
tag := f.Tag.Get(fieldTag) | |||
if tag == "" || tag == "-" { | |||
continue | |||
} | |||
tag = strings.Split(tag, ",")[0] | |||
if tag == "" { | |||
continue | |||
} | |||
// Use the built-in decoder. | |||
out.set(tag, &structField{index: i, fn: decoders[f.Type.Kind()]}) | |||
} | |||
return out | |||
} | |||
//------------------------------------------------------------------------------ | |||
// structField represents a single field in a target struct. | |||
type structField struct { | |||
index int | |||
fn decoderFunc | |||
} | |||
//------------------------------------------------------------------------------ | |||
type StructValue struct { | |||
spec *structSpec | |||
value reflect.Value | |||
} | |||
func (s StructValue) Scan(key string, value string) error { | |||
field, ok := s.spec.m[key] | |||
if !ok { | |||
return nil | |||
} | |||
if err := field.fn(s.value.Field(field.index), value); err != nil { | |||
t := s.value.Type() | |||
return fmt.Errorf("cannot scan redis.result %s into struct field %s.%s of type %s, error-%s", | |||
value, t.Name(), t.Field(field.index).Name, t.Field(field.index).Type, err.Error()) | |||
} | |||
return nil | |||
} |
@@ -0,0 +1,29 @@ | |||
package internal | |||
import ( | |||
"time" | |||
"github.com/go-redis/redis/v8/internal/rand" | |||
) | |||
func RetryBackoff(retry int, minBackoff, maxBackoff time.Duration) time.Duration { | |||
if retry < 0 { | |||
panic("not reached") | |||
} | |||
if minBackoff == 0 { | |||
return 0 | |||
} | |||
d := minBackoff << uint(retry) | |||
if d < minBackoff { | |||
return maxBackoff | |||
} | |||
d = minBackoff + time.Duration(rand.Int63n(int64(d))) | |||
if d > maxBackoff || d < minBackoff { | |||
d = maxBackoff | |||
} | |||
return d | |||
} |
@@ -0,0 +1,26 @@ | |||
package internal | |||
import ( | |||
"context" | |||
"fmt" | |||
"log" | |||
"os" | |||
) | |||
type Logging interface { | |||
Printf(ctx context.Context, format string, v ...interface{}) | |||
} | |||
type logger struct { | |||
log *log.Logger | |||
} | |||
func (l *logger) Printf(ctx context.Context, format string, v ...interface{}) { | |||
_ = l.log.Output(2, fmt.Sprintf(format, v...)) | |||
} | |||
// Logger calls Output to print to the stderr. | |||
// Arguments are handled in the manner of fmt.Print. | |||
var Logger Logging = &logger{ | |||
log: log.New(os.Stderr, "redis: ", log.LstdFlags|log.Lshortfile), | |||
} |
@@ -0,0 +1,60 @@ | |||
/* | |||
Copyright 2014 The Camlistore Authors | |||
Licensed under the Apache License, Version 2.0 (the "License"); | |||
you may not use this file except in compliance with the License. | |||
You may obtain a copy of the License at | |||
http://www.apache.org/licenses/LICENSE-2.0 | |||
Unless required by applicable law or agreed to in writing, software | |||
distributed under the License is distributed on an "AS IS" BASIS, | |||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
See the License for the specific language governing permissions and | |||
limitations under the License. | |||
*/ | |||
package internal | |||
import ( | |||
"sync" | |||
"sync/atomic" | |||
) | |||
// A Once will perform a successful action exactly once. | |||
// | |||
// Unlike a sync.Once, this Once's func returns an error | |||
// and is re-armed on failure. | |||
type Once struct { | |||
m sync.Mutex | |||
done uint32 | |||
} | |||
// Do calls the function f if and only if Do has not been invoked | |||
// without error for this instance of Once. In other words, given | |||
// var once Once | |||
// if once.Do(f) is called multiple times, only the first call will | |||
// invoke f, even if f has a different value in each invocation unless | |||
// f returns an error. A new instance of Once is required for each | |||
// function to execute. | |||
// | |||
// Do is intended for initialization that must be run exactly once. Since f | |||
// is niladic, it may be necessary to use a function literal to capture the | |||
// arguments to a function to be invoked by Do: | |||
// err := config.once.Do(func() error { return config.init(filename) }) | |||
func (o *Once) Do(f func() error) error { | |||
if atomic.LoadUint32(&o.done) == 1 { | |||
return nil | |||
} | |||
// Slow-path. | |||
o.m.Lock() | |||
defer o.m.Unlock() | |||
var err error | |||
if o.done == 0 { | |||
err = f() | |||
if err == nil { | |||
atomic.StoreUint32(&o.done, 1) | |||
} | |||
} | |||
return err | |||
} |
@@ -0,0 +1,121 @@ | |||
package pool | |||
import ( | |||
"bufio" | |||
"context" | |||
"net" | |||
"sync/atomic" | |||
"time" | |||
"github.com/go-redis/redis/v8/internal/proto" | |||
) | |||
var noDeadline = time.Time{} | |||
type Conn struct { | |||
usedAt int64 // atomic | |||
netConn net.Conn | |||
rd *proto.Reader | |||
bw *bufio.Writer | |||
wr *proto.Writer | |||
Inited bool | |||
pooled bool | |||
createdAt time.Time | |||
} | |||
func NewConn(netConn net.Conn) *Conn { | |||
cn := &Conn{ | |||
netConn: netConn, | |||
createdAt: time.Now(), | |||
} | |||
cn.rd = proto.NewReader(netConn) | |||
cn.bw = bufio.NewWriter(netConn) | |||
cn.wr = proto.NewWriter(cn.bw) | |||
cn.SetUsedAt(time.Now()) | |||
return cn | |||
} | |||
func (cn *Conn) UsedAt() time.Time { | |||
unix := atomic.LoadInt64(&cn.usedAt) | |||
return time.Unix(unix, 0) | |||
} | |||
func (cn *Conn) SetUsedAt(tm time.Time) { | |||
atomic.StoreInt64(&cn.usedAt, tm.Unix()) | |||
} | |||
func (cn *Conn) SetNetConn(netConn net.Conn) { | |||
cn.netConn = netConn | |||
cn.rd.Reset(netConn) | |||
cn.bw.Reset(netConn) | |||
} | |||
func (cn *Conn) Write(b []byte) (int, error) { | |||
return cn.netConn.Write(b) | |||
} | |||
func (cn *Conn) RemoteAddr() net.Addr { | |||
if cn.netConn != nil { | |||
return cn.netConn.RemoteAddr() | |||
} | |||
return nil | |||
} | |||
func (cn *Conn) WithReader(ctx context.Context, timeout time.Duration, fn func(rd *proto.Reader) error) error { | |||
if err := cn.netConn.SetReadDeadline(cn.deadline(ctx, timeout)); err != nil { | |||
return err | |||
} | |||
return fn(cn.rd) | |||
} | |||
func (cn *Conn) WithWriter( | |||
ctx context.Context, timeout time.Duration, fn func(wr *proto.Writer) error, | |||
) error { | |||
if err := cn.netConn.SetWriteDeadline(cn.deadline(ctx, timeout)); err != nil { | |||
return err | |||
} | |||
if cn.bw.Buffered() > 0 { | |||
cn.bw.Reset(cn.netConn) | |||
} | |||
if err := fn(cn.wr); err != nil { | |||
return err | |||
} | |||
return cn.bw.Flush() | |||
} | |||
func (cn *Conn) Close() error { | |||
return cn.netConn.Close() | |||
} | |||
func (cn *Conn) deadline(ctx context.Context, timeout time.Duration) time.Time { | |||
tm := time.Now() | |||
cn.SetUsedAt(tm) | |||
if timeout > 0 { | |||
tm = tm.Add(timeout) | |||
} | |||
if ctx != nil { | |||
deadline, ok := ctx.Deadline() | |||
if ok { | |||
if timeout == 0 { | |||
return deadline | |||
} | |||
if deadline.Before(tm) { | |||
return deadline | |||
} | |||
return tm | |||
} | |||
} | |||
if timeout > 0 { | |||
return tm | |||
} | |||
return noDeadline | |||
} |
@@ -0,0 +1,557 @@ | |||
package pool | |||
import ( | |||
"context" | |||
"errors" | |||
"net" | |||
"sync" | |||
"sync/atomic" | |||
"time" | |||
"github.com/go-redis/redis/v8/internal" | |||
) | |||
var ( | |||
// ErrClosed performs any operation on the closed client will return this error. | |||
ErrClosed = errors.New("redis: client is closed") | |||
// ErrPoolTimeout timed out waiting to get a connection from the connection pool. | |||
ErrPoolTimeout = errors.New("redis: connection pool timeout") | |||
) | |||
var timers = sync.Pool{ | |||
New: func() interface{} { | |||
t := time.NewTimer(time.Hour) | |||
t.Stop() | |||
return t | |||
}, | |||
} | |||
// Stats contains pool state information and accumulated stats. | |||
type Stats struct { | |||
Hits uint32 // number of times free connection was found in the pool | |||
Misses uint32 // number of times free connection was NOT found in the pool | |||
Timeouts uint32 // number of times a wait timeout occurred | |||
TotalConns uint32 // number of total connections in the pool | |||
IdleConns uint32 // number of idle connections in the pool | |||
StaleConns uint32 // number of stale connections removed from the pool | |||
} | |||
type Pooler interface { | |||
NewConn(context.Context) (*Conn, error) | |||
CloseConn(*Conn) error | |||
Get(context.Context) (*Conn, error) | |||
Put(context.Context, *Conn) | |||
Remove(context.Context, *Conn, error) | |||
Len() int | |||
IdleLen() int | |||
Stats() *Stats | |||
Close() error | |||
} | |||
type Options struct { | |||
Dialer func(context.Context) (net.Conn, error) | |||
OnClose func(*Conn) error | |||
PoolFIFO bool | |||
PoolSize int | |||
MinIdleConns int | |||
MaxConnAge time.Duration | |||
PoolTimeout time.Duration | |||
IdleTimeout time.Duration | |||
IdleCheckFrequency time.Duration | |||
} | |||
type lastDialErrorWrap struct { | |||
err error | |||
} | |||
type ConnPool struct { | |||
opt *Options | |||
dialErrorsNum uint32 // atomic | |||
lastDialError atomic.Value | |||
queue chan struct{} | |||
connsMu sync.Mutex | |||
conns []*Conn | |||
idleConns []*Conn | |||
poolSize int | |||
idleConnsLen int | |||
stats Stats | |||
_closed uint32 // atomic | |||
closedCh chan struct{} | |||
} | |||
var _ Pooler = (*ConnPool)(nil) | |||
func NewConnPool(opt *Options) *ConnPool { | |||
p := &ConnPool{ | |||
opt: opt, | |||
queue: make(chan struct{}, opt.PoolSize), | |||
conns: make([]*Conn, 0, opt.PoolSize), | |||
idleConns: make([]*Conn, 0, opt.PoolSize), | |||
closedCh: make(chan struct{}), | |||
} | |||
p.connsMu.Lock() | |||
p.checkMinIdleConns() | |||
p.connsMu.Unlock() | |||
if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 { | |||
go p.reaper(opt.IdleCheckFrequency) | |||
} | |||
return p | |||
} | |||
func (p *ConnPool) checkMinIdleConns() { | |||
if p.opt.MinIdleConns == 0 { | |||
return | |||
} | |||
for p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns { | |||
p.poolSize++ | |||
p.idleConnsLen++ | |||
go func() { | |||
err := p.addIdleConn() | |||
if err != nil && err != ErrClosed { | |||
p.connsMu.Lock() | |||
p.poolSize-- | |||
p.idleConnsLen-- | |||
p.connsMu.Unlock() | |||
} | |||
}() | |||
} | |||
} | |||
func (p *ConnPool) addIdleConn() error { | |||
cn, err := p.dialConn(context.TODO(), true) | |||
if err != nil { | |||
return err | |||
} | |||
p.connsMu.Lock() | |||
defer p.connsMu.Unlock() | |||
// It is not allowed to add new connections to the closed connection pool. | |||
if p.closed() { | |||
_ = cn.Close() | |||
return ErrClosed | |||
} | |||
p.conns = append(p.conns, cn) | |||
p.idleConns = append(p.idleConns, cn) | |||
return nil | |||
} | |||
func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) { | |||
return p.newConn(ctx, false) | |||
} | |||
func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) { | |||
cn, err := p.dialConn(ctx, pooled) | |||
if err != nil { | |||
return nil, err | |||
} | |||
p.connsMu.Lock() | |||
defer p.connsMu.Unlock() | |||
// It is not allowed to add new connections to the closed connection pool. | |||
if p.closed() { | |||
_ = cn.Close() | |||
return nil, ErrClosed | |||
} | |||
p.conns = append(p.conns, cn) | |||
if pooled { | |||
// If pool is full remove the cn on next Put. | |||
if p.poolSize >= p.opt.PoolSize { | |||
cn.pooled = false | |||
} else { | |||
p.poolSize++ | |||
} | |||
} | |||
return cn, nil | |||
} | |||
func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) { | |||
if p.closed() { | |||
return nil, ErrClosed | |||
} | |||
if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) { | |||
return nil, p.getLastDialError() | |||
} | |||
netConn, err := p.opt.Dialer(ctx) | |||
if err != nil { | |||
p.setLastDialError(err) | |||
if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) { | |||
go p.tryDial() | |||
} | |||
return nil, err | |||
} | |||
cn := NewConn(netConn) | |||
cn.pooled = pooled | |||
return cn, nil | |||
} | |||
func (p *ConnPool) tryDial() { | |||
for { | |||
if p.closed() { | |||
return | |||
} | |||
conn, err := p.opt.Dialer(context.Background()) | |||
if err != nil { | |||
p.setLastDialError(err) | |||
time.Sleep(time.Second) | |||
continue | |||
} | |||
atomic.StoreUint32(&p.dialErrorsNum, 0) | |||
_ = conn.Close() | |||
return | |||
} | |||
} | |||
func (p *ConnPool) setLastDialError(err error) { | |||
p.lastDialError.Store(&lastDialErrorWrap{err: err}) | |||
} | |||
func (p *ConnPool) getLastDialError() error { | |||
err, _ := p.lastDialError.Load().(*lastDialErrorWrap) | |||
if err != nil { | |||
return err.err | |||
} | |||
return nil | |||
} | |||
// Get returns existed connection from the pool or creates a new one. | |||
func (p *ConnPool) Get(ctx context.Context) (*Conn, error) { | |||
if p.closed() { | |||
return nil, ErrClosed | |||
} | |||
if err := p.waitTurn(ctx); err != nil { | |||
return nil, err | |||
} | |||
for { | |||
p.connsMu.Lock() | |||
cn, err := p.popIdle() | |||
p.connsMu.Unlock() | |||
if err != nil { | |||
return nil, err | |||
} | |||
if cn == nil { | |||
break | |||
} | |||
if p.isStaleConn(cn) { | |||
_ = p.CloseConn(cn) | |||
continue | |||
} | |||
atomic.AddUint32(&p.stats.Hits, 1) | |||
return cn, nil | |||
} | |||
atomic.AddUint32(&p.stats.Misses, 1) | |||
newcn, err := p.newConn(ctx, true) | |||
if err != nil { | |||
p.freeTurn() | |||
return nil, err | |||
} | |||
return newcn, nil | |||
} | |||
func (p *ConnPool) getTurn() { | |||
p.queue <- struct{}{} | |||
} | |||
func (p *ConnPool) waitTurn(ctx context.Context) error { | |||
select { | |||
case <-ctx.Done(): | |||
return ctx.Err() | |||
default: | |||
} | |||
select { | |||
case p.queue <- struct{}{}: | |||
return nil | |||
default: | |||
} | |||
timer := timers.Get().(*time.Timer) | |||
timer.Reset(p.opt.PoolTimeout) | |||
select { | |||
case <-ctx.Done(): | |||
if !timer.Stop() { | |||
<-timer.C | |||
} | |||
timers.Put(timer) | |||
return ctx.Err() | |||
case p.queue <- struct{}{}: | |||
if !timer.Stop() { | |||
<-timer.C | |||
} | |||
timers.Put(timer) | |||
return nil | |||
case <-timer.C: | |||
timers.Put(timer) | |||
atomic.AddUint32(&p.stats.Timeouts, 1) | |||
return ErrPoolTimeout | |||
} | |||
} | |||
func (p *ConnPool) freeTurn() { | |||
<-p.queue | |||
} | |||
func (p *ConnPool) popIdle() (*Conn, error) { | |||
if p.closed() { | |||
return nil, ErrClosed | |||
} | |||
n := len(p.idleConns) | |||
if n == 0 { | |||
return nil, nil | |||
} | |||
var cn *Conn | |||
if p.opt.PoolFIFO { | |||
cn = p.idleConns[0] | |||
copy(p.idleConns, p.idleConns[1:]) | |||
p.idleConns = p.idleConns[:n-1] | |||
} else { | |||
idx := n - 1 | |||
cn = p.idleConns[idx] | |||
p.idleConns = p.idleConns[:idx] | |||
} | |||
p.idleConnsLen-- | |||
p.checkMinIdleConns() | |||
return cn, nil | |||
} | |||
func (p *ConnPool) Put(ctx context.Context, cn *Conn) { | |||
if cn.rd.Buffered() > 0 { | |||
internal.Logger.Printf(ctx, "Conn has unread data") | |||
p.Remove(ctx, cn, BadConnError{}) | |||
return | |||
} | |||
if !cn.pooled { | |||
p.Remove(ctx, cn, nil) | |||
return | |||
} | |||
p.connsMu.Lock() | |||
p.idleConns = append(p.idleConns, cn) | |||
p.idleConnsLen++ | |||
p.connsMu.Unlock() | |||
p.freeTurn() | |||
} | |||
func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) { | |||
p.removeConnWithLock(cn) | |||
p.freeTurn() | |||
_ = p.closeConn(cn) | |||
} | |||
func (p *ConnPool) CloseConn(cn *Conn) error { | |||
p.removeConnWithLock(cn) | |||
return p.closeConn(cn) | |||
} | |||
func (p *ConnPool) removeConnWithLock(cn *Conn) { | |||
p.connsMu.Lock() | |||
p.removeConn(cn) | |||
p.connsMu.Unlock() | |||
} | |||
func (p *ConnPool) removeConn(cn *Conn) { | |||
for i, c := range p.conns { | |||
if c == cn { | |||
p.conns = append(p.conns[:i], p.conns[i+1:]...) | |||
if cn.pooled { | |||
p.poolSize-- | |||
p.checkMinIdleConns() | |||
} | |||
return | |||
} | |||
} | |||
} | |||
func (p *ConnPool) closeConn(cn *Conn) error { | |||
if p.opt.OnClose != nil { | |||
_ = p.opt.OnClose(cn) | |||
} | |||
return cn.Close() | |||
} | |||
// Len returns total number of connections. | |||
func (p *ConnPool) Len() int { | |||
p.connsMu.Lock() | |||
n := len(p.conns) | |||
p.connsMu.Unlock() | |||
return n | |||
} | |||
// IdleLen returns number of idle connections. | |||
func (p *ConnPool) IdleLen() int { | |||
p.connsMu.Lock() | |||
n := p.idleConnsLen | |||
p.connsMu.Unlock() | |||
return n | |||
} | |||
func (p *ConnPool) Stats() *Stats { | |||
idleLen := p.IdleLen() | |||
return &Stats{ | |||
Hits: atomic.LoadUint32(&p.stats.Hits), | |||
Misses: atomic.LoadUint32(&p.stats.Misses), | |||
Timeouts: atomic.LoadUint32(&p.stats.Timeouts), | |||
TotalConns: uint32(p.Len()), | |||
IdleConns: uint32(idleLen), | |||
StaleConns: atomic.LoadUint32(&p.stats.StaleConns), | |||
} | |||
} | |||
func (p *ConnPool) closed() bool { | |||
return atomic.LoadUint32(&p._closed) == 1 | |||
} | |||
func (p *ConnPool) Filter(fn func(*Conn) bool) error { | |||
p.connsMu.Lock() | |||
defer p.connsMu.Unlock() | |||
var firstErr error | |||
for _, cn := range p.conns { | |||
if fn(cn) { | |||
if err := p.closeConn(cn); err != nil && firstErr == nil { | |||
firstErr = err | |||
} | |||
} | |||
} | |||
return firstErr | |||
} | |||
func (p *ConnPool) Close() error { | |||
if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) { | |||
return ErrClosed | |||
} | |||
close(p.closedCh) | |||
var firstErr error | |||
p.connsMu.Lock() | |||
for _, cn := range p.conns { | |||
if err := p.closeConn(cn); err != nil && firstErr == nil { | |||
firstErr = err | |||
} | |||
} | |||
p.conns = nil | |||
p.poolSize = 0 | |||
p.idleConns = nil | |||
p.idleConnsLen = 0 | |||
p.connsMu.Unlock() | |||
return firstErr | |||
} | |||
func (p *ConnPool) reaper(frequency time.Duration) { | |||
ticker := time.NewTicker(frequency) | |||
defer ticker.Stop() | |||
for { | |||
select { | |||
case <-ticker.C: | |||
// It is possible that ticker and closedCh arrive together, | |||
// and select pseudo-randomly pick ticker case, we double | |||
// check here to prevent being executed after closed. | |||
if p.closed() { | |||
return | |||
} | |||
_, err := p.ReapStaleConns() | |||
if err != nil { | |||
internal.Logger.Printf(context.Background(), "ReapStaleConns failed: %s", err) | |||
continue | |||
} | |||
case <-p.closedCh: | |||
return | |||
} | |||
} | |||
} | |||
func (p *ConnPool) ReapStaleConns() (int, error) { | |||
var n int | |||
for { | |||
p.getTurn() | |||
p.connsMu.Lock() | |||
cn := p.reapStaleConn() | |||
p.connsMu.Unlock() | |||
p.freeTurn() | |||
if cn != nil { | |||
_ = p.closeConn(cn) | |||
n++ | |||
} else { | |||
break | |||
} | |||
} | |||
atomic.AddUint32(&p.stats.StaleConns, uint32(n)) | |||
return n, nil | |||
} | |||
func (p *ConnPool) reapStaleConn() *Conn { | |||
if len(p.idleConns) == 0 { | |||
return nil | |||
} | |||
cn := p.idleConns[0] | |||
if !p.isStaleConn(cn) { | |||
return nil | |||
} | |||
p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...) | |||
p.idleConnsLen-- | |||
p.removeConn(cn) | |||
return cn | |||
} | |||
func (p *ConnPool) isStaleConn(cn *Conn) bool { | |||
if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 { | |||
return false | |||
} | |||
now := time.Now() | |||
if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout { | |||
return true | |||
} | |||
if p.opt.MaxConnAge > 0 && now.Sub(cn.createdAt) >= p.opt.MaxConnAge { | |||
return true | |||
} | |||
return false | |||
} |
@@ -0,0 +1,58 @@ | |||
package pool | |||
import "context" | |||
type SingleConnPool struct { | |||
pool Pooler | |||
cn *Conn | |||
stickyErr error | |||
} | |||
var _ Pooler = (*SingleConnPool)(nil) | |||
func NewSingleConnPool(pool Pooler, cn *Conn) *SingleConnPool { | |||
return &SingleConnPool{ | |||
pool: pool, | |||
cn: cn, | |||
} | |||
} | |||
func (p *SingleConnPool) NewConn(ctx context.Context) (*Conn, error) { | |||
return p.pool.NewConn(ctx) | |||
} | |||
func (p *SingleConnPool) CloseConn(cn *Conn) error { | |||
return p.pool.CloseConn(cn) | |||
} | |||
func (p *SingleConnPool) Get(ctx context.Context) (*Conn, error) { | |||
if p.stickyErr != nil { | |||
return nil, p.stickyErr | |||
} | |||
return p.cn, nil | |||
} | |||
func (p *SingleConnPool) Put(ctx context.Context, cn *Conn) {} | |||
func (p *SingleConnPool) Remove(ctx context.Context, cn *Conn, reason error) { | |||
p.cn = nil | |||
p.stickyErr = reason | |||
} | |||
func (p *SingleConnPool) Close() error { | |||
p.cn = nil | |||
p.stickyErr = ErrClosed | |||
return nil | |||
} | |||
func (p *SingleConnPool) Len() int { | |||
return 0 | |||
} | |||
func (p *SingleConnPool) IdleLen() int { | |||
return 0 | |||
} | |||
func (p *SingleConnPool) Stats() *Stats { | |||
return &Stats{} | |||
} |
@@ -0,0 +1,201 @@ | |||
package pool | |||
import ( | |||
"context" | |||
"errors" | |||
"fmt" | |||
"sync/atomic" | |||
) | |||
const ( | |||
stateDefault = 0 | |||
stateInited = 1 | |||
stateClosed = 2 | |||
) | |||
type BadConnError struct { | |||
wrapped error | |||
} | |||
var _ error = (*BadConnError)(nil) | |||
func (e BadConnError) Error() string { | |||
s := "redis: Conn is in a bad state" | |||
if e.wrapped != nil { | |||
s += ": " + e.wrapped.Error() | |||
} | |||
return s | |||
} | |||
func (e BadConnError) Unwrap() error { | |||
return e.wrapped | |||
} | |||
//------------------------------------------------------------------------------ | |||
type StickyConnPool struct { | |||
pool Pooler | |||
shared int32 // atomic | |||
state uint32 // atomic | |||
ch chan *Conn | |||
_badConnError atomic.Value | |||
} | |||
var _ Pooler = (*StickyConnPool)(nil) | |||
func NewStickyConnPool(pool Pooler) *StickyConnPool { | |||
p, ok := pool.(*StickyConnPool) | |||
if !ok { | |||
p = &StickyConnPool{ | |||
pool: pool, | |||
ch: make(chan *Conn, 1), | |||
} | |||
} | |||
atomic.AddInt32(&p.shared, 1) | |||
return p | |||
} | |||
func (p *StickyConnPool) NewConn(ctx context.Context) (*Conn, error) { | |||
return p.pool.NewConn(ctx) | |||
} | |||
func (p *StickyConnPool) CloseConn(cn *Conn) error { | |||
return p.pool.CloseConn(cn) | |||
} | |||
func (p *StickyConnPool) Get(ctx context.Context) (*Conn, error) { | |||
// In worst case this races with Close which is not a very common operation. | |||
for i := 0; i < 1000; i++ { | |||
switch atomic.LoadUint32(&p.state) { | |||
case stateDefault: | |||
cn, err := p.pool.Get(ctx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if atomic.CompareAndSwapUint32(&p.state, stateDefault, stateInited) { | |||
return cn, nil | |||
} | |||
p.pool.Remove(ctx, cn, ErrClosed) | |||
case stateInited: | |||
if err := p.badConnError(); err != nil { | |||
return nil, err | |||
} | |||
cn, ok := <-p.ch | |||
if !ok { | |||
return nil, ErrClosed | |||
} | |||
return cn, nil | |||
case stateClosed: | |||
return nil, ErrClosed | |||
default: | |||
panic("not reached") | |||
} | |||
} | |||
return nil, fmt.Errorf("redis: StickyConnPool.Get: infinite loop") | |||
} | |||
func (p *StickyConnPool) Put(ctx context.Context, cn *Conn) { | |||
defer func() { | |||
if recover() != nil { | |||
p.freeConn(ctx, cn) | |||
} | |||
}() | |||
p.ch <- cn | |||
} | |||
func (p *StickyConnPool) freeConn(ctx context.Context, cn *Conn) { | |||
if err := p.badConnError(); err != nil { | |||
p.pool.Remove(ctx, cn, err) | |||
} else { | |||
p.pool.Put(ctx, cn) | |||
} | |||
} | |||
func (p *StickyConnPool) Remove(ctx context.Context, cn *Conn, reason error) { | |||
defer func() { | |||
if recover() != nil { | |||
p.pool.Remove(ctx, cn, ErrClosed) | |||
} | |||
}() | |||
p._badConnError.Store(BadConnError{wrapped: reason}) | |||
p.ch <- cn | |||
} | |||
func (p *StickyConnPool) Close() error { | |||
if shared := atomic.AddInt32(&p.shared, -1); shared > 0 { | |||
return nil | |||
} | |||
for i := 0; i < 1000; i++ { | |||
state := atomic.LoadUint32(&p.state) | |||
if state == stateClosed { | |||
return ErrClosed | |||
} | |||
if atomic.CompareAndSwapUint32(&p.state, state, stateClosed) { | |||
close(p.ch) | |||
cn, ok := <-p.ch | |||
if ok { | |||
p.freeConn(context.TODO(), cn) | |||
} | |||
return nil | |||
} | |||
} | |||
return errors.New("redis: StickyConnPool.Close: infinite loop") | |||
} | |||
func (p *StickyConnPool) Reset(ctx context.Context) error { | |||
if p.badConnError() == nil { | |||
return nil | |||
} | |||
select { | |||
case cn, ok := <-p.ch: | |||
if !ok { | |||
return ErrClosed | |||
} | |||
p.pool.Remove(ctx, cn, ErrClosed) | |||
p._badConnError.Store(BadConnError{wrapped: nil}) | |||
default: | |||
return errors.New("redis: StickyConnPool does not have a Conn") | |||
} | |||
if !atomic.CompareAndSwapUint32(&p.state, stateInited, stateDefault) { | |||
state := atomic.LoadUint32(&p.state) | |||
return fmt.Errorf("redis: invalid StickyConnPool state: %d", state) | |||
} | |||
return nil | |||
} | |||
func (p *StickyConnPool) badConnError() error { | |||
if v := p._badConnError.Load(); v != nil { | |||
if err := v.(BadConnError); err.wrapped != nil { | |||
return err | |||
} | |||
} | |||
return nil | |||
} | |||
func (p *StickyConnPool) Len() int { | |||
switch atomic.LoadUint32(&p.state) { | |||
case stateDefault: | |||
return 0 | |||
case stateInited: | |||
return 1 | |||
case stateClosed: | |||
return 0 | |||
default: | |||
panic("not reached") | |||
} | |||
} | |||
func (p *StickyConnPool) IdleLen() int { | |||
return len(p.ch) | |||
} | |||
func (p *StickyConnPool) Stats() *Stats { | |||
return &Stats{} | |||
} |
@@ -0,0 +1,332 @@ | |||
package proto | |||
import ( | |||
"bufio" | |||
"fmt" | |||
"io" | |||
"github.com/go-redis/redis/v8/internal/util" | |||
) | |||
// redis resp protocol data type. | |||
const ( | |||
ErrorReply = '-' | |||
StatusReply = '+' | |||
IntReply = ':' | |||
StringReply = '$' | |||
ArrayReply = '*' | |||
) | |||
//------------------------------------------------------------------------------ | |||
const Nil = RedisError("redis: nil") // nolint:errname | |||
type RedisError string | |||
func (e RedisError) Error() string { return string(e) } | |||
func (RedisError) RedisError() {} | |||
//------------------------------------------------------------------------------ | |||
type MultiBulkParse func(*Reader, int64) (interface{}, error) | |||
type Reader struct { | |||
rd *bufio.Reader | |||
_buf []byte | |||
} | |||
func NewReader(rd io.Reader) *Reader { | |||
return &Reader{ | |||
rd: bufio.NewReader(rd), | |||
_buf: make([]byte, 64), | |||
} | |||
} | |||
func (r *Reader) Buffered() int { | |||
return r.rd.Buffered() | |||
} | |||
func (r *Reader) Peek(n int) ([]byte, error) { | |||
return r.rd.Peek(n) | |||
} | |||
func (r *Reader) Reset(rd io.Reader) { | |||
r.rd.Reset(rd) | |||
} | |||
func (r *Reader) ReadLine() ([]byte, error) { | |||
line, err := r.readLine() | |||
if err != nil { | |||
return nil, err | |||
} | |||
if isNilReply(line) { | |||
return nil, Nil | |||
} | |||
return line, nil | |||
} | |||
// readLine that returns an error if: | |||
// - there is a pending read error; | |||
// - or line does not end with \r\n. | |||
func (r *Reader) readLine() ([]byte, error) { | |||
b, err := r.rd.ReadSlice('\n') | |||
if err != nil { | |||
if err != bufio.ErrBufferFull { | |||
return nil, err | |||
} | |||
full := make([]byte, len(b)) | |||
copy(full, b) | |||
b, err = r.rd.ReadBytes('\n') | |||
if err != nil { | |||
return nil, err | |||
} | |||
full = append(full, b...) //nolint:makezero | |||
b = full | |||
} | |||
if len(b) <= 2 || b[len(b)-1] != '\n' || b[len(b)-2] != '\r' { | |||
return nil, fmt.Errorf("redis: invalid reply: %q", b) | |||
} | |||
return b[:len(b)-2], nil | |||
} | |||
func (r *Reader) ReadReply(m MultiBulkParse) (interface{}, error) { | |||
line, err := r.ReadLine() | |||
if err != nil { | |||
return nil, err | |||
} | |||
switch line[0] { | |||
case ErrorReply: | |||
return nil, ParseErrorReply(line) | |||
case StatusReply: | |||
return string(line[1:]), nil | |||
case IntReply: | |||
return util.ParseInt(line[1:], 10, 64) | |||
case StringReply: | |||
return r.readStringReply(line) | |||
case ArrayReply: | |||
n, err := parseArrayLen(line) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if m == nil { | |||
err := fmt.Errorf("redis: got %.100q, but multi bulk parser is nil", line) | |||
return nil, err | |||
} | |||
return m(r, n) | |||
} | |||
return nil, fmt.Errorf("redis: can't parse %.100q", line) | |||
} | |||
func (r *Reader) ReadIntReply() (int64, error) { | |||
line, err := r.ReadLine() | |||
if err != nil { | |||
return 0, err | |||
} | |||
switch line[0] { | |||
case ErrorReply: | |||
return 0, ParseErrorReply(line) | |||
case IntReply: | |||
return util.ParseInt(line[1:], 10, 64) | |||
default: | |||
return 0, fmt.Errorf("redis: can't parse int reply: %.100q", line) | |||
} | |||
} | |||
func (r *Reader) ReadString() (string, error) { | |||
line, err := r.ReadLine() | |||
if err != nil { | |||
return "", err | |||
} | |||
switch line[0] { | |||
case ErrorReply: | |||
return "", ParseErrorReply(line) | |||
case StringReply: | |||
return r.readStringReply(line) | |||
case StatusReply: | |||
return string(line[1:]), nil | |||
case IntReply: | |||
return string(line[1:]), nil | |||
default: | |||
return "", fmt.Errorf("redis: can't parse reply=%.100q reading string", line) | |||
} | |||
} | |||
func (r *Reader) readStringReply(line []byte) (string, error) { | |||
if isNilReply(line) { | |||
return "", Nil | |||
} | |||
replyLen, err := util.Atoi(line[1:]) | |||
if err != nil { | |||
return "", err | |||
} | |||
b := make([]byte, replyLen+2) | |||
_, err = io.ReadFull(r.rd, b) | |||
if err != nil { | |||
return "", err | |||
} | |||
return util.BytesToString(b[:replyLen]), nil | |||
} | |||
func (r *Reader) ReadArrayReply(m MultiBulkParse) (interface{}, error) { | |||
line, err := r.ReadLine() | |||
if err != nil { | |||
return nil, err | |||
} | |||
switch line[0] { | |||
case ErrorReply: | |||
return nil, ParseErrorReply(line) | |||
case ArrayReply: | |||
n, err := parseArrayLen(line) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return m(r, n) | |||
default: | |||
return nil, fmt.Errorf("redis: can't parse array reply: %.100q", line) | |||
} | |||
} | |||
func (r *Reader) ReadArrayLen() (int, error) { | |||
line, err := r.ReadLine() | |||
if err != nil { | |||
return 0, err | |||
} | |||
switch line[0] { | |||
case ErrorReply: | |||
return 0, ParseErrorReply(line) | |||
case ArrayReply: | |||
n, err := parseArrayLen(line) | |||
if err != nil { | |||
return 0, err | |||
} | |||
return int(n), nil | |||
default: | |||
return 0, fmt.Errorf("redis: can't parse array reply: %.100q", line) | |||
} | |||
} | |||
func (r *Reader) ReadScanReply() ([]string, uint64, error) { | |||
n, err := r.ReadArrayLen() | |||
if err != nil { | |||
return nil, 0, err | |||
} | |||
if n != 2 { | |||
return nil, 0, fmt.Errorf("redis: got %d elements in scan reply, expected 2", n) | |||
} | |||
cursor, err := r.ReadUint() | |||
if err != nil { | |||
return nil, 0, err | |||
} | |||
n, err = r.ReadArrayLen() | |||
if err != nil { | |||
return nil, 0, err | |||
} | |||
keys := make([]string, n) | |||
for i := 0; i < n; i++ { | |||
key, err := r.ReadString() | |||
if err != nil { | |||
return nil, 0, err | |||
} | |||
keys[i] = key | |||
} | |||
return keys, cursor, err | |||
} | |||
func (r *Reader) ReadInt() (int64, error) { | |||
b, err := r.readTmpBytesReply() | |||
if err != nil { | |||
return 0, err | |||
} | |||
return util.ParseInt(b, 10, 64) | |||
} | |||
func (r *Reader) ReadUint() (uint64, error) { | |||
b, err := r.readTmpBytesReply() | |||
if err != nil { | |||
return 0, err | |||
} | |||
return util.ParseUint(b, 10, 64) | |||
} | |||
func (r *Reader) ReadFloatReply() (float64, error) { | |||
b, err := r.readTmpBytesReply() | |||
if err != nil { | |||
return 0, err | |||
} | |||
return util.ParseFloat(b, 64) | |||
} | |||
func (r *Reader) readTmpBytesReply() ([]byte, error) { | |||
line, err := r.ReadLine() | |||
if err != nil { | |||
return nil, err | |||
} | |||
switch line[0] { | |||
case ErrorReply: | |||
return nil, ParseErrorReply(line) | |||
case StringReply: | |||
return r._readTmpBytesReply(line) | |||
case StatusReply: | |||
return line[1:], nil | |||
default: | |||
return nil, fmt.Errorf("redis: can't parse string reply: %.100q", line) | |||
} | |||
} | |||
func (r *Reader) _readTmpBytesReply(line []byte) ([]byte, error) { | |||
if isNilReply(line) { | |||
return nil, Nil | |||
} | |||
replyLen, err := util.Atoi(line[1:]) | |||
if err != nil { | |||
return nil, err | |||
} | |||
buf := r.buf(replyLen + 2) | |||
_, err = io.ReadFull(r.rd, buf) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return buf[:replyLen], nil | |||
} | |||
func (r *Reader) buf(n int) []byte { | |||
if n <= cap(r._buf) { | |||
return r._buf[:n] | |||
} | |||
d := n - cap(r._buf) | |||
r._buf = append(r._buf, make([]byte, d)...) | |||
return r._buf | |||
} | |||
func isNilReply(b []byte) bool { | |||
return len(b) == 3 && | |||
(b[0] == StringReply || b[0] == ArrayReply) && | |||
b[1] == '-' && b[2] == '1' | |||
} | |||
func ParseErrorReply(line []byte) error { | |||
return RedisError(string(line[1:])) | |||
} | |||
func parseArrayLen(line []byte) (int64, error) { | |||
if isNilReply(line) { | |||
return 0, Nil | |||
} | |||
return util.ParseInt(line[1:], 10, 64) | |||
} |
@@ -0,0 +1,172 @@ | |||
package proto | |||
import ( | |||
"encoding" | |||
"fmt" | |||
"reflect" | |||
"time" | |||
"github.com/go-redis/redis/v8/internal/util" | |||
) | |||
// Scan parses bytes `b` to `v` with appropriate type. | |||
func Scan(b []byte, v interface{}) error { | |||
switch v := v.(type) { | |||
case nil: | |||
return fmt.Errorf("redis: Scan(nil)") | |||
case *string: | |||
*v = util.BytesToString(b) | |||
return nil | |||
case *[]byte: | |||
*v = b | |||
return nil | |||
case *int: | |||
var err error | |||
*v, err = util.Atoi(b) | |||
return err | |||
case *int8: | |||
n, err := util.ParseInt(b, 10, 8) | |||
if err != nil { | |||
return err | |||
} | |||
*v = int8(n) | |||
return nil | |||
case *int16: | |||
n, err := util.ParseInt(b, 10, 16) | |||
if err != nil { | |||
return err | |||
} | |||
*v = int16(n) | |||
return nil | |||
case *int32: | |||
n, err := util.ParseInt(b, 10, 32) | |||
if err != nil { | |||
return err | |||
} | |||
*v = int32(n) | |||
return nil | |||
case *int64: | |||
n, err := util.ParseInt(b, 10, 64) | |||
if err != nil { | |||
return err | |||
} | |||
*v = n | |||
return nil | |||
case *uint: | |||
n, err := util.ParseUint(b, 10, 64) | |||
if err != nil { | |||
return err | |||
} | |||
*v = uint(n) | |||
return nil | |||
case *uint8: | |||
n, err := util.ParseUint(b, 10, 8) | |||
if err != nil { | |||
return err | |||
} | |||
*v = uint8(n) | |||
return nil | |||
case *uint16: | |||
n, err := util.ParseUint(b, 10, 16) | |||
if err != nil { | |||
return err | |||
} | |||
*v = uint16(n) | |||
return nil | |||
case *uint32: | |||
n, err := util.ParseUint(b, 10, 32) | |||
if err != nil { | |||
return err | |||
} | |||
*v = uint32(n) | |||
return nil | |||
case *uint64: | |||
n, err := util.ParseUint(b, 10, 64) | |||
if err != nil { | |||
return err | |||
} | |||
*v = n | |||
return nil | |||
case *float32: | |||
n, err := util.ParseFloat(b, 32) | |||
if err != nil { | |||
return err | |||
} | |||
*v = float32(n) | |||
return err | |||
case *float64: | |||
var err error | |||
*v, err = util.ParseFloat(b, 64) | |||
return err | |||
case *bool: | |||
*v = len(b) == 1 && b[0] == '1' | |||
return nil | |||
case *time.Time: | |||
var err error | |||
*v, err = time.Parse(time.RFC3339Nano, util.BytesToString(b)) | |||
return err | |||
case encoding.BinaryUnmarshaler: | |||
return v.UnmarshalBinary(b) | |||
default: | |||
return fmt.Errorf( | |||
"redis: can't unmarshal %T (consider implementing BinaryUnmarshaler)", v) | |||
} | |||
} | |||
func ScanSlice(data []string, slice interface{}) error { | |||
v := reflect.ValueOf(slice) | |||
if !v.IsValid() { | |||
return fmt.Errorf("redis: ScanSlice(nil)") | |||
} | |||
if v.Kind() != reflect.Ptr { | |||
return fmt.Errorf("redis: ScanSlice(non-pointer %T)", slice) | |||
} | |||
v = v.Elem() | |||
if v.Kind() != reflect.Slice { | |||
return fmt.Errorf("redis: ScanSlice(non-slice %T)", slice) | |||
} | |||
next := makeSliceNextElemFunc(v) | |||
for i, s := range data { | |||
elem := next() | |||
if err := Scan([]byte(s), elem.Addr().Interface()); err != nil { | |||
err = fmt.Errorf("redis: ScanSlice index=%d value=%q failed: %w", i, s, err) | |||
return err | |||
} | |||
} | |||
return nil | |||
} | |||
func makeSliceNextElemFunc(v reflect.Value) func() reflect.Value { | |||
elemType := v.Type().Elem() | |||
if elemType.Kind() == reflect.Ptr { | |||
elemType = elemType.Elem() | |||
return func() reflect.Value { | |||
if v.Len() < v.Cap() { | |||
v.Set(v.Slice(0, v.Len()+1)) | |||
elem := v.Index(v.Len() - 1) | |||
if elem.IsNil() { | |||
elem.Set(reflect.New(elemType)) | |||
} | |||
return elem.Elem() | |||
} | |||
elem := reflect.New(elemType) | |||
v.Set(reflect.Append(v, elem)) | |||
return elem.Elem() | |||
} | |||
} | |||
zero := reflect.Zero(elemType) | |||
return func() reflect.Value { | |||
if v.Len() < v.Cap() { | |||
v.Set(v.Slice(0, v.Len()+1)) | |||
return v.Index(v.Len() - 1) | |||
} | |||
v.Set(reflect.Append(v, zero)) | |||
return v.Index(v.Len() - 1) | |||
} | |||
} |
@@ -0,0 +1,153 @@ | |||
package proto | |||
import ( | |||
"encoding" | |||
"fmt" | |||
"io" | |||
"strconv" | |||
"time" | |||
"github.com/go-redis/redis/v8/internal/util" | |||
) | |||
type writer interface { | |||
io.Writer | |||
io.ByteWriter | |||
// io.StringWriter | |||
WriteString(s string) (n int, err error) | |||
} | |||
type Writer struct { | |||
writer | |||
lenBuf []byte | |||
numBuf []byte | |||
} | |||
func NewWriter(wr writer) *Writer { | |||
return &Writer{ | |||
writer: wr, | |||
lenBuf: make([]byte, 64), | |||
numBuf: make([]byte, 64), | |||
} | |||
} | |||
func (w *Writer) WriteArgs(args []interface{}) error { | |||
if err := w.WriteByte(ArrayReply); err != nil { | |||
return err | |||
} | |||
if err := w.writeLen(len(args)); err != nil { | |||
return err | |||
} | |||
for _, arg := range args { | |||
if err := w.WriteArg(arg); err != nil { | |||
return err | |||
} | |||
} | |||
return nil | |||
} | |||
func (w *Writer) writeLen(n int) error { | |||
w.lenBuf = strconv.AppendUint(w.lenBuf[:0], uint64(n), 10) | |||
w.lenBuf = append(w.lenBuf, '\r', '\n') | |||
_, err := w.Write(w.lenBuf) | |||
return err | |||
} | |||
func (w *Writer) WriteArg(v interface{}) error { | |||
switch v := v.(type) { | |||
case nil: | |||
return w.string("") | |||
case string: | |||
return w.string(v) | |||
case []byte: | |||
return w.bytes(v) | |||
case int: | |||
return w.int(int64(v)) | |||
case int8: | |||
return w.int(int64(v)) | |||
case int16: | |||
return w.int(int64(v)) | |||
case int32: | |||
return w.int(int64(v)) | |||
case int64: | |||
return w.int(v) | |||
case uint: | |||
return w.uint(uint64(v)) | |||
case uint8: | |||
return w.uint(uint64(v)) | |||
case uint16: | |||
return w.uint(uint64(v)) | |||
case uint32: | |||
return w.uint(uint64(v)) | |||
case uint64: | |||
return w.uint(v) | |||
case float32: | |||
return w.float(float64(v)) | |||
case float64: | |||
return w.float(v) | |||
case bool: | |||
if v { | |||
return w.int(1) | |||
} | |||
return w.int(0) | |||
case time.Time: | |||
w.numBuf = v.AppendFormat(w.numBuf[:0], time.RFC3339Nano) | |||
return w.bytes(w.numBuf) | |||
case encoding.BinaryMarshaler: | |||
b, err := v.MarshalBinary() | |||
if err != nil { | |||
return err | |||
} | |||
return w.bytes(b) | |||
default: | |||
return fmt.Errorf( | |||
"redis: can't marshal %T (implement encoding.BinaryMarshaler)", v) | |||
} | |||
} | |||
func (w *Writer) bytes(b []byte) error { | |||
if err := w.WriteByte(StringReply); err != nil { | |||
return err | |||
} | |||
if err := w.writeLen(len(b)); err != nil { | |||
return err | |||
} | |||
if _, err := w.Write(b); err != nil { | |||
return err | |||
} | |||
return w.crlf() | |||
} | |||
func (w *Writer) string(s string) error { | |||
return w.bytes(util.StringToBytes(s)) | |||
} | |||
func (w *Writer) uint(n uint64) error { | |||
w.numBuf = strconv.AppendUint(w.numBuf[:0], n, 10) | |||
return w.bytes(w.numBuf) | |||
} | |||
func (w *Writer) int(n int64) error { | |||
w.numBuf = strconv.AppendInt(w.numBuf[:0], n, 10) | |||
return w.bytes(w.numBuf) | |||
} | |||
func (w *Writer) float(f float64) error { | |||
w.numBuf = strconv.AppendFloat(w.numBuf[:0], f, 'f', -1, 64) | |||
return w.bytes(w.numBuf) | |||
} | |||
func (w *Writer) crlf() error { | |||
if err := w.WriteByte('\r'); err != nil { | |||
return err | |||
} | |||
return w.WriteByte('\n') | |||
} |
@@ -0,0 +1,50 @@ | |||
package rand | |||
import ( | |||
"math/rand" | |||
"sync" | |||
) | |||
// Int returns a non-negative pseudo-random int. | |||
func Int() int { return pseudo.Int() } | |||
// Intn returns, as an int, a non-negative pseudo-random number in [0,n). | |||
// It panics if n <= 0. | |||
func Intn(n int) int { return pseudo.Intn(n) } | |||
// Int63n returns, as an int64, a non-negative pseudo-random number in [0,n). | |||
// It panics if n <= 0. | |||
func Int63n(n int64) int64 { return pseudo.Int63n(n) } | |||
// Perm returns, as a slice of n ints, a pseudo-random permutation of the integers [0,n). | |||
func Perm(n int) []int { return pseudo.Perm(n) } | |||
// Seed uses the provided seed value to initialize the default Source to a | |||
// deterministic state. If Seed is not called, the generator behaves as if | |||
// seeded by Seed(1). | |||
func Seed(n int64) { pseudo.Seed(n) } | |||
var pseudo = rand.New(&source{src: rand.NewSource(1)}) | |||
type source struct { | |||
src rand.Source | |||
mu sync.Mutex | |||
} | |||
func (s *source) Int63() int64 { | |||
s.mu.Lock() | |||
n := s.src.Int63() | |||
s.mu.Unlock() | |||
return n | |||
} | |||
func (s *source) Seed(seed int64) { | |||
s.mu.Lock() | |||
s.src.Seed(seed) | |||
s.mu.Unlock() | |||
} | |||
// Shuffle pseudo-randomizes the order of elements. | |||
// n is the number of elements. | |||
// swap swaps the elements with indexes i and j. | |||
func Shuffle(n int, swap func(i, j int)) { pseudo.Shuffle(n, swap) } |
@@ -0,0 +1,12 @@ | |||
//go:build appengine | |||
// +build appengine | |||
package internal | |||
func String(b []byte) string { | |||
return string(b) | |||
} | |||
func Bytes(s string) []byte { | |||
return []byte(s) | |||
} |
@@ -0,0 +1,21 @@ | |||
//go:build !appengine | |||
// +build !appengine | |||
package internal | |||
import "unsafe" | |||
// String converts byte slice to string. | |||
func String(b []byte) string { | |||
return *(*string)(unsafe.Pointer(&b)) | |||
} | |||
// Bytes converts string to byte slice. | |||
func Bytes(s string) []byte { | |||
return *(*[]byte)(unsafe.Pointer( | |||
&struct { | |||
string | |||
Cap int | |||
}{s, len(s)}, | |||
)) | |||
} |
@@ -0,0 +1,46 @@ | |||
package internal | |||
import ( | |||
"context" | |||
"time" | |||
"github.com/go-redis/redis/v8/internal/util" | |||
) | |||
func Sleep(ctx context.Context, dur time.Duration) error { | |||
t := time.NewTimer(dur) | |||
defer t.Stop() | |||
select { | |||
case <-t.C: | |||
return nil | |||
case <-ctx.Done(): | |||
return ctx.Err() | |||
} | |||
} | |||
func ToLower(s string) string { | |||
if isLower(s) { | |||
return s | |||
} | |||
b := make([]byte, len(s)) | |||
for i := range b { | |||
c := s[i] | |||
if c >= 'A' && c <= 'Z' { | |||
c += 'a' - 'A' | |||
} | |||
b[i] = c | |||
} | |||
return util.BytesToString(b) | |||
} | |||
func isLower(s string) bool { | |||
for i := 0; i < len(s); i++ { | |||
c := s[i] | |||
if c >= 'A' && c <= 'Z' { | |||
return false | |||
} | |||
} | |||
return true | |||
} |
@@ -0,0 +1,12 @@ | |||
//go:build appengine | |||
// +build appengine | |||
package util | |||
func BytesToString(b []byte) string { | |||
return string(b) | |||
} | |||
func StringToBytes(s string) []byte { | |||
return []byte(s) | |||
} |
@@ -0,0 +1,19 @@ | |||
package util | |||
import "strconv" | |||
func Atoi(b []byte) (int, error) { | |||
return strconv.Atoi(BytesToString(b)) | |||
} | |||
func ParseInt(b []byte, base int, bitSize int) (int64, error) { | |||
return strconv.ParseInt(BytesToString(b), base, bitSize) | |||
} | |||
func ParseUint(b []byte, base int, bitSize int) (uint64, error) { | |||
return strconv.ParseUint(BytesToString(b), base, bitSize) | |||
} | |||
func ParseFloat(b []byte, bitSize int) (float64, error) { | |||
return strconv.ParseFloat(BytesToString(b), bitSize) | |||
} |
@@ -0,0 +1,23 @@ | |||
//go:build !appengine | |||
// +build !appengine | |||
package util | |||
import ( | |||
"unsafe" | |||
) | |||
// BytesToString converts byte slice to string. | |||
func BytesToString(b []byte) string { | |||
return *(*string)(unsafe.Pointer(&b)) | |||
} | |||
// StringToBytes converts string to byte slice. | |||
func StringToBytes(s string) []byte { | |||
return *(*[]byte)(unsafe.Pointer( | |||
&struct { | |||
string | |||
Cap int | |||
}{s, len(s)}, | |||
)) | |||
} |
@@ -0,0 +1,77 @@ | |||
package redis | |||
import ( | |||
"context" | |||
"sync" | |||
) | |||
// ScanIterator is used to incrementally iterate over a collection of elements. | |||
// It's safe for concurrent use by multiple goroutines. | |||
type ScanIterator struct { | |||
mu sync.Mutex // protects Scanner and pos | |||
cmd *ScanCmd | |||
pos int | |||
} | |||
// Err returns the last iterator error, if any. | |||
func (it *ScanIterator) Err() error { | |||
it.mu.Lock() | |||
err := it.cmd.Err() | |||
it.mu.Unlock() | |||
return err | |||
} | |||
// Next advances the cursor and returns true if more values can be read. | |||
func (it *ScanIterator) Next(ctx context.Context) bool { | |||
it.mu.Lock() | |||
defer it.mu.Unlock() | |||
// Instantly return on errors. | |||
if it.cmd.Err() != nil { | |||
return false | |||
} | |||
// Advance cursor, check if we are still within range. | |||
if it.pos < len(it.cmd.page) { | |||
it.pos++ | |||
return true | |||
} | |||
for { | |||
// Return if there is no more data to fetch. | |||
if it.cmd.cursor == 0 { | |||
return false | |||
} | |||
// Fetch next page. | |||
switch it.cmd.args[0] { | |||
case "scan", "qscan": | |||
it.cmd.args[1] = it.cmd.cursor | |||
default: | |||
it.cmd.args[2] = it.cmd.cursor | |||
} | |||
err := it.cmd.process(ctx, it.cmd) | |||
if err != nil { | |||
return false | |||
} | |||
it.pos = 1 | |||
// Redis can occasionally return empty page. | |||
if len(it.cmd.page) > 0 { | |||
return true | |||
} | |||
} | |||
} | |||
// Val returns the key/field at the current cursor position. | |||
func (it *ScanIterator) Val() string { | |||
var v string | |||
it.mu.Lock() | |||
if it.cmd.Err() == nil && it.pos > 0 && it.pos <= len(it.cmd.page) { | |||
v = it.cmd.page[it.pos-1] | |||
} | |||
it.mu.Unlock() | |||
return v | |||
} |
@@ -0,0 +1,429 @@ | |||
package redis | |||
import ( | |||
"context" | |||
"crypto/tls" | |||
"errors" | |||
"fmt" | |||
"net" | |||
"net/url" | |||
"runtime" | |||
"sort" | |||
"strconv" | |||
"strings" | |||
"time" | |||
"github.com/go-redis/redis/v8/internal/pool" | |||
) | |||
// Limiter is the interface of a rate limiter or a circuit breaker. | |||
type Limiter interface { | |||
// Allow returns nil if operation is allowed or an error otherwise. | |||
// If operation is allowed client must ReportResult of the operation | |||
// whether it is a success or a failure. | |||
Allow() error | |||
// ReportResult reports the result of the previously allowed operation. | |||
// nil indicates a success, non-nil error usually indicates a failure. | |||
ReportResult(result error) | |||
} | |||
// Options keeps the settings to setup redis connection. | |||
type Options struct { | |||
// The network type, either tcp or unix. | |||
// Default is tcp. | |||
Network string | |||
// host:port address. | |||
Addr string | |||
// Dialer creates new network connection and has priority over | |||
// Network and Addr options. | |||
Dialer func(ctx context.Context, network, addr string) (net.Conn, error) | |||
// Hook that is called when new connection is established. | |||
OnConnect func(ctx context.Context, cn *Conn) error | |||
// Use the specified Username to authenticate the current connection | |||
// with one of the connections defined in the ACL list when connecting | |||
// to a Redis 6.0 instance, or greater, that is using the Redis ACL system. | |||
Username string | |||
// Optional password. Must match the password specified in the | |||
// requirepass server configuration option (if connecting to a Redis 5.0 instance, or lower), | |||
// or the User Password when connecting to a Redis 6.0 instance, or greater, | |||
// that is using the Redis ACL system. | |||
Password string | |||
// Database to be selected after connecting to the server. | |||
DB int | |||
// Maximum number of retries before giving up. | |||
// Default is 3 retries; -1 (not 0) disables retries. | |||
MaxRetries int | |||
// Minimum backoff between each retry. | |||
// Default is 8 milliseconds; -1 disables backoff. | |||
MinRetryBackoff time.Duration | |||
// Maximum backoff between each retry. | |||
// Default is 512 milliseconds; -1 disables backoff. | |||
MaxRetryBackoff time.Duration | |||
// Dial timeout for establishing new connections. | |||
// Default is 5 seconds. | |||
DialTimeout time.Duration | |||
// Timeout for socket reads. If reached, commands will fail | |||
// with a timeout instead of blocking. Use value -1 for no timeout and 0 for default. | |||
// Default is 3 seconds. | |||
ReadTimeout time.Duration | |||
// Timeout for socket writes. If reached, commands will fail | |||
// with a timeout instead of blocking. | |||
// Default is ReadTimeout. | |||
WriteTimeout time.Duration | |||
// Type of connection pool. | |||
// true for FIFO pool, false for LIFO pool. | |||
// Note that fifo has higher overhead compared to lifo. | |||
PoolFIFO bool | |||
// Maximum number of socket connections. | |||
// Default is 10 connections per every available CPU as reported by runtime.GOMAXPROCS. | |||
PoolSize int | |||
// Minimum number of idle connections which is useful when establishing | |||
// new connection is slow. | |||
MinIdleConns int | |||
// Connection age at which client retires (closes) the connection. | |||
// Default is to not close aged connections. | |||
MaxConnAge time.Duration | |||
// Amount of time client waits for connection if all connections | |||
// are busy before returning an error. | |||
// Default is ReadTimeout + 1 second. | |||
PoolTimeout time.Duration | |||
// Amount of time after which client closes idle connections. | |||
// Should be less than server's timeout. | |||
// Default is 5 minutes. -1 disables idle timeout check. | |||
IdleTimeout time.Duration | |||
// Frequency of idle checks made by idle connections reaper. | |||
// Default is 1 minute. -1 disables idle connections reaper, | |||
// but idle connections are still discarded by the client | |||
// if IdleTimeout is set. | |||
IdleCheckFrequency time.Duration | |||
// Enables read only queries on slave nodes. | |||
readOnly bool | |||
// TLS Config to use. When set TLS will be negotiated. | |||
TLSConfig *tls.Config | |||
// Limiter interface used to implemented circuit breaker or rate limiter. | |||
Limiter Limiter | |||
} | |||
func (opt *Options) init() { | |||
if opt.Addr == "" { | |||
opt.Addr = "localhost:6379" | |||
} | |||
if opt.Network == "" { | |||
if strings.HasPrefix(opt.Addr, "/") { | |||
opt.Network = "unix" | |||
} else { | |||
opt.Network = "tcp" | |||
} | |||
} | |||
if opt.DialTimeout == 0 { | |||
opt.DialTimeout = 5 * time.Second | |||
} | |||
if opt.Dialer == nil { | |||
opt.Dialer = func(ctx context.Context, network, addr string) (net.Conn, error) { | |||
netDialer := &net.Dialer{ | |||
Timeout: opt.DialTimeout, | |||
KeepAlive: 5 * time.Minute, | |||
} | |||
if opt.TLSConfig == nil { | |||
return netDialer.DialContext(ctx, network, addr) | |||
} | |||
return tls.DialWithDialer(netDialer, network, addr, opt.TLSConfig) | |||
} | |||
} | |||
if opt.PoolSize == 0 { | |||
opt.PoolSize = 10 * runtime.GOMAXPROCS(0) | |||
} | |||
switch opt.ReadTimeout { | |||
case -1: | |||
opt.ReadTimeout = 0 | |||
case 0: | |||
opt.ReadTimeout = 3 * time.Second | |||
} | |||
switch opt.WriteTimeout { | |||
case -1: | |||
opt.WriteTimeout = 0 | |||
case 0: | |||
opt.WriteTimeout = opt.ReadTimeout | |||
} | |||
if opt.PoolTimeout == 0 { | |||
opt.PoolTimeout = opt.ReadTimeout + time.Second | |||
} | |||
if opt.IdleTimeout == 0 { | |||
opt.IdleTimeout = 5 * time.Minute | |||
} | |||
if opt.IdleCheckFrequency == 0 { | |||
opt.IdleCheckFrequency = time.Minute | |||
} | |||
if opt.MaxRetries == -1 { | |||
opt.MaxRetries = 0 | |||
} else if opt.MaxRetries == 0 { | |||
opt.MaxRetries = 3 | |||
} | |||
switch opt.MinRetryBackoff { | |||
case -1: | |||
opt.MinRetryBackoff = 0 | |||
case 0: | |||
opt.MinRetryBackoff = 8 * time.Millisecond | |||
} | |||
switch opt.MaxRetryBackoff { | |||
case -1: | |||
opt.MaxRetryBackoff = 0 | |||
case 0: | |||
opt.MaxRetryBackoff = 512 * time.Millisecond | |||
} | |||
} | |||
func (opt *Options) clone() *Options { | |||
clone := *opt | |||
return &clone | |||
} | |||
// ParseURL parses an URL into Options that can be used to connect to Redis. | |||
// Scheme is required. | |||
// There are two connection types: by tcp socket and by unix socket. | |||
// Tcp connection: | |||
// redis://<user>:<password>@<host>:<port>/<db_number> | |||
// Unix connection: | |||
// unix://<user>:<password>@</path/to/redis.sock>?db=<db_number> | |||
// Most Option fields can be set using query parameters, with the following restrictions: | |||
// - field names are mapped using snake-case conversion: to set MaxRetries, use max_retries | |||
// - only scalar type fields are supported (bool, int, time.Duration) | |||
// - for time.Duration fields, values must be a valid input for time.ParseDuration(); | |||
// additionally a plain integer as value (i.e. without unit) is intepreted as seconds | |||
// - to disable a duration field, use value less than or equal to 0; to use the default | |||
// value, leave the value blank or remove the parameter | |||
// - only the last value is interpreted if a parameter is given multiple times | |||
// - fields "network", "addr", "username" and "password" can only be set using other | |||
// URL attributes (scheme, host, userinfo, resp.), query paremeters using these | |||
// names will be treated as unknown parameters | |||
// - unknown parameter names will result in an error | |||
// Examples: | |||
// redis://user:password@localhost:6789/3?dial_timeout=3&db=1&read_timeout=6s&max_retries=2 | |||
// is equivalent to: | |||
// &Options{ | |||
// Network: "tcp", | |||
// Addr: "localhost:6789", | |||
// DB: 1, // path "/3" was overridden by "&db=1" | |||
// DialTimeout: 3 * time.Second, // no time unit = seconds | |||
// ReadTimeout: 6 * time.Second, | |||
// MaxRetries: 2, | |||
// } | |||
func ParseURL(redisURL string) (*Options, error) { | |||
u, err := url.Parse(redisURL) | |||
if err != nil { | |||
return nil, err | |||
} | |||
switch u.Scheme { | |||
case "redis", "rediss": | |||
return setupTCPConn(u) | |||
case "unix": | |||
return setupUnixConn(u) | |||
default: | |||
return nil, fmt.Errorf("redis: invalid URL scheme: %s", u.Scheme) | |||
} | |||
} | |||
func setupTCPConn(u *url.URL) (*Options, error) { | |||
o := &Options{Network: "tcp"} | |||
o.Username, o.Password = getUserPassword(u) | |||
h, p, err := net.SplitHostPort(u.Host) | |||
if err != nil { | |||
h = u.Host | |||
} | |||
if h == "" { | |||
h = "localhost" | |||
} | |||
if p == "" { | |||
p = "6379" | |||
} | |||
o.Addr = net.JoinHostPort(h, p) | |||
f := strings.FieldsFunc(u.Path, func(r rune) bool { | |||
return r == '/' | |||
}) | |||
switch len(f) { | |||
case 0: | |||
o.DB = 0 | |||
case 1: | |||
if o.DB, err = strconv.Atoi(f[0]); err != nil { | |||
return nil, fmt.Errorf("redis: invalid database number: %q", f[0]) | |||
} | |||
default: | |||
return nil, fmt.Errorf("redis: invalid URL path: %s", u.Path) | |||
} | |||
if u.Scheme == "rediss" { | |||
o.TLSConfig = &tls.Config{ServerName: h} | |||
} | |||
return setupConnParams(u, o) | |||
} | |||
func setupUnixConn(u *url.URL) (*Options, error) { | |||
o := &Options{ | |||
Network: "unix", | |||
} | |||
if strings.TrimSpace(u.Path) == "" { // path is required with unix connection | |||
return nil, errors.New("redis: empty unix socket path") | |||
} | |||
o.Addr = u.Path | |||
o.Username, o.Password = getUserPassword(u) | |||
return setupConnParams(u, o) | |||
} | |||
type queryOptions struct { | |||
q url.Values | |||
err error | |||
} | |||
func (o *queryOptions) string(name string) string { | |||
vs := o.q[name] | |||
if len(vs) == 0 { | |||
return "" | |||
} | |||
delete(o.q, name) // enable detection of unknown parameters | |||
return vs[len(vs)-1] | |||
} | |||
func (o *queryOptions) int(name string) int { | |||
s := o.string(name) | |||
if s == "" { | |||
return 0 | |||
} | |||
i, err := strconv.Atoi(s) | |||
if err == nil { | |||
return i | |||
} | |||
if o.err == nil { | |||
o.err = fmt.Errorf("redis: invalid %s number: %s", name, err) | |||
} | |||
return 0 | |||
} | |||
func (o *queryOptions) duration(name string) time.Duration { | |||
s := o.string(name) | |||
if s == "" { | |||
return 0 | |||
} | |||
// try plain number first | |||
if i, err := strconv.Atoi(s); err == nil { | |||
if i <= 0 { | |||
// disable timeouts | |||
return -1 | |||
} | |||
return time.Duration(i) * time.Second | |||
} | |||
dur, err := time.ParseDuration(s) | |||
if err == nil { | |||
return dur | |||
} | |||
if o.err == nil { | |||
o.err = fmt.Errorf("redis: invalid %s duration: %w", name, err) | |||
} | |||
return 0 | |||
} | |||
func (o *queryOptions) bool(name string) bool { | |||
switch s := o.string(name); s { | |||
case "true", "1": | |||
return true | |||
case "false", "0", "": | |||
return false | |||
default: | |||
if o.err == nil { | |||
o.err = fmt.Errorf("redis: invalid %s boolean: expected true/false/1/0 or an empty string, got %q", name, s) | |||
} | |||
return false | |||
} | |||
} | |||
func (o *queryOptions) remaining() []string { | |||
if len(o.q) == 0 { | |||
return nil | |||
} | |||
keys := make([]string, 0, len(o.q)) | |||
for k := range o.q { | |||
keys = append(keys, k) | |||
} | |||
sort.Strings(keys) | |||
return keys | |||
} | |||
// setupConnParams converts query parameters in u to option value in o. | |||
func setupConnParams(u *url.URL, o *Options) (*Options, error) { | |||
q := queryOptions{q: u.Query()} | |||
// compat: a future major release may use q.int("db") | |||
if tmp := q.string("db"); tmp != "" { | |||
db, err := strconv.Atoi(tmp) | |||
if err != nil { | |||
return nil, fmt.Errorf("redis: invalid database number: %w", err) | |||
} | |||
o.DB = db | |||
} | |||
o.MaxRetries = q.int("max_retries") | |||
o.MinRetryBackoff = q.duration("min_retry_backoff") | |||
o.MaxRetryBackoff = q.duration("max_retry_backoff") | |||
o.DialTimeout = q.duration("dial_timeout") | |||
o.ReadTimeout = q.duration("read_timeout") | |||
o.WriteTimeout = q.duration("write_timeout") | |||
o.PoolFIFO = q.bool("pool_fifo") | |||
o.PoolSize = q.int("pool_size") | |||
o.MinIdleConns = q.int("min_idle_conns") | |||
o.MaxConnAge = q.duration("max_conn_age") | |||
o.PoolTimeout = q.duration("pool_timeout") | |||
o.IdleTimeout = q.duration("idle_timeout") | |||
o.IdleCheckFrequency = q.duration("idle_check_frequency") | |||
if q.err != nil { | |||
return nil, q.err | |||
} | |||
// any parameters left? | |||
if r := q.remaining(); len(r) > 0 { | |||
return nil, fmt.Errorf("redis: unexpected option: %s", strings.Join(r, ", ")) | |||
} | |||
return o, nil | |||
} | |||
func getUserPassword(u *url.URL) (string, string) { | |||
var user, password string | |||
if u.User != nil { | |||
user = u.User.Username() | |||
if p, ok := u.User.Password(); ok { | |||
password = p | |||
} | |||
} | |||
return user, password | |||
} | |||
func newConnPool(opt *Options) *pool.ConnPool { | |||
return pool.NewConnPool(&pool.Options{ | |||
Dialer: func(ctx context.Context) (net.Conn, error) { | |||
return opt.Dialer(ctx, opt.Network, opt.Addr) | |||
}, | |||
PoolFIFO: opt.PoolFIFO, | |||
PoolSize: opt.PoolSize, | |||
MinIdleConns: opt.MinIdleConns, | |||
MaxConnAge: opt.MaxConnAge, | |||
PoolTimeout: opt.PoolTimeout, | |||
IdleTimeout: opt.IdleTimeout, | |||
IdleCheckFrequency: opt.IdleCheckFrequency, | |||
}) | |||
} |
@@ -0,0 +1,8 @@ | |||
{ | |||
"name": "redis", | |||
"version": "8.11.4", | |||
"main": "index.js", | |||
"repository": "git@github.com:go-redis/redis.git", | |||
"author": "Vladimir Mihailenco <vladimir.webdev@gmail.com>", | |||
"license": "BSD-2-clause" | |||
} |
@@ -0,0 +1,137 @@ | |||
package redis | |||
import ( | |||
"context" | |||
"sync" | |||
"github.com/go-redis/redis/v8/internal/pool" | |||
) | |||
type pipelineExecer func(context.Context, []Cmder) error | |||
// Pipeliner is an mechanism to realise Redis Pipeline technique. | |||
// | |||
// Pipelining is a technique to extremely speed up processing by packing | |||
// operations to batches, send them at once to Redis and read a replies in a | |||
// singe step. | |||
// See https://redis.io/topics/pipelining | |||
// | |||
// Pay attention, that Pipeline is not a transaction, so you can get unexpected | |||
// results in case of big pipelines and small read/write timeouts. | |||
// Redis client has retransmission logic in case of timeouts, pipeline | |||
// can be retransmitted and commands can be executed more then once. | |||
// To avoid this: it is good idea to use reasonable bigger read/write timeouts | |||
// depends of your batch size and/or use TxPipeline. | |||
type Pipeliner interface { | |||
StatefulCmdable | |||
Do(ctx context.Context, args ...interface{}) *Cmd | |||
Process(ctx context.Context, cmd Cmder) error | |||
Close() error | |||
Discard() error | |||
Exec(ctx context.Context) ([]Cmder, error) | |||
} | |||
var _ Pipeliner = (*Pipeline)(nil) | |||
// Pipeline implements pipelining as described in | |||
// http://redis.io/topics/pipelining. It's safe for concurrent use | |||
// by multiple goroutines. | |||
type Pipeline struct { | |||
cmdable | |||
statefulCmdable | |||
ctx context.Context | |||
exec pipelineExecer | |||
mu sync.Mutex | |||
cmds []Cmder | |||
closed bool | |||
} | |||
func (c *Pipeline) init() { | |||
c.cmdable = c.Process | |||
c.statefulCmdable = c.Process | |||
} | |||
func (c *Pipeline) Do(ctx context.Context, args ...interface{}) *Cmd { | |||
cmd := NewCmd(ctx, args...) | |||
_ = c.Process(ctx, cmd) | |||
return cmd | |||
} | |||
// Process queues the cmd for later execution. | |||
func (c *Pipeline) Process(ctx context.Context, cmd Cmder) error { | |||
c.mu.Lock() | |||
c.cmds = append(c.cmds, cmd) | |||
c.mu.Unlock() | |||
return nil | |||
} | |||
// Close closes the pipeline, releasing any open resources. | |||
func (c *Pipeline) Close() error { | |||
c.mu.Lock() | |||
_ = c.discard() | |||
c.closed = true | |||
c.mu.Unlock() | |||
return nil | |||
} | |||
// Discard resets the pipeline and discards queued commands. | |||
func (c *Pipeline) Discard() error { | |||
c.mu.Lock() | |||
err := c.discard() | |||
c.mu.Unlock() | |||
return err | |||
} | |||
func (c *Pipeline) discard() error { | |||
if c.closed { | |||
return pool.ErrClosed | |||
} | |||
c.cmds = c.cmds[:0] | |||
return nil | |||
} | |||
// Exec executes all previously queued commands using one | |||
// client-server roundtrip. | |||
// | |||
// Exec always returns list of commands and error of the first failed | |||
// command if any. | |||
func (c *Pipeline) Exec(ctx context.Context) ([]Cmder, error) { | |||
c.mu.Lock() | |||
defer c.mu.Unlock() | |||
if c.closed { | |||
return nil, pool.ErrClosed | |||
} | |||
if len(c.cmds) == 0 { | |||
return nil, nil | |||
} | |||
cmds := c.cmds | |||
c.cmds = nil | |||
return cmds, c.exec(ctx, cmds) | |||
} | |||
func (c *Pipeline) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||
if err := fn(c); err != nil { | |||
return nil, err | |||
} | |||
cmds, err := c.Exec(ctx) | |||
_ = c.Close() | |||
return cmds, err | |||
} | |||
func (c *Pipeline) Pipeline() Pipeliner { | |||
return c | |||
} | |||
func (c *Pipeline) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||
return c.Pipelined(ctx, fn) | |||
} | |||
func (c *Pipeline) TxPipeline() Pipeliner { | |||
return c | |||
} |
@@ -0,0 +1,668 @@ | |||
package redis | |||
import ( | |||
"context" | |||
"fmt" | |||
"strings" | |||
"sync" | |||
"time" | |||
"github.com/go-redis/redis/v8/internal" | |||
"github.com/go-redis/redis/v8/internal/pool" | |||
"github.com/go-redis/redis/v8/internal/proto" | |||
) | |||
// PubSub implements Pub/Sub commands as described in | |||
// http://redis.io/topics/pubsub. Message receiving is NOT safe | |||
// for concurrent use by multiple goroutines. | |||
// | |||
// PubSub automatically reconnects to Redis Server and resubscribes | |||
// to the channels in case of network errors. | |||
type PubSub struct { | |||
opt *Options | |||
newConn func(ctx context.Context, channels []string) (*pool.Conn, error) | |||
closeConn func(*pool.Conn) error | |||
mu sync.Mutex | |||
cn *pool.Conn | |||
channels map[string]struct{} | |||
patterns map[string]struct{} | |||
closed bool | |||
exit chan struct{} | |||
cmd *Cmd | |||
chOnce sync.Once | |||
msgCh *channel | |||
allCh *channel | |||
} | |||
func (c *PubSub) init() { | |||
c.exit = make(chan struct{}) | |||
} | |||
func (c *PubSub) String() string { | |||
channels := mapKeys(c.channels) | |||
channels = append(channels, mapKeys(c.patterns)...) | |||
return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", ")) | |||
} | |||
func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) { | |||
c.mu.Lock() | |||
cn, err := c.conn(ctx, nil) | |||
c.mu.Unlock() | |||
return cn, err | |||
} | |||
func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, error) { | |||
if c.closed { | |||
return nil, pool.ErrClosed | |||
} | |||
if c.cn != nil { | |||
return c.cn, nil | |||
} | |||
channels := mapKeys(c.channels) | |||
channels = append(channels, newChannels...) | |||
cn, err := c.newConn(ctx, channels) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if err := c.resubscribe(ctx, cn); err != nil { | |||
_ = c.closeConn(cn) | |||
return nil, err | |||
} | |||
c.cn = cn | |||
return cn, nil | |||
} | |||
func (c *PubSub) writeCmd(ctx context.Context, cn *pool.Conn, cmd Cmder) error { | |||
return cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { | |||
return writeCmd(wr, cmd) | |||
}) | |||
} | |||
func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error { | |||
var firstErr error | |||
if len(c.channels) > 0 { | |||
firstErr = c._subscribe(ctx, cn, "subscribe", mapKeys(c.channels)) | |||
} | |||
if len(c.patterns) > 0 { | |||
err := c._subscribe(ctx, cn, "psubscribe", mapKeys(c.patterns)) | |||
if err != nil && firstErr == nil { | |||
firstErr = err | |||
} | |||
} | |||
return firstErr | |||
} | |||
func mapKeys(m map[string]struct{}) []string { | |||
s := make([]string, len(m)) | |||
i := 0 | |||
for k := range m { | |||
s[i] = k | |||
i++ | |||
} | |||
return s | |||
} | |||
func (c *PubSub) _subscribe( | |||
ctx context.Context, cn *pool.Conn, redisCmd string, channels []string, | |||
) error { | |||
args := make([]interface{}, 0, 1+len(channels)) | |||
args = append(args, redisCmd) | |||
for _, channel := range channels { | |||
args = append(args, channel) | |||
} | |||
cmd := NewSliceCmd(ctx, args...) | |||
return c.writeCmd(ctx, cn, cmd) | |||
} | |||
func (c *PubSub) releaseConnWithLock( | |||
ctx context.Context, | |||
cn *pool.Conn, | |||
err error, | |||
allowTimeout bool, | |||
) { | |||
c.mu.Lock() | |||
c.releaseConn(ctx, cn, err, allowTimeout) | |||
c.mu.Unlock() | |||
} | |||
func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allowTimeout bool) { | |||
if c.cn != cn { | |||
return | |||
} | |||
if isBadConn(err, allowTimeout, c.opt.Addr) { | |||
c.reconnect(ctx, err) | |||
} | |||
} | |||
func (c *PubSub) reconnect(ctx context.Context, reason error) { | |||
_ = c.closeTheCn(reason) | |||
_, _ = c.conn(ctx, nil) | |||
} | |||
func (c *PubSub) closeTheCn(reason error) error { | |||
if c.cn == nil { | |||
return nil | |||
} | |||
if !c.closed { | |||
internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason) | |||
} | |||
err := c.closeConn(c.cn) | |||
c.cn = nil | |||
return err | |||
} | |||
func (c *PubSub) Close() error { | |||
c.mu.Lock() | |||
defer c.mu.Unlock() | |||
if c.closed { | |||
return pool.ErrClosed | |||
} | |||
c.closed = true | |||
close(c.exit) | |||
return c.closeTheCn(pool.ErrClosed) | |||
} | |||
// Subscribe the client to the specified channels. It returns | |||
// empty subscription if there are no channels. | |||
func (c *PubSub) Subscribe(ctx context.Context, channels ...string) error { | |||
c.mu.Lock() | |||
defer c.mu.Unlock() | |||
err := c.subscribe(ctx, "subscribe", channels...) | |||
if c.channels == nil { | |||
c.channels = make(map[string]struct{}) | |||
} | |||
for _, s := range channels { | |||
c.channels[s] = struct{}{} | |||
} | |||
return err | |||
} | |||
// PSubscribe the client to the given patterns. It returns | |||
// empty subscription if there are no patterns. | |||
func (c *PubSub) PSubscribe(ctx context.Context, patterns ...string) error { | |||
c.mu.Lock() | |||
defer c.mu.Unlock() | |||
err := c.subscribe(ctx, "psubscribe", patterns...) | |||
if c.patterns == nil { | |||
c.patterns = make(map[string]struct{}) | |||
} | |||
for _, s := range patterns { | |||
c.patterns[s] = struct{}{} | |||
} | |||
return err | |||
} | |||
// Unsubscribe the client from the given channels, or from all of | |||
// them if none is given. | |||
func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error { | |||
c.mu.Lock() | |||
defer c.mu.Unlock() | |||
for _, channel := range channels { | |||
delete(c.channels, channel) | |||
} | |||
err := c.subscribe(ctx, "unsubscribe", channels...) | |||
return err | |||
} | |||
// PUnsubscribe the client from the given patterns, or from all of | |||
// them if none is given. | |||
func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error { | |||
c.mu.Lock() | |||
defer c.mu.Unlock() | |||
for _, pattern := range patterns { | |||
delete(c.patterns, pattern) | |||
} | |||
err := c.subscribe(ctx, "punsubscribe", patterns...) | |||
return err | |||
} | |||
func (c *PubSub) subscribe(ctx context.Context, redisCmd string, channels ...string) error { | |||
cn, err := c.conn(ctx, channels) | |||
if err != nil { | |||
return err | |||
} | |||
err = c._subscribe(ctx, cn, redisCmd, channels) | |||
c.releaseConn(ctx, cn, err, false) | |||
return err | |||
} | |||
func (c *PubSub) Ping(ctx context.Context, payload ...string) error { | |||
args := []interface{}{"ping"} | |||
if len(payload) == 1 { | |||
args = append(args, payload[0]) | |||
} | |||
cmd := NewCmd(ctx, args...) | |||
c.mu.Lock() | |||
defer c.mu.Unlock() | |||
cn, err := c.conn(ctx, nil) | |||
if err != nil { | |||
return err | |||
} | |||
err = c.writeCmd(ctx, cn, cmd) | |||
c.releaseConn(ctx, cn, err, false) | |||
return err | |||
} | |||
// Subscription received after a successful subscription to channel. | |||
type Subscription struct { | |||
// Can be "subscribe", "unsubscribe", "psubscribe" or "punsubscribe". | |||
Kind string | |||
// Channel name we have subscribed to. | |||
Channel string | |||
// Number of channels we are currently subscribed to. | |||
Count int | |||
} | |||
func (m *Subscription) String() string { | |||
return fmt.Sprintf("%s: %s", m.Kind, m.Channel) | |||
} | |||
// Message received as result of a PUBLISH command issued by another client. | |||
type Message struct { | |||
Channel string | |||
Pattern string | |||
Payload string | |||
PayloadSlice []string | |||
} | |||
func (m *Message) String() string { | |||
return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload) | |||
} | |||
// Pong received as result of a PING command issued by another client. | |||
type Pong struct { | |||
Payload string | |||
} | |||
func (p *Pong) String() string { | |||
if p.Payload != "" { | |||
return fmt.Sprintf("Pong<%s>", p.Payload) | |||
} | |||
return "Pong" | |||
} | |||
func (c *PubSub) newMessage(reply interface{}) (interface{}, error) { | |||
switch reply := reply.(type) { | |||
case string: | |||
return &Pong{ | |||
Payload: reply, | |||
}, nil | |||
case []interface{}: | |||
switch kind := reply[0].(string); kind { | |||
case "subscribe", "unsubscribe", "psubscribe", "punsubscribe": | |||
// Can be nil in case of "unsubscribe". | |||
channel, _ := reply[1].(string) | |||
return &Subscription{ | |||
Kind: kind, | |||
Channel: channel, | |||
Count: int(reply[2].(int64)), | |||
}, nil | |||
case "message": | |||
switch payload := reply[2].(type) { | |||
case string: | |||
return &Message{ | |||
Channel: reply[1].(string), | |||
Payload: payload, | |||
}, nil | |||
case []interface{}: | |||
ss := make([]string, len(payload)) | |||
for i, s := range payload { | |||
ss[i] = s.(string) | |||
} | |||
return &Message{ | |||
Channel: reply[1].(string), | |||
PayloadSlice: ss, | |||
}, nil | |||
default: | |||
return nil, fmt.Errorf("redis: unsupported pubsub message payload: %T", payload) | |||
} | |||
case "pmessage": | |||
return &Message{ | |||
Pattern: reply[1].(string), | |||
Channel: reply[2].(string), | |||
Payload: reply[3].(string), | |||
}, nil | |||
case "pong": | |||
return &Pong{ | |||
Payload: reply[1].(string), | |||
}, nil | |||
default: | |||
return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind) | |||
} | |||
default: | |||
return nil, fmt.Errorf("redis: unsupported pubsub message: %#v", reply) | |||
} | |||
} | |||
// ReceiveTimeout acts like Receive but returns an error if message | |||
// is not received in time. This is low-level API and in most cases | |||
// Channel should be used instead. | |||
func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (interface{}, error) { | |||
if c.cmd == nil { | |||
c.cmd = NewCmd(ctx) | |||
} | |||
// Don't hold the lock to allow subscriptions and pings. | |||
cn, err := c.connWithLock(ctx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
err = cn.WithReader(ctx, timeout, func(rd *proto.Reader) error { | |||
return c.cmd.readReply(rd) | |||
}) | |||
c.releaseConnWithLock(ctx, cn, err, timeout > 0) | |||
if err != nil { | |||
return nil, err | |||
} | |||
return c.newMessage(c.cmd.Val()) | |||
} | |||
// Receive returns a message as a Subscription, Message, Pong or error. | |||
// See PubSub example for details. This is low-level API and in most cases | |||
// Channel should be used instead. | |||
func (c *PubSub) Receive(ctx context.Context) (interface{}, error) { | |||
return c.ReceiveTimeout(ctx, 0) | |||
} | |||
// ReceiveMessage returns a Message or error ignoring Subscription and Pong | |||
// messages. This is low-level API and in most cases Channel should be used | |||
// instead. | |||
func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) { | |||
for { | |||
msg, err := c.Receive(ctx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
switch msg := msg.(type) { | |||
case *Subscription: | |||
// Ignore. | |||
case *Pong: | |||
// Ignore. | |||
case *Message: | |||
return msg, nil | |||
default: | |||
err := fmt.Errorf("redis: unknown message: %T", msg) | |||
return nil, err | |||
} | |||
} | |||
} | |||
func (c *PubSub) getContext() context.Context { | |||
if c.cmd != nil { | |||
return c.cmd.ctx | |||
} | |||
return context.Background() | |||
} | |||
//------------------------------------------------------------------------------ | |||
// Channel returns a Go channel for concurrently receiving messages. | |||
// The channel is closed together with the PubSub. If the Go channel | |||
// is blocked full for 30 seconds the message is dropped. | |||
// Receive* APIs can not be used after channel is created. | |||
// | |||
// go-redis periodically sends ping messages to test connection health | |||
// and re-subscribes if ping can not not received for 30 seconds. | |||
func (c *PubSub) Channel(opts ...ChannelOption) <-chan *Message { | |||
c.chOnce.Do(func() { | |||
c.msgCh = newChannel(c, opts...) | |||
c.msgCh.initMsgChan() | |||
}) | |||
if c.msgCh == nil { | |||
err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions") | |||
panic(err) | |||
} | |||
return c.msgCh.msgCh | |||
} | |||
// ChannelSize is like Channel, but creates a Go channel | |||
// with specified buffer size. | |||
// | |||
// Deprecated: use Channel(WithChannelSize(size)), remove in v9. | |||
func (c *PubSub) ChannelSize(size int) <-chan *Message { | |||
return c.Channel(WithChannelSize(size)) | |||
} | |||
// ChannelWithSubscriptions is like Channel, but message type can be either | |||
// *Subscription or *Message. Subscription messages can be used to detect | |||
// reconnections. | |||
// | |||
// ChannelWithSubscriptions can not be used together with Channel or ChannelSize. | |||
func (c *PubSub) ChannelWithSubscriptions(_ context.Context, size int) <-chan interface{} { | |||
c.chOnce.Do(func() { | |||
c.allCh = newChannel(c, WithChannelSize(size)) | |||
c.allCh.initAllChan() | |||
}) | |||
if c.allCh == nil { | |||
err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel") | |||
panic(err) | |||
} | |||
return c.allCh.allCh | |||
} | |||
type ChannelOption func(c *channel) | |||
// WithChannelSize specifies the Go chan size that is used to buffer incoming messages. | |||
// | |||
// The default is 100 messages. | |||
func WithChannelSize(size int) ChannelOption { | |||
return func(c *channel) { | |||
c.chanSize = size | |||
} | |||
} | |||
// WithChannelHealthCheckInterval specifies the health check interval. | |||
// PubSub will ping Redis Server if it does not receive any messages within the interval. | |||
// To disable health check, use zero interval. | |||
// | |||
// The default is 3 seconds. | |||
func WithChannelHealthCheckInterval(d time.Duration) ChannelOption { | |||
return func(c *channel) { | |||
c.checkInterval = d | |||
} | |||
} | |||
// WithChannelSendTimeout specifies the channel send timeout after which | |||
// the message is dropped. | |||
// | |||
// The default is 60 seconds. | |||
func WithChannelSendTimeout(d time.Duration) ChannelOption { | |||
return func(c *channel) { | |||
c.chanSendTimeout = d | |||
} | |||
} | |||
type channel struct { | |||
pubSub *PubSub | |||
msgCh chan *Message | |||
allCh chan interface{} | |||
ping chan struct{} | |||
chanSize int | |||
chanSendTimeout time.Duration | |||
checkInterval time.Duration | |||
} | |||
func newChannel(pubSub *PubSub, opts ...ChannelOption) *channel { | |||
c := &channel{ | |||
pubSub: pubSub, | |||
chanSize: 100, | |||
chanSendTimeout: time.Minute, | |||
checkInterval: 3 * time.Second, | |||
} | |||
for _, opt := range opts { | |||
opt(c) | |||
} | |||
if c.checkInterval > 0 { | |||
c.initHealthCheck() | |||
} | |||
return c | |||
} | |||
func (c *channel) initHealthCheck() { | |||
ctx := context.TODO() | |||
c.ping = make(chan struct{}, 1) | |||
go func() { | |||
timer := time.NewTimer(time.Minute) | |||
timer.Stop() | |||
for { | |||
timer.Reset(c.checkInterval) | |||
select { | |||
case <-c.ping: | |||
if !timer.Stop() { | |||
<-timer.C | |||
} | |||
case <-timer.C: | |||
if pingErr := c.pubSub.Ping(ctx); pingErr != nil { | |||
c.pubSub.mu.Lock() | |||
c.pubSub.reconnect(ctx, pingErr) | |||
c.pubSub.mu.Unlock() | |||
} | |||
case <-c.pubSub.exit: | |||
return | |||
} | |||
} | |||
}() | |||
} | |||
// initMsgChan must be in sync with initAllChan. | |||
func (c *channel) initMsgChan() { | |||
ctx := context.TODO() | |||
c.msgCh = make(chan *Message, c.chanSize) | |||
go func() { | |||
timer := time.NewTimer(time.Minute) | |||
timer.Stop() | |||
var errCount int | |||
for { | |||
msg, err := c.pubSub.Receive(ctx) | |||
if err != nil { | |||
if err == pool.ErrClosed { | |||
close(c.msgCh) | |||
return | |||
} | |||
if errCount > 0 { | |||
time.Sleep(100 * time.Millisecond) | |||
} | |||
errCount++ | |||
continue | |||
} | |||
errCount = 0 | |||
// Any message is as good as a ping. | |||
select { | |||
case c.ping <- struct{}{}: | |||
default: | |||
} | |||
switch msg := msg.(type) { | |||
case *Subscription: | |||
// Ignore. | |||
case *Pong: | |||
// Ignore. | |||
case *Message: | |||
timer.Reset(c.chanSendTimeout) | |||
select { | |||
case c.msgCh <- msg: | |||
if !timer.Stop() { | |||
<-timer.C | |||
} | |||
case <-timer.C: | |||
internal.Logger.Printf( | |||
ctx, "redis: %s channel is full for %s (message is dropped)", | |||
c, c.chanSendTimeout) | |||
} | |||
default: | |||
internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg) | |||
} | |||
} | |||
}() | |||
} | |||
// initAllChan must be in sync with initMsgChan. | |||
func (c *channel) initAllChan() { | |||
ctx := context.TODO() | |||
c.allCh = make(chan interface{}, c.chanSize) | |||
go func() { | |||
timer := time.NewTimer(time.Minute) | |||
timer.Stop() | |||
var errCount int | |||
for { | |||
msg, err := c.pubSub.Receive(ctx) | |||
if err != nil { | |||
if err == pool.ErrClosed { | |||
close(c.allCh) | |||
return | |||
} | |||
if errCount > 0 { | |||
time.Sleep(100 * time.Millisecond) | |||
} | |||
errCount++ | |||
continue | |||
} | |||
errCount = 0 | |||
// Any message is as good as a ping. | |||
select { | |||
case c.ping <- struct{}{}: | |||
default: | |||
} | |||
switch msg := msg.(type) { | |||
case *Pong: | |||
// Ignore. | |||
case *Subscription, *Message: | |||
timer.Reset(c.chanSendTimeout) | |||
select { | |||
case c.allCh <- msg: | |||
if !timer.Stop() { | |||
<-timer.C | |||
} | |||
case <-timer.C: | |||
internal.Logger.Printf( | |||
ctx, "redis: %s channel is full for %s (message is dropped)", | |||
c, c.chanSendTimeout) | |||
} | |||
default: | |||
internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg) | |||
} | |||
} | |||
}() | |||
} |
@@ -0,0 +1,773 @@ | |||
package redis | |||
import ( | |||
"context" | |||
"errors" | |||
"fmt" | |||
"sync/atomic" | |||
"time" | |||
"github.com/go-redis/redis/v8/internal" | |||
"github.com/go-redis/redis/v8/internal/pool" | |||
"github.com/go-redis/redis/v8/internal/proto" | |||
) | |||
// Nil reply returned by Redis when key does not exist. | |||
const Nil = proto.Nil | |||
func SetLogger(logger internal.Logging) { | |||
internal.Logger = logger | |||
} | |||
//------------------------------------------------------------------------------ | |||
type Hook interface { | |||
BeforeProcess(ctx context.Context, cmd Cmder) (context.Context, error) | |||
AfterProcess(ctx context.Context, cmd Cmder) error | |||
BeforeProcessPipeline(ctx context.Context, cmds []Cmder) (context.Context, error) | |||
AfterProcessPipeline(ctx context.Context, cmds []Cmder) error | |||
} | |||
type hooks struct { | |||
hooks []Hook | |||
} | |||
func (hs *hooks) lock() { | |||
hs.hooks = hs.hooks[:len(hs.hooks):len(hs.hooks)] | |||
} | |||
func (hs hooks) clone() hooks { | |||
clone := hs | |||
clone.lock() | |||
return clone | |||
} | |||
func (hs *hooks) AddHook(hook Hook) { | |||
hs.hooks = append(hs.hooks, hook) | |||
} | |||
func (hs hooks) process( | |||
ctx context.Context, cmd Cmder, fn func(context.Context, Cmder) error, | |||
) error { | |||
if len(hs.hooks) == 0 { | |||
err := fn(ctx, cmd) | |||
cmd.SetErr(err) | |||
return err | |||
} | |||
var hookIndex int | |||
var retErr error | |||
for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ { | |||
ctx, retErr = hs.hooks[hookIndex].BeforeProcess(ctx, cmd) | |||
if retErr != nil { | |||
cmd.SetErr(retErr) | |||
} | |||
} | |||
if retErr == nil { | |||
retErr = fn(ctx, cmd) | |||
cmd.SetErr(retErr) | |||
} | |||
for hookIndex--; hookIndex >= 0; hookIndex-- { | |||
if err := hs.hooks[hookIndex].AfterProcess(ctx, cmd); err != nil { | |||
retErr = err | |||
cmd.SetErr(retErr) | |||
} | |||
} | |||
return retErr | |||
} | |||
func (hs hooks) processPipeline( | |||
ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, | |||
) error { | |||
if len(hs.hooks) == 0 { | |||
err := fn(ctx, cmds) | |||
return err | |||
} | |||
var hookIndex int | |||
var retErr error | |||
for ; hookIndex < len(hs.hooks) && retErr == nil; hookIndex++ { | |||
ctx, retErr = hs.hooks[hookIndex].BeforeProcessPipeline(ctx, cmds) | |||
if retErr != nil { | |||
setCmdsErr(cmds, retErr) | |||
} | |||
} | |||
if retErr == nil { | |||
retErr = fn(ctx, cmds) | |||
} | |||
for hookIndex--; hookIndex >= 0; hookIndex-- { | |||
if err := hs.hooks[hookIndex].AfterProcessPipeline(ctx, cmds); err != nil { | |||
retErr = err | |||
setCmdsErr(cmds, retErr) | |||
} | |||
} | |||
return retErr | |||
} | |||
func (hs hooks) processTxPipeline( | |||
ctx context.Context, cmds []Cmder, fn func(context.Context, []Cmder) error, | |||
) error { | |||
cmds = wrapMultiExec(ctx, cmds) | |||
return hs.processPipeline(ctx, cmds, fn) | |||
} | |||
//------------------------------------------------------------------------------ | |||
type baseClient struct { | |||
opt *Options | |||
connPool pool.Pooler | |||
onClose func() error // hook called when client is closed | |||
} | |||
func newBaseClient(opt *Options, connPool pool.Pooler) *baseClient { | |||
return &baseClient{ | |||
opt: opt, | |||
connPool: connPool, | |||
} | |||
} | |||
func (c *baseClient) clone() *baseClient { | |||
clone := *c | |||
return &clone | |||
} | |||
func (c *baseClient) withTimeout(timeout time.Duration) *baseClient { | |||
opt := c.opt.clone() | |||
opt.ReadTimeout = timeout | |||
opt.WriteTimeout = timeout | |||
clone := c.clone() | |||
clone.opt = opt | |||
return clone | |||
} | |||
func (c *baseClient) String() string { | |||
return fmt.Sprintf("Redis<%s db:%d>", c.getAddr(), c.opt.DB) | |||
} | |||
func (c *baseClient) newConn(ctx context.Context) (*pool.Conn, error) { | |||
cn, err := c.connPool.NewConn(ctx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
err = c.initConn(ctx, cn) | |||
if err != nil { | |||
_ = c.connPool.CloseConn(cn) | |||
return nil, err | |||
} | |||
return cn, nil | |||
} | |||
func (c *baseClient) getConn(ctx context.Context) (*pool.Conn, error) { | |||
if c.opt.Limiter != nil { | |||
err := c.opt.Limiter.Allow() | |||
if err != nil { | |||
return nil, err | |||
} | |||
} | |||
cn, err := c._getConn(ctx) | |||
if err != nil { | |||
if c.opt.Limiter != nil { | |||
c.opt.Limiter.ReportResult(err) | |||
} | |||
return nil, err | |||
} | |||
return cn, nil | |||
} | |||
func (c *baseClient) _getConn(ctx context.Context) (*pool.Conn, error) { | |||
cn, err := c.connPool.Get(ctx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
if cn.Inited { | |||
return cn, nil | |||
} | |||
if err := c.initConn(ctx, cn); err != nil { | |||
c.connPool.Remove(ctx, cn, err) | |||
if err := errors.Unwrap(err); err != nil { | |||
return nil, err | |||
} | |||
return nil, err | |||
} | |||
return cn, nil | |||
} | |||
func (c *baseClient) initConn(ctx context.Context, cn *pool.Conn) error { | |||
if cn.Inited { | |||
return nil | |||
} | |||
cn.Inited = true | |||
if c.opt.Password == "" && | |||
c.opt.DB == 0 && | |||
!c.opt.readOnly && | |||
c.opt.OnConnect == nil { | |||
return nil | |||
} | |||
connPool := pool.NewSingleConnPool(c.connPool, cn) | |||
conn := newConn(ctx, c.opt, connPool) | |||
_, err := conn.Pipelined(ctx, func(pipe Pipeliner) error { | |||
if c.opt.Password != "" { | |||
if c.opt.Username != "" { | |||
pipe.AuthACL(ctx, c.opt.Username, c.opt.Password) | |||
} else { | |||
pipe.Auth(ctx, c.opt.Password) | |||
} | |||
} | |||
if c.opt.DB > 0 { | |||
pipe.Select(ctx, c.opt.DB) | |||
} | |||
if c.opt.readOnly { | |||
pipe.ReadOnly(ctx) | |||
} | |||
return nil | |||
}) | |||
if err != nil { | |||
return err | |||
} | |||
if c.opt.OnConnect != nil { | |||
return c.opt.OnConnect(ctx, conn) | |||
} | |||
return nil | |||
} | |||
func (c *baseClient) releaseConn(ctx context.Context, cn *pool.Conn, err error) { | |||
if c.opt.Limiter != nil { | |||
c.opt.Limiter.ReportResult(err) | |||
} | |||
if isBadConn(err, false, c.opt.Addr) { | |||
c.connPool.Remove(ctx, cn, err) | |||
} else { | |||
c.connPool.Put(ctx, cn) | |||
} | |||
} | |||
func (c *baseClient) withConn( | |||
ctx context.Context, fn func(context.Context, *pool.Conn) error, | |||
) error { | |||
cn, err := c.getConn(ctx) | |||
if err != nil { | |||
return err | |||
} | |||
defer func() { | |||
c.releaseConn(ctx, cn, err) | |||
}() | |||
done := ctx.Done() //nolint:ifshort | |||
if done == nil { | |||
err = fn(ctx, cn) | |||
return err | |||
} | |||
errc := make(chan error, 1) | |||
go func() { errc <- fn(ctx, cn) }() | |||
select { | |||
case <-done: | |||
_ = cn.Close() | |||
// Wait for the goroutine to finish and send something. | |||
<-errc | |||
err = ctx.Err() | |||
return err | |||
case err = <-errc: | |||
return err | |||
} | |||
} | |||
func (c *baseClient) process(ctx context.Context, cmd Cmder) error { | |||
var lastErr error | |||
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { | |||
attempt := attempt | |||
retry, err := c._process(ctx, cmd, attempt) | |||
if err == nil || !retry { | |||
return err | |||
} | |||
lastErr = err | |||
} | |||
return lastErr | |||
} | |||
func (c *baseClient) _process(ctx context.Context, cmd Cmder, attempt int) (bool, error) { | |||
if attempt > 0 { | |||
if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { | |||
return false, err | |||
} | |||
} | |||
retryTimeout := uint32(1) | |||
err := c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { | |||
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { | |||
return writeCmd(wr, cmd) | |||
}) | |||
if err != nil { | |||
return err | |||
} | |||
err = cn.WithReader(ctx, c.cmdTimeout(cmd), cmd.readReply) | |||
if err != nil { | |||
if cmd.readTimeout() == nil { | |||
atomic.StoreUint32(&retryTimeout, 1) | |||
} | |||
return err | |||
} | |||
return nil | |||
}) | |||
if err == nil { | |||
return false, nil | |||
} | |||
retry := shouldRetry(err, atomic.LoadUint32(&retryTimeout) == 1) | |||
return retry, err | |||
} | |||
func (c *baseClient) retryBackoff(attempt int) time.Duration { | |||
return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) | |||
} | |||
func (c *baseClient) cmdTimeout(cmd Cmder) time.Duration { | |||
if timeout := cmd.readTimeout(); timeout != nil { | |||
t := *timeout | |||
if t == 0 { | |||
return 0 | |||
} | |||
return t + 10*time.Second | |||
} | |||
return c.opt.ReadTimeout | |||
} | |||
// Close closes the client, releasing any open resources. | |||
// | |||
// It is rare to Close a Client, as the Client is meant to be | |||
// long-lived and shared between many goroutines. | |||
func (c *baseClient) Close() error { | |||
var firstErr error | |||
if c.onClose != nil { | |||
if err := c.onClose(); err != nil { | |||
firstErr = err | |||
} | |||
} | |||
if err := c.connPool.Close(); err != nil && firstErr == nil { | |||
firstErr = err | |||
} | |||
return firstErr | |||
} | |||
func (c *baseClient) getAddr() string { | |||
return c.opt.Addr | |||
} | |||
func (c *baseClient) processPipeline(ctx context.Context, cmds []Cmder) error { | |||
return c.generalProcessPipeline(ctx, cmds, c.pipelineProcessCmds) | |||
} | |||
func (c *baseClient) processTxPipeline(ctx context.Context, cmds []Cmder) error { | |||
return c.generalProcessPipeline(ctx, cmds, c.txPipelineProcessCmds) | |||
} | |||
type pipelineProcessor func(context.Context, *pool.Conn, []Cmder) (bool, error) | |||
func (c *baseClient) generalProcessPipeline( | |||
ctx context.Context, cmds []Cmder, p pipelineProcessor, | |||
) error { | |||
err := c._generalProcessPipeline(ctx, cmds, p) | |||
if err != nil { | |||
setCmdsErr(cmds, err) | |||
return err | |||
} | |||
return cmdsFirstErr(cmds) | |||
} | |||
func (c *baseClient) _generalProcessPipeline( | |||
ctx context.Context, cmds []Cmder, p pipelineProcessor, | |||
) error { | |||
var lastErr error | |||
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { | |||
if attempt > 0 { | |||
if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { | |||
return err | |||
} | |||
} | |||
var canRetry bool | |||
lastErr = c.withConn(ctx, func(ctx context.Context, cn *pool.Conn) error { | |||
var err error | |||
canRetry, err = p(ctx, cn, cmds) | |||
return err | |||
}) | |||
if lastErr == nil || !canRetry || !shouldRetry(lastErr, true) { | |||
return lastErr | |||
} | |||
} | |||
return lastErr | |||
} | |||
func (c *baseClient) pipelineProcessCmds( | |||
ctx context.Context, cn *pool.Conn, cmds []Cmder, | |||
) (bool, error) { | |||
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { | |||
return writeCmds(wr, cmds) | |||
}) | |||
if err != nil { | |||
return true, err | |||
} | |||
err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { | |||
return pipelineReadCmds(rd, cmds) | |||
}) | |||
return true, err | |||
} | |||
func pipelineReadCmds(rd *proto.Reader, cmds []Cmder) error { | |||
for _, cmd := range cmds { | |||
err := cmd.readReply(rd) | |||
cmd.SetErr(err) | |||
if err != nil && !isRedisError(err) { | |||
return err | |||
} | |||
} | |||
return nil | |||
} | |||
func (c *baseClient) txPipelineProcessCmds( | |||
ctx context.Context, cn *pool.Conn, cmds []Cmder, | |||
) (bool, error) { | |||
err := cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error { | |||
return writeCmds(wr, cmds) | |||
}) | |||
if err != nil { | |||
return true, err | |||
} | |||
err = cn.WithReader(ctx, c.opt.ReadTimeout, func(rd *proto.Reader) error { | |||
statusCmd := cmds[0].(*StatusCmd) | |||
// Trim multi and exec. | |||
cmds = cmds[1 : len(cmds)-1] | |||
err := txPipelineReadQueued(rd, statusCmd, cmds) | |||
if err != nil { | |||
return err | |||
} | |||
return pipelineReadCmds(rd, cmds) | |||
}) | |||
return false, err | |||
} | |||
func wrapMultiExec(ctx context.Context, cmds []Cmder) []Cmder { | |||
if len(cmds) == 0 { | |||
panic("not reached") | |||
} | |||
cmdCopy := make([]Cmder, len(cmds)+2) | |||
cmdCopy[0] = NewStatusCmd(ctx, "multi") | |||
copy(cmdCopy[1:], cmds) | |||
cmdCopy[len(cmdCopy)-1] = NewSliceCmd(ctx, "exec") | |||
return cmdCopy | |||
} | |||
func txPipelineReadQueued(rd *proto.Reader, statusCmd *StatusCmd, cmds []Cmder) error { | |||
// Parse queued replies. | |||
if err := statusCmd.readReply(rd); err != nil { | |||
return err | |||
} | |||
for range cmds { | |||
if err := statusCmd.readReply(rd); err != nil && !isRedisError(err) { | |||
return err | |||
} | |||
} | |||
// Parse number of replies. | |||
line, err := rd.ReadLine() | |||
if err != nil { | |||
if err == Nil { | |||
err = TxFailedErr | |||
} | |||
return err | |||
} | |||
switch line[0] { | |||
case proto.ErrorReply: | |||
return proto.ParseErrorReply(line) | |||
case proto.ArrayReply: | |||
// ok | |||
default: | |||
err := fmt.Errorf("redis: expected '*', but got line %q", line) | |||
return err | |||
} | |||
return nil | |||
} | |||
//------------------------------------------------------------------------------ | |||
// Client is a Redis client representing a pool of zero or more | |||
// underlying connections. It's safe for concurrent use by multiple | |||
// goroutines. | |||
type Client struct { | |||
*baseClient | |||
cmdable | |||
hooks | |||
ctx context.Context | |||
} | |||
// NewClient returns a client to the Redis Server specified by Options. | |||
func NewClient(opt *Options) *Client { | |||
opt.init() | |||
c := Client{ | |||
baseClient: newBaseClient(opt, newConnPool(opt)), | |||
ctx: context.Background(), | |||
} | |||
c.cmdable = c.Process | |||
return &c | |||
} | |||
func (c *Client) clone() *Client { | |||
clone := *c | |||
clone.cmdable = clone.Process | |||
clone.hooks.lock() | |||
return &clone | |||
} | |||
func (c *Client) WithTimeout(timeout time.Duration) *Client { | |||
clone := c.clone() | |||
clone.baseClient = c.baseClient.withTimeout(timeout) | |||
return clone | |||
} | |||
func (c *Client) Context() context.Context { | |||
return c.ctx | |||
} | |||
func (c *Client) WithContext(ctx context.Context) *Client { | |||
if ctx == nil { | |||
panic("nil context") | |||
} | |||
clone := c.clone() | |||
clone.ctx = ctx | |||
return clone | |||
} | |||
func (c *Client) Conn(ctx context.Context) *Conn { | |||
return newConn(ctx, c.opt, pool.NewStickyConnPool(c.connPool)) | |||
} | |||
// Do creates a Cmd from the args and processes the cmd. | |||
func (c *Client) Do(ctx context.Context, args ...interface{}) *Cmd { | |||
cmd := NewCmd(ctx, args...) | |||
_ = c.Process(ctx, cmd) | |||
return cmd | |||
} | |||
func (c *Client) Process(ctx context.Context, cmd Cmder) error { | |||
return c.hooks.process(ctx, cmd, c.baseClient.process) | |||
} | |||
func (c *Client) processPipeline(ctx context.Context, cmds []Cmder) error { | |||
return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline) | |||
} | |||
func (c *Client) processTxPipeline(ctx context.Context, cmds []Cmder) error { | |||
return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) | |||
} | |||
// Options returns read-only Options that were used to create the client. | |||
func (c *Client) Options() *Options { | |||
return c.opt | |||
} | |||
type PoolStats pool.Stats | |||
// PoolStats returns connection pool stats. | |||
func (c *Client) PoolStats() *PoolStats { | |||
stats := c.connPool.Stats() | |||
return (*PoolStats)(stats) | |||
} | |||
func (c *Client) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||
return c.Pipeline().Pipelined(ctx, fn) | |||
} | |||
func (c *Client) Pipeline() Pipeliner { | |||
pipe := Pipeline{ | |||
ctx: c.ctx, | |||
exec: c.processPipeline, | |||
} | |||
pipe.init() | |||
return &pipe | |||
} | |||
func (c *Client) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||
return c.TxPipeline().Pipelined(ctx, fn) | |||
} | |||
// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. | |||
func (c *Client) TxPipeline() Pipeliner { | |||
pipe := Pipeline{ | |||
ctx: c.ctx, | |||
exec: c.processTxPipeline, | |||
} | |||
pipe.init() | |||
return &pipe | |||
} | |||
func (c *Client) pubSub() *PubSub { | |||
pubsub := &PubSub{ | |||
opt: c.opt, | |||
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { | |||
return c.newConn(ctx) | |||
}, | |||
closeConn: c.connPool.CloseConn, | |||
} | |||
pubsub.init() | |||
return pubsub | |||
} | |||
// Subscribe subscribes the client to the specified channels. | |||
// Channels can be omitted to create empty subscription. | |||
// Note that this method does not wait on a response from Redis, so the | |||
// subscription may not be active immediately. To force the connection to wait, | |||
// you may call the Receive() method on the returned *PubSub like so: | |||
// | |||
// sub := client.Subscribe(queryResp) | |||
// iface, err := sub.Receive() | |||
// if err != nil { | |||
// // handle error | |||
// } | |||
// | |||
// // Should be *Subscription, but others are possible if other actions have been | |||
// // taken on sub since it was created. | |||
// switch iface.(type) { | |||
// case *Subscription: | |||
// // subscribe succeeded | |||
// case *Message: | |||
// // received first message | |||
// case *Pong: | |||
// // pong received | |||
// default: | |||
// // handle error | |||
// } | |||
// | |||
// ch := sub.Channel() | |||
func (c *Client) Subscribe(ctx context.Context, channels ...string) *PubSub { | |||
pubsub := c.pubSub() | |||
if len(channels) > 0 { | |||
_ = pubsub.Subscribe(ctx, channels...) | |||
} | |||
return pubsub | |||
} | |||
// PSubscribe subscribes the client to the given patterns. | |||
// Patterns can be omitted to create empty subscription. | |||
func (c *Client) PSubscribe(ctx context.Context, channels ...string) *PubSub { | |||
pubsub := c.pubSub() | |||
if len(channels) > 0 { | |||
_ = pubsub.PSubscribe(ctx, channels...) | |||
} | |||
return pubsub | |||
} | |||
//------------------------------------------------------------------------------ | |||
type conn struct { | |||
baseClient | |||
cmdable | |||
statefulCmdable | |||
hooks // TODO: inherit hooks | |||
} | |||
// Conn represents a single Redis connection rather than a pool of connections. | |||
// Prefer running commands from Client unless there is a specific need | |||
// for a continuous single Redis connection. | |||
type Conn struct { | |||
*conn | |||
ctx context.Context | |||
} | |||
func newConn(ctx context.Context, opt *Options, connPool pool.Pooler) *Conn { | |||
c := Conn{ | |||
conn: &conn{ | |||
baseClient: baseClient{ | |||
opt: opt, | |||
connPool: connPool, | |||
}, | |||
}, | |||
ctx: ctx, | |||
} | |||
c.cmdable = c.Process | |||
c.statefulCmdable = c.Process | |||
return &c | |||
} | |||
func (c *Conn) Process(ctx context.Context, cmd Cmder) error { | |||
return c.hooks.process(ctx, cmd, c.baseClient.process) | |||
} | |||
func (c *Conn) processPipeline(ctx context.Context, cmds []Cmder) error { | |||
return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline) | |||
} | |||
func (c *Conn) processTxPipeline(ctx context.Context, cmds []Cmder) error { | |||
return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) | |||
} | |||
func (c *Conn) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||
return c.Pipeline().Pipelined(ctx, fn) | |||
} | |||
func (c *Conn) Pipeline() Pipeliner { | |||
pipe := Pipeline{ | |||
ctx: c.ctx, | |||
exec: c.processPipeline, | |||
} | |||
pipe.init() | |||
return &pipe | |||
} | |||
func (c *Conn) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||
return c.TxPipeline().Pipelined(ctx, fn) | |||
} | |||
// TxPipeline acts like Pipeline, but wraps queued commands with MULTI/EXEC. | |||
func (c *Conn) TxPipeline() Pipeliner { | |||
pipe := Pipeline{ | |||
ctx: c.ctx, | |||
exec: c.processTxPipeline, | |||
} | |||
pipe.init() | |||
return &pipe | |||
} |
@@ -0,0 +1,180 @@ | |||
package redis | |||
import "time" | |||
// NewCmdResult returns a Cmd initialised with val and err for testing. | |||
func NewCmdResult(val interface{}, err error) *Cmd { | |||
var cmd Cmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewSliceResult returns a SliceCmd initialised with val and err for testing. | |||
func NewSliceResult(val []interface{}, err error) *SliceCmd { | |||
var cmd SliceCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewStatusResult returns a StatusCmd initialised with val and err for testing. | |||
func NewStatusResult(val string, err error) *StatusCmd { | |||
var cmd StatusCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewIntResult returns an IntCmd initialised with val and err for testing. | |||
func NewIntResult(val int64, err error) *IntCmd { | |||
var cmd IntCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewDurationResult returns a DurationCmd initialised with val and err for testing. | |||
func NewDurationResult(val time.Duration, err error) *DurationCmd { | |||
var cmd DurationCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewBoolResult returns a BoolCmd initialised with val and err for testing. | |||
func NewBoolResult(val bool, err error) *BoolCmd { | |||
var cmd BoolCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewStringResult returns a StringCmd initialised with val and err for testing. | |||
func NewStringResult(val string, err error) *StringCmd { | |||
var cmd StringCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewFloatResult returns a FloatCmd initialised with val and err for testing. | |||
func NewFloatResult(val float64, err error) *FloatCmd { | |||
var cmd FloatCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewStringSliceResult returns a StringSliceCmd initialised with val and err for testing. | |||
func NewStringSliceResult(val []string, err error) *StringSliceCmd { | |||
var cmd StringSliceCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewBoolSliceResult returns a BoolSliceCmd initialised with val and err for testing. | |||
func NewBoolSliceResult(val []bool, err error) *BoolSliceCmd { | |||
var cmd BoolSliceCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewStringStringMapResult returns a StringStringMapCmd initialised with val and err for testing. | |||
func NewStringStringMapResult(val map[string]string, err error) *StringStringMapCmd { | |||
var cmd StringStringMapCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewStringIntMapCmdResult returns a StringIntMapCmd initialised with val and err for testing. | |||
func NewStringIntMapCmdResult(val map[string]int64, err error) *StringIntMapCmd { | |||
var cmd StringIntMapCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewTimeCmdResult returns a TimeCmd initialised with val and err for testing. | |||
func NewTimeCmdResult(val time.Time, err error) *TimeCmd { | |||
var cmd TimeCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewZSliceCmdResult returns a ZSliceCmd initialised with val and err for testing. | |||
func NewZSliceCmdResult(val []Z, err error) *ZSliceCmd { | |||
var cmd ZSliceCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewZWithKeyCmdResult returns a NewZWithKeyCmd initialised with val and err for testing. | |||
func NewZWithKeyCmdResult(val *ZWithKey, err error) *ZWithKeyCmd { | |||
var cmd ZWithKeyCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewScanCmdResult returns a ScanCmd initialised with val and err for testing. | |||
func NewScanCmdResult(keys []string, cursor uint64, err error) *ScanCmd { | |||
var cmd ScanCmd | |||
cmd.page = keys | |||
cmd.cursor = cursor | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewClusterSlotsCmdResult returns a ClusterSlotsCmd initialised with val and err for testing. | |||
func NewClusterSlotsCmdResult(val []ClusterSlot, err error) *ClusterSlotsCmd { | |||
var cmd ClusterSlotsCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewGeoLocationCmdResult returns a GeoLocationCmd initialised with val and err for testing. | |||
func NewGeoLocationCmdResult(val []GeoLocation, err error) *GeoLocationCmd { | |||
var cmd GeoLocationCmd | |||
cmd.locations = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewGeoPosCmdResult returns a GeoPosCmd initialised with val and err for testing. | |||
func NewGeoPosCmdResult(val []*GeoPos, err error) *GeoPosCmd { | |||
var cmd GeoPosCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewCommandsInfoCmdResult returns a CommandsInfoCmd initialised with val and err for testing. | |||
func NewCommandsInfoCmdResult(val map[string]*CommandInfo, err error) *CommandsInfoCmd { | |||
var cmd CommandsInfoCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewXMessageSliceCmdResult returns a XMessageSliceCmd initialised with val and err for testing. | |||
func NewXMessageSliceCmdResult(val []XMessage, err error) *XMessageSliceCmd { | |||
var cmd XMessageSliceCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} | |||
// NewXStreamSliceCmdResult returns a XStreamSliceCmd initialised with val and err for testing. | |||
func NewXStreamSliceCmdResult(val []XStream, err error) *XStreamSliceCmd { | |||
var cmd XStreamSliceCmd | |||
cmd.val = val | |||
cmd.SetErr(err) | |||
return &cmd | |||
} |
@@ -0,0 +1,736 @@ | |||
package redis | |||
import ( | |||
"context" | |||
"crypto/tls" | |||
"errors" | |||
"fmt" | |||
"net" | |||
"strconv" | |||
"sync" | |||
"sync/atomic" | |||
"time" | |||
"github.com/cespare/xxhash/v2" | |||
rendezvous "github.com/dgryski/go-rendezvous" //nolint | |||
"github.com/go-redis/redis/v8/internal" | |||
"github.com/go-redis/redis/v8/internal/hashtag" | |||
"github.com/go-redis/redis/v8/internal/pool" | |||
"github.com/go-redis/redis/v8/internal/rand" | |||
) | |||
var errRingShardsDown = errors.New("redis: all ring shards are down") | |||
//------------------------------------------------------------------------------ | |||
type ConsistentHash interface { | |||
Get(string) string | |||
} | |||
type rendezvousWrapper struct { | |||
*rendezvous.Rendezvous | |||
} | |||
func (w rendezvousWrapper) Get(key string) string { | |||
return w.Lookup(key) | |||
} | |||
func newRendezvous(shards []string) ConsistentHash { | |||
return rendezvousWrapper{rendezvous.New(shards, xxhash.Sum64String)} | |||
} | |||
//------------------------------------------------------------------------------ | |||
// RingOptions are used to configure a ring client and should be | |||
// passed to NewRing. | |||
type RingOptions struct { | |||
// Map of name => host:port addresses of ring shards. | |||
Addrs map[string]string | |||
// NewClient creates a shard client with provided name and options. | |||
NewClient func(name string, opt *Options) *Client | |||
// Frequency of PING commands sent to check shards availability. | |||
// Shard is considered down after 3 subsequent failed checks. | |||
HeartbeatFrequency time.Duration | |||
// NewConsistentHash returns a consistent hash that is used | |||
// to distribute keys across the shards. | |||
// | |||
// See https://medium.com/@dgryski/consistent-hashing-algorithmic-tradeoffs-ef6b8e2fcae8 | |||
// for consistent hashing algorithmic tradeoffs. | |||
NewConsistentHash func(shards []string) ConsistentHash | |||
// Following options are copied from Options struct. | |||
Dialer func(ctx context.Context, network, addr string) (net.Conn, error) | |||
OnConnect func(ctx context.Context, cn *Conn) error | |||
Username string | |||
Password string | |||
DB int | |||
MaxRetries int | |||
MinRetryBackoff time.Duration | |||
MaxRetryBackoff time.Duration | |||
DialTimeout time.Duration | |||
ReadTimeout time.Duration | |||
WriteTimeout time.Duration | |||
// PoolFIFO uses FIFO mode for each node connection pool GET/PUT (default LIFO). | |||
PoolFIFO bool | |||
PoolSize int | |||
MinIdleConns int | |||
MaxConnAge time.Duration | |||
PoolTimeout time.Duration | |||
IdleTimeout time.Duration | |||
IdleCheckFrequency time.Duration | |||
TLSConfig *tls.Config | |||
Limiter Limiter | |||
} | |||
func (opt *RingOptions) init() { | |||
if opt.NewClient == nil { | |||
opt.NewClient = func(name string, opt *Options) *Client { | |||
return NewClient(opt) | |||
} | |||
} | |||
if opt.HeartbeatFrequency == 0 { | |||
opt.HeartbeatFrequency = 500 * time.Millisecond | |||
} | |||
if opt.NewConsistentHash == nil { | |||
opt.NewConsistentHash = newRendezvous | |||
} | |||
if opt.MaxRetries == -1 { | |||
opt.MaxRetries = 0 | |||
} else if opt.MaxRetries == 0 { | |||
opt.MaxRetries = 3 | |||
} | |||
switch opt.MinRetryBackoff { | |||
case -1: | |||
opt.MinRetryBackoff = 0 | |||
case 0: | |||
opt.MinRetryBackoff = 8 * time.Millisecond | |||
} | |||
switch opt.MaxRetryBackoff { | |||
case -1: | |||
opt.MaxRetryBackoff = 0 | |||
case 0: | |||
opt.MaxRetryBackoff = 512 * time.Millisecond | |||
} | |||
} | |||
func (opt *RingOptions) clientOptions() *Options { | |||
return &Options{ | |||
Dialer: opt.Dialer, | |||
OnConnect: opt.OnConnect, | |||
Username: opt.Username, | |||
Password: opt.Password, | |||
DB: opt.DB, | |||
MaxRetries: -1, | |||
DialTimeout: opt.DialTimeout, | |||
ReadTimeout: opt.ReadTimeout, | |||
WriteTimeout: opt.WriteTimeout, | |||
PoolFIFO: opt.PoolFIFO, | |||
PoolSize: opt.PoolSize, | |||
MinIdleConns: opt.MinIdleConns, | |||
MaxConnAge: opt.MaxConnAge, | |||
PoolTimeout: opt.PoolTimeout, | |||
IdleTimeout: opt.IdleTimeout, | |||
IdleCheckFrequency: opt.IdleCheckFrequency, | |||
TLSConfig: opt.TLSConfig, | |||
Limiter: opt.Limiter, | |||
} | |||
} | |||
//------------------------------------------------------------------------------ | |||
type ringShard struct { | |||
Client *Client | |||
down int32 | |||
} | |||
func newRingShard(opt *RingOptions, name, addr string) *ringShard { | |||
clopt := opt.clientOptions() | |||
clopt.Addr = addr | |||
return &ringShard{ | |||
Client: opt.NewClient(name, clopt), | |||
} | |||
} | |||
func (shard *ringShard) String() string { | |||
var state string | |||
if shard.IsUp() { | |||
state = "up" | |||
} else { | |||
state = "down" | |||
} | |||
return fmt.Sprintf("%s is %s", shard.Client, state) | |||
} | |||
func (shard *ringShard) IsDown() bool { | |||
const threshold = 3 | |||
return atomic.LoadInt32(&shard.down) >= threshold | |||
} | |||
func (shard *ringShard) IsUp() bool { | |||
return !shard.IsDown() | |||
} | |||
// Vote votes to set shard state and returns true if state was changed. | |||
func (shard *ringShard) Vote(up bool) bool { | |||
if up { | |||
changed := shard.IsDown() | |||
atomic.StoreInt32(&shard.down, 0) | |||
return changed | |||
} | |||
if shard.IsDown() { | |||
return false | |||
} | |||
atomic.AddInt32(&shard.down, 1) | |||
return shard.IsDown() | |||
} | |||
//------------------------------------------------------------------------------ | |||
type ringShards struct { | |||
opt *RingOptions | |||
mu sync.RWMutex | |||
hash ConsistentHash | |||
shards map[string]*ringShard // read only | |||
list []*ringShard // read only | |||
numShard int | |||
closed bool | |||
} | |||
func newRingShards(opt *RingOptions) *ringShards { | |||
shards := make(map[string]*ringShard, len(opt.Addrs)) | |||
list := make([]*ringShard, 0, len(shards)) | |||
for name, addr := range opt.Addrs { | |||
shard := newRingShard(opt, name, addr) | |||
shards[name] = shard | |||
list = append(list, shard) | |||
} | |||
c := &ringShards{ | |||
opt: opt, | |||
shards: shards, | |||
list: list, | |||
} | |||
c.rebalance() | |||
return c | |||
} | |||
func (c *ringShards) List() []*ringShard { | |||
var list []*ringShard | |||
c.mu.RLock() | |||
if !c.closed { | |||
list = c.list | |||
} | |||
c.mu.RUnlock() | |||
return list | |||
} | |||
func (c *ringShards) Hash(key string) string { | |||
key = hashtag.Key(key) | |||
var hash string | |||
c.mu.RLock() | |||
if c.numShard > 0 { | |||
hash = c.hash.Get(key) | |||
} | |||
c.mu.RUnlock() | |||
return hash | |||
} | |||
func (c *ringShards) GetByKey(key string) (*ringShard, error) { | |||
key = hashtag.Key(key) | |||
c.mu.RLock() | |||
if c.closed { | |||
c.mu.RUnlock() | |||
return nil, pool.ErrClosed | |||
} | |||
if c.numShard == 0 { | |||
c.mu.RUnlock() | |||
return nil, errRingShardsDown | |||
} | |||
hash := c.hash.Get(key) | |||
if hash == "" { | |||
c.mu.RUnlock() | |||
return nil, errRingShardsDown | |||
} | |||
shard := c.shards[hash] | |||
c.mu.RUnlock() | |||
return shard, nil | |||
} | |||
func (c *ringShards) GetByName(shardName string) (*ringShard, error) { | |||
if shardName == "" { | |||
return c.Random() | |||
} | |||
c.mu.RLock() | |||
shard := c.shards[shardName] | |||
c.mu.RUnlock() | |||
return shard, nil | |||
} | |||
func (c *ringShards) Random() (*ringShard, error) { | |||
return c.GetByKey(strconv.Itoa(rand.Int())) | |||
} | |||
// heartbeat monitors state of each shard in the ring. | |||
func (c *ringShards) Heartbeat(frequency time.Duration) { | |||
ticker := time.NewTicker(frequency) | |||
defer ticker.Stop() | |||
ctx := context.Background() | |||
for range ticker.C { | |||
var rebalance bool | |||
for _, shard := range c.List() { | |||
err := shard.Client.Ping(ctx).Err() | |||
isUp := err == nil || err == pool.ErrPoolTimeout | |||
if shard.Vote(isUp) { | |||
internal.Logger.Printf(context.Background(), "ring shard state changed: %s", shard) | |||
rebalance = true | |||
} | |||
} | |||
if rebalance { | |||
c.rebalance() | |||
} | |||
} | |||
} | |||
// rebalance removes dead shards from the Ring. | |||
func (c *ringShards) rebalance() { | |||
c.mu.RLock() | |||
shards := c.shards | |||
c.mu.RUnlock() | |||
liveShards := make([]string, 0, len(shards)) | |||
for name, shard := range shards { | |||
if shard.IsUp() { | |||
liveShards = append(liveShards, name) | |||
} | |||
} | |||
hash := c.opt.NewConsistentHash(liveShards) | |||
c.mu.Lock() | |||
c.hash = hash | |||
c.numShard = len(liveShards) | |||
c.mu.Unlock() | |||
} | |||
func (c *ringShards) Len() int { | |||
c.mu.RLock() | |||
l := c.numShard | |||
c.mu.RUnlock() | |||
return l | |||
} | |||
func (c *ringShards) Close() error { | |||
c.mu.Lock() | |||
defer c.mu.Unlock() | |||
if c.closed { | |||
return nil | |||
} | |||
c.closed = true | |||
var firstErr error | |||
for _, shard := range c.shards { | |||
if err := shard.Client.Close(); err != nil && firstErr == nil { | |||
firstErr = err | |||
} | |||
} | |||
c.hash = nil | |||
c.shards = nil | |||
c.list = nil | |||
return firstErr | |||
} | |||
//------------------------------------------------------------------------------ | |||
type ring struct { | |||
opt *RingOptions | |||
shards *ringShards | |||
cmdsInfoCache *cmdsInfoCache //nolint:structcheck | |||
} | |||
// Ring is a Redis client that uses consistent hashing to distribute | |||
// keys across multiple Redis servers (shards). It's safe for | |||
// concurrent use by multiple goroutines. | |||
// | |||
// Ring monitors the state of each shard and removes dead shards from | |||
// the ring. When a shard comes online it is added back to the ring. This | |||
// gives you maximum availability and partition tolerance, but no | |||
// consistency between different shards or even clients. Each client | |||
// uses shards that are available to the client and does not do any | |||
// coordination when shard state is changed. | |||
// | |||
// Ring should be used when you need multiple Redis servers for caching | |||
// and can tolerate losing data when one of the servers dies. | |||
// Otherwise you should use Redis Cluster. | |||
type Ring struct { | |||
*ring | |||
cmdable | |||
hooks | |||
ctx context.Context | |||
} | |||
func NewRing(opt *RingOptions) *Ring { | |||
opt.init() | |||
ring := Ring{ | |||
ring: &ring{ | |||
opt: opt, | |||
shards: newRingShards(opt), | |||
}, | |||
ctx: context.Background(), | |||
} | |||
ring.cmdsInfoCache = newCmdsInfoCache(ring.cmdsInfo) | |||
ring.cmdable = ring.Process | |||
go ring.shards.Heartbeat(opt.HeartbeatFrequency) | |||
return &ring | |||
} | |||
func (c *Ring) Context() context.Context { | |||
return c.ctx | |||
} | |||
func (c *Ring) WithContext(ctx context.Context) *Ring { | |||
if ctx == nil { | |||
panic("nil context") | |||
} | |||
clone := *c | |||
clone.cmdable = clone.Process | |||
clone.hooks.lock() | |||
clone.ctx = ctx | |||
return &clone | |||
} | |||
// Do creates a Cmd from the args and processes the cmd. | |||
func (c *Ring) Do(ctx context.Context, args ...interface{}) *Cmd { | |||
cmd := NewCmd(ctx, args...) | |||
_ = c.Process(ctx, cmd) | |||
return cmd | |||
} | |||
func (c *Ring) Process(ctx context.Context, cmd Cmder) error { | |||
return c.hooks.process(ctx, cmd, c.process) | |||
} | |||
// Options returns read-only Options that were used to create the client. | |||
func (c *Ring) Options() *RingOptions { | |||
return c.opt | |||
} | |||
func (c *Ring) retryBackoff(attempt int) time.Duration { | |||
return internal.RetryBackoff(attempt, c.opt.MinRetryBackoff, c.opt.MaxRetryBackoff) | |||
} | |||
// PoolStats returns accumulated connection pool stats. | |||
func (c *Ring) PoolStats() *PoolStats { | |||
shards := c.shards.List() | |||
var acc PoolStats | |||
for _, shard := range shards { | |||
s := shard.Client.connPool.Stats() | |||
acc.Hits += s.Hits | |||
acc.Misses += s.Misses | |||
acc.Timeouts += s.Timeouts | |||
acc.TotalConns += s.TotalConns | |||
acc.IdleConns += s.IdleConns | |||
} | |||
return &acc | |||
} | |||
// Len returns the current number of shards in the ring. | |||
func (c *Ring) Len() int { | |||
return c.shards.Len() | |||
} | |||
// Subscribe subscribes the client to the specified channels. | |||
func (c *Ring) Subscribe(ctx context.Context, channels ...string) *PubSub { | |||
if len(channels) == 0 { | |||
panic("at least one channel is required") | |||
} | |||
shard, err := c.shards.GetByKey(channels[0]) | |||
if err != nil { | |||
// TODO: return PubSub with sticky error | |||
panic(err) | |||
} | |||
return shard.Client.Subscribe(ctx, channels...) | |||
} | |||
// PSubscribe subscribes the client to the given patterns. | |||
func (c *Ring) PSubscribe(ctx context.Context, channels ...string) *PubSub { | |||
if len(channels) == 0 { | |||
panic("at least one channel is required") | |||
} | |||
shard, err := c.shards.GetByKey(channels[0]) | |||
if err != nil { | |||
// TODO: return PubSub with sticky error | |||
panic(err) | |||
} | |||
return shard.Client.PSubscribe(ctx, channels...) | |||
} | |||
// ForEachShard concurrently calls the fn on each live shard in the ring. | |||
// It returns the first error if any. | |||
func (c *Ring) ForEachShard( | |||
ctx context.Context, | |||
fn func(ctx context.Context, client *Client) error, | |||
) error { | |||
shards := c.shards.List() | |||
var wg sync.WaitGroup | |||
errCh := make(chan error, 1) | |||
for _, shard := range shards { | |||
if shard.IsDown() { | |||
continue | |||
} | |||
wg.Add(1) | |||
go func(shard *ringShard) { | |||
defer wg.Done() | |||
err := fn(ctx, shard.Client) | |||
if err != nil { | |||
select { | |||
case errCh <- err: | |||
default: | |||
} | |||
} | |||
}(shard) | |||
} | |||
wg.Wait() | |||
select { | |||
case err := <-errCh: | |||
return err | |||
default: | |||
return nil | |||
} | |||
} | |||
func (c *Ring) cmdsInfo(ctx context.Context) (map[string]*CommandInfo, error) { | |||
shards := c.shards.List() | |||
var firstErr error | |||
for _, shard := range shards { | |||
cmdsInfo, err := shard.Client.Command(ctx).Result() | |||
if err == nil { | |||
return cmdsInfo, nil | |||
} | |||
if firstErr == nil { | |||
firstErr = err | |||
} | |||
} | |||
if firstErr == nil { | |||
return nil, errRingShardsDown | |||
} | |||
return nil, firstErr | |||
} | |||
func (c *Ring) cmdInfo(ctx context.Context, name string) *CommandInfo { | |||
cmdsInfo, err := c.cmdsInfoCache.Get(ctx) | |||
if err != nil { | |||
return nil | |||
} | |||
info := cmdsInfo[name] | |||
if info == nil { | |||
internal.Logger.Printf(c.Context(), "info for cmd=%s not found", name) | |||
} | |||
return info | |||
} | |||
func (c *Ring) cmdShard(ctx context.Context, cmd Cmder) (*ringShard, error) { | |||
cmdInfo := c.cmdInfo(ctx, cmd.Name()) | |||
pos := cmdFirstKeyPos(cmd, cmdInfo) | |||
if pos == 0 { | |||
return c.shards.Random() | |||
} | |||
firstKey := cmd.stringArg(pos) | |||
return c.shards.GetByKey(firstKey) | |||
} | |||
func (c *Ring) process(ctx context.Context, cmd Cmder) error { | |||
var lastErr error | |||
for attempt := 0; attempt <= c.opt.MaxRetries; attempt++ { | |||
if attempt > 0 { | |||
if err := internal.Sleep(ctx, c.retryBackoff(attempt)); err != nil { | |||
return err | |||
} | |||
} | |||
shard, err := c.cmdShard(ctx, cmd) | |||
if err != nil { | |||
return err | |||
} | |||
lastErr = shard.Client.Process(ctx, cmd) | |||
if lastErr == nil || !shouldRetry(lastErr, cmd.readTimeout() == nil) { | |||
return lastErr | |||
} | |||
} | |||
return lastErr | |||
} | |||
func (c *Ring) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||
return c.Pipeline().Pipelined(ctx, fn) | |||
} | |||
func (c *Ring) Pipeline() Pipeliner { | |||
pipe := Pipeline{ | |||
ctx: c.ctx, | |||
exec: c.processPipeline, | |||
} | |||
pipe.init() | |||
return &pipe | |||
} | |||
func (c *Ring) processPipeline(ctx context.Context, cmds []Cmder) error { | |||
return c.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { | |||
return c.generalProcessPipeline(ctx, cmds, false) | |||
}) | |||
} | |||
func (c *Ring) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||
return c.TxPipeline().Pipelined(ctx, fn) | |||
} | |||
func (c *Ring) TxPipeline() Pipeliner { | |||
pipe := Pipeline{ | |||
ctx: c.ctx, | |||
exec: c.processTxPipeline, | |||
} | |||
pipe.init() | |||
return &pipe | |||
} | |||
func (c *Ring) processTxPipeline(ctx context.Context, cmds []Cmder) error { | |||
return c.hooks.processPipeline(ctx, cmds, func(ctx context.Context, cmds []Cmder) error { | |||
return c.generalProcessPipeline(ctx, cmds, true) | |||
}) | |||
} | |||
func (c *Ring) generalProcessPipeline( | |||
ctx context.Context, cmds []Cmder, tx bool, | |||
) error { | |||
cmdsMap := make(map[string][]Cmder) | |||
for _, cmd := range cmds { | |||
cmdInfo := c.cmdInfo(ctx, cmd.Name()) | |||
hash := cmd.stringArg(cmdFirstKeyPos(cmd, cmdInfo)) | |||
if hash != "" { | |||
hash = c.shards.Hash(hash) | |||
} | |||
cmdsMap[hash] = append(cmdsMap[hash], cmd) | |||
} | |||
var wg sync.WaitGroup | |||
for hash, cmds := range cmdsMap { | |||
wg.Add(1) | |||
go func(hash string, cmds []Cmder) { | |||
defer wg.Done() | |||
_ = c.processShardPipeline(ctx, hash, cmds, tx) | |||
}(hash, cmds) | |||
} | |||
wg.Wait() | |||
return cmdsFirstErr(cmds) | |||
} | |||
func (c *Ring) processShardPipeline( | |||
ctx context.Context, hash string, cmds []Cmder, tx bool, | |||
) error { | |||
// TODO: retry? | |||
shard, err := c.shards.GetByName(hash) | |||
if err != nil { | |||
setCmdsErr(cmds, err) | |||
return err | |||
} | |||
if tx { | |||
return shard.Client.processTxPipeline(ctx, cmds) | |||
} | |||
return shard.Client.processPipeline(ctx, cmds) | |||
} | |||
func (c *Ring) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error { | |||
if len(keys) == 0 { | |||
return fmt.Errorf("redis: Watch requires at least one key") | |||
} | |||
var shards []*ringShard | |||
for _, key := range keys { | |||
if key != "" { | |||
shard, err := c.shards.GetByKey(hashtag.Key(key)) | |||
if err != nil { | |||
return err | |||
} | |||
shards = append(shards, shard) | |||
} | |||
} | |||
if len(shards) == 0 { | |||
return fmt.Errorf("redis: Watch requires at least one shard") | |||
} | |||
if len(shards) > 1 { | |||
for _, shard := range shards[1:] { | |||
if shard.Client != shards[0].Client { | |||
err := fmt.Errorf("redis: Watch requires all keys to be in the same shard") | |||
return err | |||
} | |||
} | |||
} | |||
return shards[0].Client.Watch(ctx, fn, keys...) | |||
} | |||
// Close closes the ring client, releasing any open resources. | |||
// | |||
// It is rare to Close a Ring, as the Ring is meant to be long-lived | |||
// and shared between many goroutines. | |||
func (c *Ring) Close() error { | |||
return c.shards.Close() | |||
} |
@@ -0,0 +1,65 @@ | |||
package redis | |||
import ( | |||
"context" | |||
"crypto/sha1" | |||
"encoding/hex" | |||
"io" | |||
"strings" | |||
) | |||
type Scripter interface { | |||
Eval(ctx context.Context, script string, keys []string, args ...interface{}) *Cmd | |||
EvalSha(ctx context.Context, sha1 string, keys []string, args ...interface{}) *Cmd | |||
ScriptExists(ctx context.Context, hashes ...string) *BoolSliceCmd | |||
ScriptLoad(ctx context.Context, script string) *StringCmd | |||
} | |||
var ( | |||
_ Scripter = (*Client)(nil) | |||
_ Scripter = (*Ring)(nil) | |||
_ Scripter = (*ClusterClient)(nil) | |||
) | |||
type Script struct { | |||
src, hash string | |||
} | |||
func NewScript(src string) *Script { | |||
h := sha1.New() | |||
_, _ = io.WriteString(h, src) | |||
return &Script{ | |||
src: src, | |||
hash: hex.EncodeToString(h.Sum(nil)), | |||
} | |||
} | |||
func (s *Script) Hash() string { | |||
return s.hash | |||
} | |||
func (s *Script) Load(ctx context.Context, c Scripter) *StringCmd { | |||
return c.ScriptLoad(ctx, s.src) | |||
} | |||
func (s *Script) Exists(ctx context.Context, c Scripter) *BoolSliceCmd { | |||
return c.ScriptExists(ctx, s.hash) | |||
} | |||
func (s *Script) Eval(ctx context.Context, c Scripter, keys []string, args ...interface{}) *Cmd { | |||
return c.Eval(ctx, s.src, keys, args...) | |||
} | |||
func (s *Script) EvalSha(ctx context.Context, c Scripter, keys []string, args ...interface{}) *Cmd { | |||
return c.EvalSha(ctx, s.hash, keys, args...) | |||
} | |||
// Run optimistically uses EVALSHA to run the script. If script does not exist | |||
// it is retried using EVAL. | |||
func (s *Script) Run(ctx context.Context, c Scripter, keys []string, args ...interface{}) *Cmd { | |||
r := s.EvalSha(ctx, c, keys, args...) | |||
if err := r.Err(); err != nil && strings.HasPrefix(err.Error(), "NOSCRIPT ") { | |||
return s.Eval(ctx, c, keys, args...) | |||
} | |||
return r | |||
} |
@@ -0,0 +1,796 @@ | |||
package redis | |||
import ( | |||
"context" | |||
"crypto/tls" | |||
"errors" | |||
"net" | |||
"strings" | |||
"sync" | |||
"time" | |||
"github.com/go-redis/redis/v8/internal" | |||
"github.com/go-redis/redis/v8/internal/pool" | |||
"github.com/go-redis/redis/v8/internal/rand" | |||
) | |||
//------------------------------------------------------------------------------ | |||
// FailoverOptions are used to configure a failover client and should | |||
// be passed to NewFailoverClient. | |||
type FailoverOptions struct { | |||
// The master name. | |||
MasterName string | |||
// A seed list of host:port addresses of sentinel nodes. | |||
SentinelAddrs []string | |||
// If specified with SentinelPassword, enables ACL-based authentication (via | |||
// AUTH <user> <pass>). | |||
SentinelUsername string | |||
// Sentinel password from "requirepass <password>" (if enabled) in Sentinel | |||
// configuration, or, if SentinelUsername is also supplied, used for ACL-based | |||
// authentication. | |||
SentinelPassword string | |||
// Allows routing read-only commands to the closest master or slave node. | |||
// This option only works with NewFailoverClusterClient. | |||
RouteByLatency bool | |||
// Allows routing read-only commands to the random master or slave node. | |||
// This option only works with NewFailoverClusterClient. | |||
RouteRandomly bool | |||
// Route all commands to slave read-only nodes. | |||
SlaveOnly bool | |||
// Use slaves disconnected with master when cannot get connected slaves | |||
// Now, this option only works in RandomSlaveAddr function. | |||
UseDisconnectedSlaves bool | |||
// Following options are copied from Options struct. | |||
Dialer func(ctx context.Context, network, addr string) (net.Conn, error) | |||
OnConnect func(ctx context.Context, cn *Conn) error | |||
Username string | |||
Password string | |||
DB int | |||
MaxRetries int | |||
MinRetryBackoff time.Duration | |||
MaxRetryBackoff time.Duration | |||
DialTimeout time.Duration | |||
ReadTimeout time.Duration | |||
WriteTimeout time.Duration | |||
// PoolFIFO uses FIFO mode for each node connection pool GET/PUT (default LIFO). | |||
PoolFIFO bool | |||
PoolSize int | |||
MinIdleConns int | |||
MaxConnAge time.Duration | |||
PoolTimeout time.Duration | |||
IdleTimeout time.Duration | |||
IdleCheckFrequency time.Duration | |||
TLSConfig *tls.Config | |||
} | |||
func (opt *FailoverOptions) clientOptions() *Options { | |||
return &Options{ | |||
Addr: "FailoverClient", | |||
Dialer: opt.Dialer, | |||
OnConnect: opt.OnConnect, | |||
DB: opt.DB, | |||
Username: opt.Username, | |||
Password: opt.Password, | |||
MaxRetries: opt.MaxRetries, | |||
MinRetryBackoff: opt.MinRetryBackoff, | |||
MaxRetryBackoff: opt.MaxRetryBackoff, | |||
DialTimeout: opt.DialTimeout, | |||
ReadTimeout: opt.ReadTimeout, | |||
WriteTimeout: opt.WriteTimeout, | |||
PoolFIFO: opt.PoolFIFO, | |||
PoolSize: opt.PoolSize, | |||
PoolTimeout: opt.PoolTimeout, | |||
IdleTimeout: opt.IdleTimeout, | |||
IdleCheckFrequency: opt.IdleCheckFrequency, | |||
MinIdleConns: opt.MinIdleConns, | |||
MaxConnAge: opt.MaxConnAge, | |||
TLSConfig: opt.TLSConfig, | |||
} | |||
} | |||
func (opt *FailoverOptions) sentinelOptions(addr string) *Options { | |||
return &Options{ | |||
Addr: addr, | |||
Dialer: opt.Dialer, | |||
OnConnect: opt.OnConnect, | |||
DB: 0, | |||
Username: opt.SentinelUsername, | |||
Password: opt.SentinelPassword, | |||
MaxRetries: opt.MaxRetries, | |||
MinRetryBackoff: opt.MinRetryBackoff, | |||
MaxRetryBackoff: opt.MaxRetryBackoff, | |||
DialTimeout: opt.DialTimeout, | |||
ReadTimeout: opt.ReadTimeout, | |||
WriteTimeout: opt.WriteTimeout, | |||
PoolFIFO: opt.PoolFIFO, | |||
PoolSize: opt.PoolSize, | |||
PoolTimeout: opt.PoolTimeout, | |||
IdleTimeout: opt.IdleTimeout, | |||
IdleCheckFrequency: opt.IdleCheckFrequency, | |||
MinIdleConns: opt.MinIdleConns, | |||
MaxConnAge: opt.MaxConnAge, | |||
TLSConfig: opt.TLSConfig, | |||
} | |||
} | |||
func (opt *FailoverOptions) clusterOptions() *ClusterOptions { | |||
return &ClusterOptions{ | |||
Dialer: opt.Dialer, | |||
OnConnect: opt.OnConnect, | |||
Username: opt.Username, | |||
Password: opt.Password, | |||
MaxRedirects: opt.MaxRetries, | |||
RouteByLatency: opt.RouteByLatency, | |||
RouteRandomly: opt.RouteRandomly, | |||
MinRetryBackoff: opt.MinRetryBackoff, | |||
MaxRetryBackoff: opt.MaxRetryBackoff, | |||
DialTimeout: opt.DialTimeout, | |||
ReadTimeout: opt.ReadTimeout, | |||
WriteTimeout: opt.WriteTimeout, | |||
PoolFIFO: opt.PoolFIFO, | |||
PoolSize: opt.PoolSize, | |||
PoolTimeout: opt.PoolTimeout, | |||
IdleTimeout: opt.IdleTimeout, | |||
IdleCheckFrequency: opt.IdleCheckFrequency, | |||
MinIdleConns: opt.MinIdleConns, | |||
MaxConnAge: opt.MaxConnAge, | |||
TLSConfig: opt.TLSConfig, | |||
} | |||
} | |||
// NewFailoverClient returns a Redis client that uses Redis Sentinel | |||
// for automatic failover. It's safe for concurrent use by multiple | |||
// goroutines. | |||
func NewFailoverClient(failoverOpt *FailoverOptions) *Client { | |||
if failoverOpt.RouteByLatency { | |||
panic("to route commands by latency, use NewFailoverClusterClient") | |||
} | |||
if failoverOpt.RouteRandomly { | |||
panic("to route commands randomly, use NewFailoverClusterClient") | |||
} | |||
sentinelAddrs := make([]string, len(failoverOpt.SentinelAddrs)) | |||
copy(sentinelAddrs, failoverOpt.SentinelAddrs) | |||
rand.Shuffle(len(sentinelAddrs), func(i, j int) { | |||
sentinelAddrs[i], sentinelAddrs[j] = sentinelAddrs[j], sentinelAddrs[i] | |||
}) | |||
failover := &sentinelFailover{ | |||
opt: failoverOpt, | |||
sentinelAddrs: sentinelAddrs, | |||
} | |||
opt := failoverOpt.clientOptions() | |||
opt.Dialer = masterSlaveDialer(failover) | |||
opt.init() | |||
connPool := newConnPool(opt) | |||
failover.mu.Lock() | |||
failover.onFailover = func(ctx context.Context, addr string) { | |||
_ = connPool.Filter(func(cn *pool.Conn) bool { | |||
return cn.RemoteAddr().String() != addr | |||
}) | |||
} | |||
failover.mu.Unlock() | |||
c := Client{ | |||
baseClient: newBaseClient(opt, connPool), | |||
ctx: context.Background(), | |||
} | |||
c.cmdable = c.Process | |||
c.onClose = failover.Close | |||
return &c | |||
} | |||
func masterSlaveDialer( | |||
failover *sentinelFailover, | |||
) func(ctx context.Context, network, addr string) (net.Conn, error) { | |||
return func(ctx context.Context, network, _ string) (net.Conn, error) { | |||
var addr string | |||
var err error | |||
if failover.opt.SlaveOnly { | |||
addr, err = failover.RandomSlaveAddr(ctx) | |||
} else { | |||
addr, err = failover.MasterAddr(ctx) | |||
if err == nil { | |||
failover.trySwitchMaster(ctx, addr) | |||
} | |||
} | |||
if err != nil { | |||
return nil, err | |||
} | |||
if failover.opt.Dialer != nil { | |||
return failover.opt.Dialer(ctx, network, addr) | |||
} | |||
netDialer := &net.Dialer{ | |||
Timeout: failover.opt.DialTimeout, | |||
KeepAlive: 5 * time.Minute, | |||
} | |||
if failover.opt.TLSConfig == nil { | |||
return netDialer.DialContext(ctx, network, addr) | |||
} | |||
return tls.DialWithDialer(netDialer, network, addr, failover.opt.TLSConfig) | |||
} | |||
} | |||
//------------------------------------------------------------------------------ | |||
// SentinelClient is a client for a Redis Sentinel. | |||
type SentinelClient struct { | |||
*baseClient | |||
hooks | |||
ctx context.Context | |||
} | |||
func NewSentinelClient(opt *Options) *SentinelClient { | |||
opt.init() | |||
c := &SentinelClient{ | |||
baseClient: &baseClient{ | |||
opt: opt, | |||
connPool: newConnPool(opt), | |||
}, | |||
ctx: context.Background(), | |||
} | |||
return c | |||
} | |||
func (c *SentinelClient) Context() context.Context { | |||
return c.ctx | |||
} | |||
func (c *SentinelClient) WithContext(ctx context.Context) *SentinelClient { | |||
if ctx == nil { | |||
panic("nil context") | |||
} | |||
clone := *c | |||
clone.ctx = ctx | |||
return &clone | |||
} | |||
func (c *SentinelClient) Process(ctx context.Context, cmd Cmder) error { | |||
return c.hooks.process(ctx, cmd, c.baseClient.process) | |||
} | |||
func (c *SentinelClient) pubSub() *PubSub { | |||
pubsub := &PubSub{ | |||
opt: c.opt, | |||
newConn: func(ctx context.Context, channels []string) (*pool.Conn, error) { | |||
return c.newConn(ctx) | |||
}, | |||
closeConn: c.connPool.CloseConn, | |||
} | |||
pubsub.init() | |||
return pubsub | |||
} | |||
// Ping is used to test if a connection is still alive, or to | |||
// measure latency. | |||
func (c *SentinelClient) Ping(ctx context.Context) *StringCmd { | |||
cmd := NewStringCmd(ctx, "ping") | |||
_ = c.Process(ctx, cmd) | |||
return cmd | |||
} | |||
// Subscribe subscribes the client to the specified channels. | |||
// Channels can be omitted to create empty subscription. | |||
func (c *SentinelClient) Subscribe(ctx context.Context, channels ...string) *PubSub { | |||
pubsub := c.pubSub() | |||
if len(channels) > 0 { | |||
_ = pubsub.Subscribe(ctx, channels...) | |||
} | |||
return pubsub | |||
} | |||
// PSubscribe subscribes the client to the given patterns. | |||
// Patterns can be omitted to create empty subscription. | |||
func (c *SentinelClient) PSubscribe(ctx context.Context, channels ...string) *PubSub { | |||
pubsub := c.pubSub() | |||
if len(channels) > 0 { | |||
_ = pubsub.PSubscribe(ctx, channels...) | |||
} | |||
return pubsub | |||
} | |||
func (c *SentinelClient) GetMasterAddrByName(ctx context.Context, name string) *StringSliceCmd { | |||
cmd := NewStringSliceCmd(ctx, "sentinel", "get-master-addr-by-name", name) | |||
_ = c.Process(ctx, cmd) | |||
return cmd | |||
} | |||
func (c *SentinelClient) Sentinels(ctx context.Context, name string) *SliceCmd { | |||
cmd := NewSliceCmd(ctx, "sentinel", "sentinels", name) | |||
_ = c.Process(ctx, cmd) | |||
return cmd | |||
} | |||
// Failover forces a failover as if the master was not reachable, and without | |||
// asking for agreement to other Sentinels. | |||
func (c *SentinelClient) Failover(ctx context.Context, name string) *StatusCmd { | |||
cmd := NewStatusCmd(ctx, "sentinel", "failover", name) | |||
_ = c.Process(ctx, cmd) | |||
return cmd | |||
} | |||
// Reset resets all the masters with matching name. The pattern argument is a | |||
// glob-style pattern. The reset process clears any previous state in a master | |||
// (including a failover in progress), and removes every slave and sentinel | |||
// already discovered and associated with the master. | |||
func (c *SentinelClient) Reset(ctx context.Context, pattern string) *IntCmd { | |||
cmd := NewIntCmd(ctx, "sentinel", "reset", pattern) | |||
_ = c.Process(ctx, cmd) | |||
return cmd | |||
} | |||
// FlushConfig forces Sentinel to rewrite its configuration on disk, including | |||
// the current Sentinel state. | |||
func (c *SentinelClient) FlushConfig(ctx context.Context) *StatusCmd { | |||
cmd := NewStatusCmd(ctx, "sentinel", "flushconfig") | |||
_ = c.Process(ctx, cmd) | |||
return cmd | |||
} | |||
// Master shows the state and info of the specified master. | |||
func (c *SentinelClient) Master(ctx context.Context, name string) *StringStringMapCmd { | |||
cmd := NewStringStringMapCmd(ctx, "sentinel", "master", name) | |||
_ = c.Process(ctx, cmd) | |||
return cmd | |||
} | |||
// Masters shows a list of monitored masters and their state. | |||
func (c *SentinelClient) Masters(ctx context.Context) *SliceCmd { | |||
cmd := NewSliceCmd(ctx, "sentinel", "masters") | |||
_ = c.Process(ctx, cmd) | |||
return cmd | |||
} | |||
// Slaves shows a list of slaves for the specified master and their state. | |||
func (c *SentinelClient) Slaves(ctx context.Context, name string) *SliceCmd { | |||
cmd := NewSliceCmd(ctx, "sentinel", "slaves", name) | |||
_ = c.Process(ctx, cmd) | |||
return cmd | |||
} | |||
// CkQuorum checks if the current Sentinel configuration is able to reach the | |||
// quorum needed to failover a master, and the majority needed to authorize the | |||
// failover. This command should be used in monitoring systems to check if a | |||
// Sentinel deployment is ok. | |||
func (c *SentinelClient) CkQuorum(ctx context.Context, name string) *StringCmd { | |||
cmd := NewStringCmd(ctx, "sentinel", "ckquorum", name) | |||
_ = c.Process(ctx, cmd) | |||
return cmd | |||
} | |||
// Monitor tells the Sentinel to start monitoring a new master with the specified | |||
// name, ip, port, and quorum. | |||
func (c *SentinelClient) Monitor(ctx context.Context, name, ip, port, quorum string) *StringCmd { | |||
cmd := NewStringCmd(ctx, "sentinel", "monitor", name, ip, port, quorum) | |||
_ = c.Process(ctx, cmd) | |||
return cmd | |||
} | |||
// Set is used in order to change configuration parameters of a specific master. | |||
func (c *SentinelClient) Set(ctx context.Context, name, option, value string) *StringCmd { | |||
cmd := NewStringCmd(ctx, "sentinel", "set", name, option, value) | |||
_ = c.Process(ctx, cmd) | |||
return cmd | |||
} | |||
// Remove is used in order to remove the specified master: the master will no | |||
// longer be monitored, and will totally be removed from the internal state of | |||
// the Sentinel. | |||
func (c *SentinelClient) Remove(ctx context.Context, name string) *StringCmd { | |||
cmd := NewStringCmd(ctx, "sentinel", "remove", name) | |||
_ = c.Process(ctx, cmd) | |||
return cmd | |||
} | |||
//------------------------------------------------------------------------------ | |||
type sentinelFailover struct { | |||
opt *FailoverOptions | |||
sentinelAddrs []string | |||
onFailover func(ctx context.Context, addr string) | |||
onUpdate func(ctx context.Context) | |||
mu sync.RWMutex | |||
_masterAddr string | |||
sentinel *SentinelClient | |||
pubsub *PubSub | |||
} | |||
func (c *sentinelFailover) Close() error { | |||
c.mu.Lock() | |||
defer c.mu.Unlock() | |||
if c.sentinel != nil { | |||
return c.closeSentinel() | |||
} | |||
return nil | |||
} | |||
func (c *sentinelFailover) closeSentinel() error { | |||
firstErr := c.pubsub.Close() | |||
c.pubsub = nil | |||
err := c.sentinel.Close() | |||
if err != nil && firstErr == nil { | |||
firstErr = err | |||
} | |||
c.sentinel = nil | |||
return firstErr | |||
} | |||
func (c *sentinelFailover) RandomSlaveAddr(ctx context.Context) (string, error) { | |||
if c.opt == nil { | |||
return "", errors.New("opt is nil") | |||
} | |||
addresses, err := c.slaveAddrs(ctx, false) | |||
if err != nil { | |||
return "", err | |||
} | |||
if len(addresses) == 0 && c.opt.UseDisconnectedSlaves { | |||
addresses, err = c.slaveAddrs(ctx, true) | |||
if err != nil { | |||
return "", err | |||
} | |||
} | |||
if len(addresses) == 0 { | |||
return c.MasterAddr(ctx) | |||
} | |||
return addresses[rand.Intn(len(addresses))], nil | |||
} | |||
func (c *sentinelFailover) MasterAddr(ctx context.Context) (string, error) { | |||
c.mu.RLock() | |||
sentinel := c.sentinel | |||
c.mu.RUnlock() | |||
if sentinel != nil { | |||
addr := c.getMasterAddr(ctx, sentinel) | |||
if addr != "" { | |||
return addr, nil | |||
} | |||
} | |||
c.mu.Lock() | |||
defer c.mu.Unlock() | |||
if c.sentinel != nil { | |||
addr := c.getMasterAddr(ctx, c.sentinel) | |||
if addr != "" { | |||
return addr, nil | |||
} | |||
_ = c.closeSentinel() | |||
} | |||
for i, sentinelAddr := range c.sentinelAddrs { | |||
sentinel := NewSentinelClient(c.opt.sentinelOptions(sentinelAddr)) | |||
masterAddr, err := sentinel.GetMasterAddrByName(ctx, c.opt.MasterName).Result() | |||
if err != nil { | |||
internal.Logger.Printf(ctx, "sentinel: GetMasterAddrByName master=%q failed: %s", | |||
c.opt.MasterName, err) | |||
_ = sentinel.Close() | |||
continue | |||
} | |||
// Push working sentinel to the top. | |||
c.sentinelAddrs[0], c.sentinelAddrs[i] = c.sentinelAddrs[i], c.sentinelAddrs[0] | |||
c.setSentinel(ctx, sentinel) | |||
addr := net.JoinHostPort(masterAddr[0], masterAddr[1]) | |||
return addr, nil | |||
} | |||
return "", errors.New("redis: all sentinels specified in configuration are unreachable") | |||
} | |||
func (c *sentinelFailover) slaveAddrs(ctx context.Context, useDisconnected bool) ([]string, error) { | |||
c.mu.RLock() | |||
sentinel := c.sentinel | |||
c.mu.RUnlock() | |||
if sentinel != nil { | |||
addrs := c.getSlaveAddrs(ctx, sentinel) | |||
if len(addrs) > 0 { | |||
return addrs, nil | |||
} | |||
} | |||
c.mu.Lock() | |||
defer c.mu.Unlock() | |||
if c.sentinel != nil { | |||
addrs := c.getSlaveAddrs(ctx, c.sentinel) | |||
if len(addrs) > 0 { | |||
return addrs, nil | |||
} | |||
_ = c.closeSentinel() | |||
} | |||
var sentinelReachable bool | |||
for i, sentinelAddr := range c.sentinelAddrs { | |||
sentinel := NewSentinelClient(c.opt.sentinelOptions(sentinelAddr)) | |||
slaves, err := sentinel.Slaves(ctx, c.opt.MasterName).Result() | |||
if err != nil { | |||
internal.Logger.Printf(ctx, "sentinel: Slaves master=%q failed: %s", | |||
c.opt.MasterName, err) | |||
_ = sentinel.Close() | |||
continue | |||
} | |||
sentinelReachable = true | |||
addrs := parseSlaveAddrs(slaves, useDisconnected) | |||
if len(addrs) == 0 { | |||
continue | |||
} | |||
// Push working sentinel to the top. | |||
c.sentinelAddrs[0], c.sentinelAddrs[i] = c.sentinelAddrs[i], c.sentinelAddrs[0] | |||
c.setSentinel(ctx, sentinel) | |||
return addrs, nil | |||
} | |||
if sentinelReachable { | |||
return []string{}, nil | |||
} | |||
return []string{}, errors.New("redis: all sentinels specified in configuration are unreachable") | |||
} | |||
func (c *sentinelFailover) getMasterAddr(ctx context.Context, sentinel *SentinelClient) string { | |||
addr, err := sentinel.GetMasterAddrByName(ctx, c.opt.MasterName).Result() | |||
if err != nil { | |||
internal.Logger.Printf(ctx, "sentinel: GetMasterAddrByName name=%q failed: %s", | |||
c.opt.MasterName, err) | |||
return "" | |||
} | |||
return net.JoinHostPort(addr[0], addr[1]) | |||
} | |||
func (c *sentinelFailover) getSlaveAddrs(ctx context.Context, sentinel *SentinelClient) []string { | |||
addrs, err := sentinel.Slaves(ctx, c.opt.MasterName).Result() | |||
if err != nil { | |||
internal.Logger.Printf(ctx, "sentinel: Slaves name=%q failed: %s", | |||
c.opt.MasterName, err) | |||
return []string{} | |||
} | |||
return parseSlaveAddrs(addrs, false) | |||
} | |||
func parseSlaveAddrs(addrs []interface{}, keepDisconnected bool) []string { | |||
nodes := make([]string, 0, len(addrs)) | |||
for _, node := range addrs { | |||
ip := "" | |||
port := "" | |||
flags := []string{} | |||
lastkey := "" | |||
isDown := false | |||
for _, key := range node.([]interface{}) { | |||
switch lastkey { | |||
case "ip": | |||
ip = key.(string) | |||
case "port": | |||
port = key.(string) | |||
case "flags": | |||
flags = strings.Split(key.(string), ",") | |||
} | |||
lastkey = key.(string) | |||
} | |||
for _, flag := range flags { | |||
switch flag { | |||
case "s_down", "o_down": | |||
isDown = true | |||
case "disconnected": | |||
if !keepDisconnected { | |||
isDown = true | |||
} | |||
} | |||
} | |||
if !isDown { | |||
nodes = append(nodes, net.JoinHostPort(ip, port)) | |||
} | |||
} | |||
return nodes | |||
} | |||
func (c *sentinelFailover) trySwitchMaster(ctx context.Context, addr string) { | |||
c.mu.RLock() | |||
currentAddr := c._masterAddr //nolint:ifshort | |||
c.mu.RUnlock() | |||
if addr == currentAddr { | |||
return | |||
} | |||
c.mu.Lock() | |||
defer c.mu.Unlock() | |||
if addr == c._masterAddr { | |||
return | |||
} | |||
c._masterAddr = addr | |||
internal.Logger.Printf(ctx, "sentinel: new master=%q addr=%q", | |||
c.opt.MasterName, addr) | |||
if c.onFailover != nil { | |||
c.onFailover(ctx, addr) | |||
} | |||
} | |||
func (c *sentinelFailover) setSentinel(ctx context.Context, sentinel *SentinelClient) { | |||
if c.sentinel != nil { | |||
panic("not reached") | |||
} | |||
c.sentinel = sentinel | |||
c.discoverSentinels(ctx) | |||
c.pubsub = sentinel.Subscribe(ctx, "+switch-master", "+slave-reconf-done") | |||
go c.listen(c.pubsub) | |||
} | |||
func (c *sentinelFailover) discoverSentinels(ctx context.Context) { | |||
sentinels, err := c.sentinel.Sentinels(ctx, c.opt.MasterName).Result() | |||
if err != nil { | |||
internal.Logger.Printf(ctx, "sentinel: Sentinels master=%q failed: %s", c.opt.MasterName, err) | |||
return | |||
} | |||
for _, sentinel := range sentinels { | |||
vals := sentinel.([]interface{}) | |||
var ip, port string | |||
for i := 0; i < len(vals); i += 2 { | |||
key := vals[i].(string) | |||
switch key { | |||
case "ip": | |||
ip = vals[i+1].(string) | |||
case "port": | |||
port = vals[i+1].(string) | |||
} | |||
} | |||
if ip != "" && port != "" { | |||
sentinelAddr := net.JoinHostPort(ip, port) | |||
if !contains(c.sentinelAddrs, sentinelAddr) { | |||
internal.Logger.Printf(ctx, "sentinel: discovered new sentinel=%q for master=%q", | |||
sentinelAddr, c.opt.MasterName) | |||
c.sentinelAddrs = append(c.sentinelAddrs, sentinelAddr) | |||
} | |||
} | |||
} | |||
} | |||
func (c *sentinelFailover) listen(pubsub *PubSub) { | |||
ctx := context.TODO() | |||
if c.onUpdate != nil { | |||
c.onUpdate(ctx) | |||
} | |||
ch := pubsub.Channel() | |||
for msg := range ch { | |||
if msg.Channel == "+switch-master" { | |||
parts := strings.Split(msg.Payload, " ") | |||
if parts[0] != c.opt.MasterName { | |||
internal.Logger.Printf(pubsub.getContext(), "sentinel: ignore addr for master=%q", parts[0]) | |||
continue | |||
} | |||
addr := net.JoinHostPort(parts[3], parts[4]) | |||
c.trySwitchMaster(pubsub.getContext(), addr) | |||
} | |||
if c.onUpdate != nil { | |||
c.onUpdate(ctx) | |||
} | |||
} | |||
} | |||
func contains(slice []string, str string) bool { | |||
for _, s := range slice { | |||
if s == str { | |||
return true | |||
} | |||
} | |||
return false | |||
} | |||
//------------------------------------------------------------------------------ | |||
// NewFailoverClusterClient returns a client that supports routing read-only commands | |||
// to a slave node. | |||
func NewFailoverClusterClient(failoverOpt *FailoverOptions) *ClusterClient { | |||
sentinelAddrs := make([]string, len(failoverOpt.SentinelAddrs)) | |||
copy(sentinelAddrs, failoverOpt.SentinelAddrs) | |||
failover := &sentinelFailover{ | |||
opt: failoverOpt, | |||
sentinelAddrs: sentinelAddrs, | |||
} | |||
opt := failoverOpt.clusterOptions() | |||
opt.ClusterSlots = func(ctx context.Context) ([]ClusterSlot, error) { | |||
masterAddr, err := failover.MasterAddr(ctx) | |||
if err != nil { | |||
return nil, err | |||
} | |||
nodes := []ClusterNode{{ | |||
Addr: masterAddr, | |||
}} | |||
slaveAddrs, err := failover.slaveAddrs(ctx, false) | |||
if err != nil { | |||
return nil, err | |||
} | |||
for _, slaveAddr := range slaveAddrs { | |||
nodes = append(nodes, ClusterNode{ | |||
Addr: slaveAddr, | |||
}) | |||
} | |||
slots := []ClusterSlot{ | |||
{ | |||
Start: 0, | |||
End: 16383, | |||
Nodes: nodes, | |||
}, | |||
} | |||
return slots, nil | |||
} | |||
c := NewClusterClient(opt) | |||
failover.mu.Lock() | |||
failover.onUpdate = func(ctx context.Context) { | |||
c.ReloadState(ctx) | |||
} | |||
failover.mu.Unlock() | |||
return c | |||
} |
@@ -0,0 +1,149 @@ | |||
package redis | |||
import ( | |||
"context" | |||
"github.com/go-redis/redis/v8/internal/pool" | |||
"github.com/go-redis/redis/v8/internal/proto" | |||
) | |||
// TxFailedErr transaction redis failed. | |||
const TxFailedErr = proto.RedisError("redis: transaction failed") | |||
// Tx implements Redis transactions as described in | |||
// http://redis.io/topics/transactions. It's NOT safe for concurrent use | |||
// by multiple goroutines, because Exec resets list of watched keys. | |||
// | |||
// If you don't need WATCH, use Pipeline instead. | |||
type Tx struct { | |||
baseClient | |||
cmdable | |||
statefulCmdable | |||
hooks | |||
ctx context.Context | |||
} | |||
func (c *Client) newTx(ctx context.Context) *Tx { | |||
tx := Tx{ | |||
baseClient: baseClient{ | |||
opt: c.opt, | |||
connPool: pool.NewStickyConnPool(c.connPool), | |||
}, | |||
hooks: c.hooks.clone(), | |||
ctx: ctx, | |||
} | |||
tx.init() | |||
return &tx | |||
} | |||
func (c *Tx) init() { | |||
c.cmdable = c.Process | |||
c.statefulCmdable = c.Process | |||
} | |||
func (c *Tx) Context() context.Context { | |||
return c.ctx | |||
} | |||
func (c *Tx) WithContext(ctx context.Context) *Tx { | |||
if ctx == nil { | |||
panic("nil context") | |||
} | |||
clone := *c | |||
clone.init() | |||
clone.hooks.lock() | |||
clone.ctx = ctx | |||
return &clone | |||
} | |||
func (c *Tx) Process(ctx context.Context, cmd Cmder) error { | |||
return c.hooks.process(ctx, cmd, c.baseClient.process) | |||
} | |||
// Watch prepares a transaction and marks the keys to be watched | |||
// for conditional execution if there are any keys. | |||
// | |||
// The transaction is automatically closed when fn exits. | |||
func (c *Client) Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error { | |||
tx := c.newTx(ctx) | |||
defer tx.Close(ctx) | |||
if len(keys) > 0 { | |||
if err := tx.Watch(ctx, keys...).Err(); err != nil { | |||
return err | |||
} | |||
} | |||
return fn(tx) | |||
} | |||
// Close closes the transaction, releasing any open resources. | |||
func (c *Tx) Close(ctx context.Context) error { | |||
_ = c.Unwatch(ctx).Err() | |||
return c.baseClient.Close() | |||
} | |||
// Watch marks the keys to be watched for conditional execution | |||
// of a transaction. | |||
func (c *Tx) Watch(ctx context.Context, keys ...string) *StatusCmd { | |||
args := make([]interface{}, 1+len(keys)) | |||
args[0] = "watch" | |||
for i, key := range keys { | |||
args[1+i] = key | |||
} | |||
cmd := NewStatusCmd(ctx, args...) | |||
_ = c.Process(ctx, cmd) | |||
return cmd | |||
} | |||
// Unwatch flushes all the previously watched keys for a transaction. | |||
func (c *Tx) Unwatch(ctx context.Context, keys ...string) *StatusCmd { | |||
args := make([]interface{}, 1+len(keys)) | |||
args[0] = "unwatch" | |||
for i, key := range keys { | |||
args[1+i] = key | |||
} | |||
cmd := NewStatusCmd(ctx, args...) | |||
_ = c.Process(ctx, cmd) | |||
return cmd | |||
} | |||
// Pipeline creates a pipeline. Usually it is more convenient to use Pipelined. | |||
func (c *Tx) Pipeline() Pipeliner { | |||
pipe := Pipeline{ | |||
ctx: c.ctx, | |||
exec: func(ctx context.Context, cmds []Cmder) error { | |||
return c.hooks.processPipeline(ctx, cmds, c.baseClient.processPipeline) | |||
}, | |||
} | |||
pipe.init() | |||
return &pipe | |||
} | |||
// Pipelined executes commands queued in the fn outside of the transaction. | |||
// Use TxPipelined if you need transactional behavior. | |||
func (c *Tx) Pipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||
return c.Pipeline().Pipelined(ctx, fn) | |||
} | |||
// TxPipelined executes commands queued in the fn in the transaction. | |||
// | |||
// When using WATCH, EXEC will execute commands only if the watched keys | |||
// were not modified, allowing for a check-and-set mechanism. | |||
// | |||
// Exec always returns list of commands. If transaction fails | |||
// TxFailedErr is returned. Otherwise Exec returns an error of the first | |||
// failed command or nil. | |||
func (c *Tx) TxPipelined(ctx context.Context, fn func(Pipeliner) error) ([]Cmder, error) { | |||
return c.TxPipeline().Pipelined(ctx, fn) | |||
} | |||
// TxPipeline creates a pipeline. Usually it is more convenient to use TxPipelined. | |||
func (c *Tx) TxPipeline() Pipeliner { | |||
pipe := Pipeline{ | |||
ctx: c.ctx, | |||
exec: func(ctx context.Context, cmds []Cmder) error { | |||
return c.hooks.processTxPipeline(ctx, cmds, c.baseClient.processTxPipeline) | |||
}, | |||
} | |||
pipe.init() | |||
return &pipe | |||
} |
@@ -0,0 +1,213 @@ | |||
package redis | |||
import ( | |||
"context" | |||
"crypto/tls" | |||
"net" | |||
"time" | |||
) | |||
// UniversalOptions information is required by UniversalClient to establish | |||
// connections. | |||
type UniversalOptions struct { | |||
// Either a single address or a seed list of host:port addresses | |||
// of cluster/sentinel nodes. | |||
Addrs []string | |||
// Database to be selected after connecting to the server. | |||
// Only single-node and failover clients. | |||
DB int | |||
// Common options. | |||
Dialer func(ctx context.Context, network, addr string) (net.Conn, error) | |||
OnConnect func(ctx context.Context, cn *Conn) error | |||
Username string | |||
Password string | |||
SentinelPassword string | |||
MaxRetries int | |||
MinRetryBackoff time.Duration | |||
MaxRetryBackoff time.Duration | |||
DialTimeout time.Duration | |||
ReadTimeout time.Duration | |||
WriteTimeout time.Duration | |||
// PoolFIFO uses FIFO mode for each node connection pool GET/PUT (default LIFO). | |||
PoolFIFO bool | |||
PoolSize int | |||
MinIdleConns int | |||
MaxConnAge time.Duration | |||
PoolTimeout time.Duration | |||
IdleTimeout time.Duration | |||
IdleCheckFrequency time.Duration | |||
TLSConfig *tls.Config | |||
// Only cluster clients. | |||
MaxRedirects int | |||
ReadOnly bool | |||
RouteByLatency bool | |||
RouteRandomly bool | |||
// The sentinel master name. | |||
// Only failover clients. | |||
MasterName string | |||
} | |||
// Cluster returns cluster options created from the universal options. | |||
func (o *UniversalOptions) Cluster() *ClusterOptions { | |||
if len(o.Addrs) == 0 { | |||
o.Addrs = []string{"127.0.0.1:6379"} | |||
} | |||
return &ClusterOptions{ | |||
Addrs: o.Addrs, | |||
Dialer: o.Dialer, | |||
OnConnect: o.OnConnect, | |||
Username: o.Username, | |||
Password: o.Password, | |||
MaxRedirects: o.MaxRedirects, | |||
ReadOnly: o.ReadOnly, | |||
RouteByLatency: o.RouteByLatency, | |||
RouteRandomly: o.RouteRandomly, | |||
MaxRetries: o.MaxRetries, | |||
MinRetryBackoff: o.MinRetryBackoff, | |||
MaxRetryBackoff: o.MaxRetryBackoff, | |||
DialTimeout: o.DialTimeout, | |||
ReadTimeout: o.ReadTimeout, | |||
WriteTimeout: o.WriteTimeout, | |||
PoolFIFO: o.PoolFIFO, | |||
PoolSize: o.PoolSize, | |||
MinIdleConns: o.MinIdleConns, | |||
MaxConnAge: o.MaxConnAge, | |||
PoolTimeout: o.PoolTimeout, | |||
IdleTimeout: o.IdleTimeout, | |||
IdleCheckFrequency: o.IdleCheckFrequency, | |||
TLSConfig: o.TLSConfig, | |||
} | |||
} | |||
// Failover returns failover options created from the universal options. | |||
func (o *UniversalOptions) Failover() *FailoverOptions { | |||
if len(o.Addrs) == 0 { | |||
o.Addrs = []string{"127.0.0.1:26379"} | |||
} | |||
return &FailoverOptions{ | |||
SentinelAddrs: o.Addrs, | |||
MasterName: o.MasterName, | |||
Dialer: o.Dialer, | |||
OnConnect: o.OnConnect, | |||
DB: o.DB, | |||
Username: o.Username, | |||
Password: o.Password, | |||
SentinelPassword: o.SentinelPassword, | |||
MaxRetries: o.MaxRetries, | |||
MinRetryBackoff: o.MinRetryBackoff, | |||
MaxRetryBackoff: o.MaxRetryBackoff, | |||
DialTimeout: o.DialTimeout, | |||
ReadTimeout: o.ReadTimeout, | |||
WriteTimeout: o.WriteTimeout, | |||
PoolFIFO: o.PoolFIFO, | |||
PoolSize: o.PoolSize, | |||
MinIdleConns: o.MinIdleConns, | |||
MaxConnAge: o.MaxConnAge, | |||
PoolTimeout: o.PoolTimeout, | |||
IdleTimeout: o.IdleTimeout, | |||
IdleCheckFrequency: o.IdleCheckFrequency, | |||
TLSConfig: o.TLSConfig, | |||
} | |||
} | |||
// Simple returns basic options created from the universal options. | |||
func (o *UniversalOptions) Simple() *Options { | |||
addr := "127.0.0.1:6379" | |||
if len(o.Addrs) > 0 { | |||
addr = o.Addrs[0] | |||
} | |||
return &Options{ | |||
Addr: addr, | |||
Dialer: o.Dialer, | |||
OnConnect: o.OnConnect, | |||
DB: o.DB, | |||
Username: o.Username, | |||
Password: o.Password, | |||
MaxRetries: o.MaxRetries, | |||
MinRetryBackoff: o.MinRetryBackoff, | |||
MaxRetryBackoff: o.MaxRetryBackoff, | |||
DialTimeout: o.DialTimeout, | |||
ReadTimeout: o.ReadTimeout, | |||
WriteTimeout: o.WriteTimeout, | |||
PoolFIFO: o.PoolFIFO, | |||
PoolSize: o.PoolSize, | |||
MinIdleConns: o.MinIdleConns, | |||
MaxConnAge: o.MaxConnAge, | |||
PoolTimeout: o.PoolTimeout, | |||
IdleTimeout: o.IdleTimeout, | |||
IdleCheckFrequency: o.IdleCheckFrequency, | |||
TLSConfig: o.TLSConfig, | |||
} | |||
} | |||
// -------------------------------------------------------------------- | |||
// UniversalClient is an abstract client which - based on the provided options - | |||
// represents either a ClusterClient, a FailoverClient, or a single-node Client. | |||
// This can be useful for testing cluster-specific applications locally or having different | |||
// clients in different environments. | |||
type UniversalClient interface { | |||
Cmdable | |||
Context() context.Context | |||
AddHook(Hook) | |||
Watch(ctx context.Context, fn func(*Tx) error, keys ...string) error | |||
Do(ctx context.Context, args ...interface{}) *Cmd | |||
Process(ctx context.Context, cmd Cmder) error | |||
Subscribe(ctx context.Context, channels ...string) *PubSub | |||
PSubscribe(ctx context.Context, channels ...string) *PubSub | |||
Close() error | |||
PoolStats() *PoolStats | |||
} | |||
var ( | |||
_ UniversalClient = (*Client)(nil) | |||
_ UniversalClient = (*ClusterClient)(nil) | |||
_ UniversalClient = (*Ring)(nil) | |||
) | |||
// NewUniversalClient returns a new multi client. The type of the returned client depends | |||
// on the following conditions: | |||
// | |||
// 1. If the MasterName option is specified, a sentinel-backed FailoverClient is returned. | |||
// 2. if the number of Addrs is two or more, a ClusterClient is returned. | |||
// 3. Otherwise, a single-node Client is returned. | |||
func NewUniversalClient(opts *UniversalOptions) UniversalClient { | |||
if opts.MasterName != "" { | |||
return NewFailoverClient(opts.Failover()) | |||
} else if len(opts.Addrs) > 1 { | |||
return NewClusterClient(opts.Cluster()) | |||
} | |||
return NewClient(opts.Simple()) | |||
} |
@@ -0,0 +1,6 @@ | |||
package redis | |||
// Version is the current release version. | |||
func Version() string { | |||
return "8.11.4" | |||
} |
@@ -0,0 +1,20 @@ | |||
# github.com/BurntSushi/toml v1.0.0 | |||
## explicit; go 1.16 | |||
github.com/BurntSushi/toml | |||
github.com/BurntSushi/toml/internal | |||
# github.com/cespare/xxhash/v2 v2.1.2 | |||
## explicit; go 1.11 | |||
github.com/cespare/xxhash/v2 | |||
# github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f | |||
## explicit | |||
github.com/dgryski/go-rendezvous | |||
# github.com/go-redis/redis/v8 v8.11.4 | |||
## explicit; go 1.13 | |||
github.com/go-redis/redis/v8 | |||
github.com/go-redis/redis/v8/internal | |||
github.com/go-redis/redis/v8/internal/hashtag | |||
github.com/go-redis/redis/v8/internal/hscan | |||
github.com/go-redis/redis/v8/internal/pool | |||
github.com/go-redis/redis/v8/internal/proto | |||
github.com/go-redis/redis/v8/internal/rand | |||
github.com/go-redis/redis/v8/internal/util |