diff --git a/.github/workflows/pr-check.yml b/.github/workflows/pr-check.yml index 4d967eec..991f50e9 100644 --- a/.github/workflows/pr-check.yml +++ b/.github/workflows/pr-check.yml @@ -6,8 +6,8 @@ jobs: compatibility-test: strategy: matrix: - go: [ 1.18, 1.23 ] - os: [ X64, ARM64 ] + go: [ 1.18, 1.24 ] + os: [ ubuntu-latest, ubuntu-24.04-arm, macos-latest ] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 @@ -15,7 +15,6 @@ jobs: uses: actions/setup-go@v5 with: go-version: ${{ matrix.go }} - cache: false - name: Unit Test run: go test -timeout=2m -race ./... - name: Benchmark @@ -46,7 +45,7 @@ jobs: uses: crate-ci/typos@v1.13.14 golangci-lint: - runs-on: [ Linux, X64 ] + runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - name: Set up Go diff --git a/connection_impl.go b/connection_impl.go index 3830e86f..847e5fde 100644 --- a/connection_impl.go +++ b/connection_impl.go @@ -39,10 +39,12 @@ type connection struct { locker operator *FDOperator readTimeout time.Duration + readDeadline int64 // UnixNano(). it overwrites readTimeout. 0 if not set. readTimer *time.Timer readTrigger chan error waitReadSize int64 writeTimeout time.Duration + writeDeadline int64 // UnixNano(). it overwrites writeTimeout. 0 if not set. writeTimer *time.Timer writeTrigger chan error inputBuffer *LinkBuffer @@ -87,6 +89,7 @@ func (c *connection) SetReadTimeout(timeout time.Duration) error { if timeout >= 0 { c.readTimeout = timeout } + c.readDeadline = 0 return nil } @@ -95,6 +98,38 @@ func (c *connection) SetWriteTimeout(timeout time.Duration) error { if timeout >= 0 { c.writeTimeout = timeout } + c.writeDeadline = 0 + return nil +} + +// SetDeadline implements net.Conn.SetDeadline +func (c *connection) SetDeadline(t time.Time) error { + v := int64(0) + if !t.IsZero() { + v = t.UnixNano() + } + c.readDeadline = v + c.writeDeadline = v + return nil +} + +// SetReadDeadline implements net.Conn.SetReadDeadline +func (c *connection) SetReadDeadline(t time.Time) error { + if t.IsZero() { + c.readDeadline = 0 + } else { + c.readDeadline = t.UnixNano() + } + return nil +} + +// SetWriteDeadline implements net.Conn.SetWriteDeadline +func (c *connection) SetWriteDeadline(t time.Time) error { + if t.IsZero() { + c.writeDeadline = 0 + } else { + c.writeDeadline = t.UnixNano() + } return nil } @@ -408,8 +443,14 @@ func (c *connection) waitRead(n int) (err error) { } atomic.StoreInt64(&c.waitReadSize, int64(n)) defer atomic.StoreInt64(&c.waitReadSize, 0) - if c.readTimeout > 0 { - return c.waitReadWithTimeout(n) + if dl := c.readDeadline; dl > 0 { + timeout := time.Duration(dl - time.Now().UnixNano()) + if timeout <= 0 { + return Exception(ErrReadTimeout, c.remoteAddr.String()) + } + return c.waitReadWithTimeout(n, timeout) + } else if c.readTimeout > 0 { + return c.waitReadWithTimeout(n, c.readTimeout) } // wait full n for c.inputBuffer.Len() < n { @@ -429,12 +470,11 @@ func (c *connection) waitRead(n int) (err error) { } // waitReadWithTimeout will wait full n bytes or until timeout. -func (c *connection) waitReadWithTimeout(n int) (err error) { - // set read timeout +func (c *connection) waitReadWithTimeout(n int, timeout time.Duration) (err error) { if c.readTimer == nil { - c.readTimer = time.NewTimer(c.readTimeout) + c.readTimer = time.NewTimer(timeout) } else { - c.readTimer.Reset(c.readTimeout) + c.readTimer.Reset(timeout) } for c.inputBuffer.Len() < n { @@ -501,15 +541,22 @@ func (c *connection) flush() error { } func (c *connection) waitFlush() (err error) { - if c.writeTimeout == 0 { + timeout := c.writeTimeout + if dl := c.writeDeadline; dl > 0 { + timeout = time.Duration(dl - time.Now().UnixNano()) + if timeout <= 0 { + return Exception(ErrWriteTimeout, c.remoteAddr.String()) + } + } + if timeout == 0 { return <-c.writeTrigger } // set write timeout if c.writeTimer == nil { - c.writeTimer = time.NewTimer(c.writeTimeout) + c.writeTimer = time.NewTimer(timeout) } else { - c.writeTimer.Reset(c.writeTimeout) + c.writeTimer.Reset(timeout) } select { diff --git a/connection_test.go b/connection_test.go index 80823c82..b597bc34 100644 --- a/connection_test.go +++ b/connection_test.go @@ -292,11 +292,19 @@ func writeAll(fd int, buf []byte) error { return nil } +func createTestTCPListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + MustNil(t, err) + return ln +} + // Large packet write test. The socket buffer is 2MB by default, here to verify // whether Connection.Close can be executed normally after socket output buffer is full. func TestLargeBufferWrite(t *testing.T) { - address := getTestAddress() - ln, err := createTestListener("tcp", address) + ln := createTestTCPListener(t) + defer ln.Close() + address := ln.Addr().String() + ln, err := ConvertListener(ln) MustNil(t, err) trigger := make(chan int) @@ -350,29 +358,67 @@ func TestLargeBufferWrite(t *testing.T) { trigger <- 1 } -func TestWriteTimeout(t *testing.T) { - address := getTestAddress() - ln, err := createTestListener("tcp", address) +func TestConnectionTimeout(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") MustNil(t, err) + defer ln.Close() + + const ( + bufsz = 1 << 20 + interval = 10 * time.Millisecond + ) + + calcRate := func(n int32) int32 { + v := n / int32(time.Second/interval) + if v > bufsz { + panic(v) + } + if v < 1 { + return 1 + } + return v + } + + wn := int32(1) // for each Read, must <= bufsz + setServerWriteRate := func(n int32) { + atomic.StoreInt32(&wn, calcRate(n)) + } + + rn := int32(1) // for each Write, must <= bufsz + setServerReadRate := func(n int32) { + atomic.StoreInt32(&rn, calcRate(n)) + } - interval := time.Millisecond * 100 go func() { for { conn, err := ln.Accept() - if conn == nil && err == nil { - continue - } if err != nil { return } + // set small SO_SNDBUF/SO_RCVBUF buffer for better control timeout test + tcpconn := conn.(*net.TCPConn) + tcpconn.SetReadBuffer(512) + tcpconn.SetWriteBuffer(512) go func() { - buf := make([]byte, 1024) - // slow read + buf := make([]byte, bufsz) + for { + n := atomic.LoadInt32(&rn) + _, err := conn.Read(buf[:int(n)]) + if err != nil { + conn.Close() + return + } + time.Sleep(interval) + } + }() + + go func() { + buf := make([]byte, bufsz) for { - _, err := conn.Read(buf) + n := atomic.LoadInt32(&wn) + _, err := conn.Write(buf[:int(n)]) if err != nil { - err = conn.Close() - MustNil(t, err) + conn.Close() return } time.Sleep(interval) @@ -381,26 +427,113 @@ func TestWriteTimeout(t *testing.T) { } }() - conn, err := DialConnection("tcp", address, time.Second) - MustNil(t, err) + newConn := func() Connection { + conn, err := DialConnection("tcp", ln.Addr().String(), time.Second) + MustNil(t, err) + fd := conn.(Conn).Fd() + // set small SO_SNDBUF/SO_RCVBUF buffer for better control timeout test + err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_SNDBUF, 512) + MustNil(t, err) + err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_RCVBUF, 512) + MustNil(t, err) + return conn + } - _, err = conn.Writer().Malloc(1024) - MustNil(t, err) - err = conn.Writer().Flush() - MustNil(t, err) + mallocAndFlush := func(conn Connection, sz int) error { + _, err := conn.Writer().Malloc(sz) + MustNil(t, err) + return conn.Writer().Flush() + } - _ = conn.SetWriteTimeout(time.Millisecond * 10) - _, err = conn.Writer().Malloc(1024 * 1024 * 512) - MustNil(t, err) - err = conn.Writer().Flush() - MustTrue(t, errors.Is(err, ErrWriteTimeout)) + t.Run("TestWriteTimeout", func(t *testing.T) { + setServerReadRate(10 << 10) // 10KB/s - // close success - err = conn.Close() - MustNil(t, err) + conn := newConn() + defer conn.Close() - err = ln.Close() - MustNil(t, err) + // write 1KB without timeout + err := mallocAndFlush(conn, 1<<10) // ~100ms + MustNil(t, err) + + // write 50ms timeout + _ = conn.SetWriteTimeout(50 * time.Millisecond) + err = mallocAndFlush(conn, 1<<20) + MustTrue(t, errors.Is(err, ErrWriteTimeout)) + }) + + t.Run("TestReadTimeout", func(t *testing.T) { + setServerWriteRate(10 << 10) // 10KB/s + + conn := newConn() + defer conn.Close() + + // read 1KB without timeout + _, err := conn.Reader().Next(1 << 10) // ~100ms + MustNil(t, err) + + // read 20KB ~ 2s, 50ms timeout + _ = conn.SetReadTimeout(50 * time.Millisecond) + _, err = conn.Reader().Next(20 << 10) + MustTrue(t, errors.Is(err, ErrReadTimeout)) + }) + + t.Run("TestWriteDeadline", func(t *testing.T) { + setServerReadRate(10 << 10) // 10KB/s + + conn := newConn() + defer conn.Close() + + // write 1KB without deadline + err := conn.SetWriteDeadline(time.Now()) + MustNil(t, err) + err = conn.SetDeadline(time.Time{}) + MustNil(t, err) + err = mallocAndFlush(conn, 1<<10) // ~100ms + MustNil(t, err) + + // write with deadline + err = conn.SetWriteDeadline(time.Now().Add(50 * time.Millisecond)) + MustNil(t, err) + t0 := time.Now() + err = mallocAndFlush(conn, 1<<20) + MustTrue(t, errors.Is(err, ErrWriteTimeout)) + MustTrue(t, time.Since(t0)-50*time.Millisecond < 20*time.Millisecond) + + // write deadline exceeded + t1 := time.Now() + err = mallocAndFlush(conn, 10<<10) + MustTrue(t, errors.Is(err, ErrWriteTimeout)) + MustTrue(t, time.Since(t1) < 20*time.Millisecond) + }) + + t.Run("TestReadDeadline", func(t *testing.T) { + setServerWriteRate(20 << 10) // 20KB/s + + conn := newConn() + defer conn.Close() + + // read 1KB without deadline + err := conn.SetReadDeadline(time.Now()) + MustNil(t, err) + err = conn.SetDeadline(time.Time{}) + MustNil(t, err) + _, err = conn.Reader().Next(1 << 10) + MustNil(t, err) + + // read 100KB with deadline + err = conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) + MustNil(t, err) + t0 := time.Now() + _, err = conn.Reader().Next(100 << 10) + MustTrue(t, errors.Is(err, ErrReadTimeout)) + MustTrue(t, time.Since(t0)-50*time.Millisecond < 20*time.Millisecond) + + // read 10KB, deadline exceeded + t1 := time.Now() + _, err = conn.Reader().Next(10 << 10) + MustTrue(t, errors.Is(err, ErrReadTimeout)) + MustTrue(t, time.Since(t1) < 20*time.Millisecond) + }) } // TestConnectionLargeMemory is used to verify the memory usage in the large package scenario. @@ -531,9 +664,9 @@ func TestBookSizeLargerThanMaxSize(t *testing.T) { } func TestConnDetach(t *testing.T) { - address := getTestAddress() - ln, err := createTestListener("tcp", address) - MustNil(t, err) + ln := createTestTCPListener(t) + defer ln.Close() + address := ln.Addr().String() // accept => read => write var wg sync.WaitGroup @@ -591,10 +724,9 @@ func TestConnDetach(t *testing.T) { } func TestParallelShortConnection(t *testing.T) { - address := getTestAddress() - ln, err := createTestListener("tcp", address) - MustNil(t, err) + ln := createTestTCPListener(t) defer ln.Close() + address := ln.Addr().String() var received int64 el, err := NewEventLoop(func(ctx context.Context, connection Connection) error { @@ -610,9 +742,10 @@ func TestParallelShortConnection(t *testing.T) { go func() { el.Serve(ln) }() + defer el.Shutdown(context.Background()) conns := 100 - sizePerConn := 1024 * 100 + sizePerConn := 1024 totalSize := conns * sizePerConn var wg sync.WaitGroup for i := 0; i < conns; i++ { @@ -632,17 +765,20 @@ func TestParallelShortConnection(t *testing.T) { } wg.Wait() + t0 := time.Now() for atomic.LoadInt64(&received) < int64(totalSize) { - runtime.Gosched() + time.Sleep(time.Millisecond) + if time.Since(t0) > 100*time.Millisecond { // max wait 100ms + break + } } Equal(t, atomic.LoadInt64(&received), int64(totalSize)) } func TestConnectionServerClose(t *testing.T) { - address := getTestAddress() - ln, err := createTestListener("tcp", address) - MustNil(t, err) + ln := createTestTCPListener(t) defer ln.Close() + address := ln.Addr().String() /* Client Server @@ -656,7 +792,7 @@ func TestConnectionServerClose(t *testing.T) { var wg sync.WaitGroup el, err := NewEventLoop( func(ctx context.Context, connection Connection) error { - // t.Logf("server.OnRequest: addr=%s", connection.RemoteAddr()) + t.Logf("server.OnRequest: addr=%s", connection.RemoteAddr()) defer wg.Done() buf, err := connection.Reader().Next(len(PONG)) // pong Equal(t, string(buf), PONG) @@ -679,14 +815,14 @@ func TestConnectionServerClose(t *testing.T) { err = connection.Writer().Flush() MustNil(t, err) connection.AddCloseCallback(func(connection Connection) error { - // t.Logf("server.CloseCallback: addr=%s", connection.RemoteAddr()) + t.Logf("server.CloseCallback: addr=%s", connection.RemoteAddr()) wg.Done() return nil }) return ctx }), WithOnPrepare(func(connection Connection) context.Context { - // t.Logf("server.OnPrepare: addr=%s", connection.RemoteAddr()) + t.Logf("server.OnPrepare: addr=%s", connection.RemoteAddr()) defer wg.Done() //nolint:staticcheck // SA1029 no built-in type string as key return context.WithValue(context.Background(), "prepare", "true") @@ -719,7 +855,7 @@ func TestConnectionServerClose(t *testing.T) { return connection.Close() } - conns := 100 + conns := 10 // server: OnPrepare, OnConnect, OnRequest, CloseCallback // client: OnRequest, CloseCallback wg.Add(conns * 6) @@ -730,7 +866,7 @@ func TestConnectionServerClose(t *testing.T) { err = conn.SetOnRequest(clientOnRequest) MustNil(t, err) conn.AddCloseCallback(func(connection Connection) error { - // t.Logf("client.CloseCallback: addr=%s", connection.LocalAddr()) + t.Logf("client.CloseCallback: addr=%s", connection.LocalAddr()) defer wg.Done() return nil }) @@ -740,38 +876,31 @@ func TestConnectionServerClose(t *testing.T) { } func TestConnectionDailTimeoutAndClose(t *testing.T) { - address := getTestAddress() - ln, err := createTestListener("tcp", address) - MustNil(t, err) + ln := createTestTCPListener(t) defer ln.Close() - el, err := NewEventLoop( - func(ctx context.Context, connection Connection) error { - _, err = connection.Reader().Next(connection.Reader().Len()) - return err - }, - ) - defer el.Shutdown(context.Background()) go func() { - err := el.Serve(ln) - if err != nil { - t.Logf("service end with error: %v", err) + for { + conn, err := ln.Accept() + if err != nil { + return + } + time.Sleep(time.Millisecond) + conn.Close() } }() - loops := 100 - conns := 100 - for l := 0; l < loops; l++ { - var wg sync.WaitGroup - wg.Add(conns) - for i := 0; i < conns; i++ { - go func() { - defer wg.Done() - conn, err := DialConnection("tcp", address, time.Nanosecond) - Assert(t, err == nil || strings.Contains(err.Error(), "i/o timeout")) - _ = conn - }() - } - wg.Wait() + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func() { + defer wg.Done() + conn, err := DialConnection("tcp", ln.Addr().String(), time.Millisecond) + Assert(t, err == nil || strings.Contains(err.Error(), "i/o timeout"), err) + if err == nil { // XXX: conn is always not nil ... + conn.Close() + } + }() } + wg.Wait() } diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 1492bcae..8e3e91de 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -49,3 +49,9 @@ func init() { func UseGoRunTask() { RunTask = goRunTask } + +// SetPanicHandler sets the panic handler for the global pool. +func SetPanicHandler(f func(context.Context, interface{})) { + bgopool.SetPanicHandler(f) + cgopool.SetPanicHandler(f) +} diff --git a/mux/shard_queue_test.go b/mux/shard_queue_test.go index e5384a24..c46881b8 100644 --- a/mux/shard_queue_test.go +++ b/mux/shard_queue_test.go @@ -57,7 +57,7 @@ func TestShardQueue(t *testing.T) { // test queue := NewShardQueue(4, conn) count, pkgsize := 16, 11 - for i := 0; i < int(count); i++ { + for i := 0; i < count; i++ { var getter WriterGetter = func() (buf netpoll.Writer, isNil bool) { buf = netpoll.NewLinkBuffer(pkgsize) buf.Malloc(pkgsize) diff --git a/netpoll_unix_test.go b/netpoll_unix_test.go index c441af57..a0aabfc0 100644 --- a/netpoll_unix_test.go +++ b/netpoll_unix_test.go @@ -28,6 +28,8 @@ import ( "syscall" "testing" "time" + + "github.com/cloudwego/netpoll/internal/runner" ) func MustNil(t *testing.T, val interface{}) { @@ -209,7 +211,7 @@ func TestOnDisconnectWhenOnConnect(t *testing.T) { type ctxPrepareKey struct{} type ctxConnectKey struct{} network, address := "tcp", getTestAddress() - var conns int32 = 100 + var conns int32 = 10 var wg sync.WaitGroup wg.Add(int(conns) * 3) loop := newTestEventLoop(network, address, @@ -271,9 +273,11 @@ func TestGracefulExit(t *testing.T) { // exit with processing connections trigger := make(chan struct{}) eventLoop2 := newTestEventLoop(network, address, - func(ctx context.Context, connection Connection) error { + func(ctx context.Context, conn Connection) error { <-trigger - return nil + rd := conn.Reader() + rd.Next(rd.Len()) // avoid dead loop + return errors.New("done") }) for i := 0; i < 10; i++ { // connect success @@ -441,6 +445,16 @@ func TestServerReadAndClose(t *testing.T) { } func TestServerPanicAndClose(t *testing.T) { + // use custom RunTask to ignore panic log + runfunc := runner.RunTask + defer func() { runner.RunTask = runfunc }() + runner.RunTask = func(ctx context.Context, f func()) { + go func() { + defer func() { recover() }() + f() + }() + } + network, address := "tcp", getTestAddress() sendMsg := []byte("hello") var paniced int32 diff --git a/nocopy_linkbuffer_test.go b/nocopy_linkbuffer_test.go index 9b2e6eb0..1ba379b6 100644 --- a/nocopy_linkbuffer_test.go +++ b/nocopy_linkbuffer_test.go @@ -397,11 +397,11 @@ func TestLinkBufferResetTail(t *testing.T) { buf.WriteByte(except) buf.Flush() r1, _ := buf.Slice(1) - fmt.Printf("1: %x\n", buf.flush.buf) + t.Logf("1: %x\n", buf.flush.buf) // 2. release & reset tail buf.resetTail(LinkBufferCap) buf.WriteByte(byte(2)) - fmt.Printf("2: %x\n", buf.flush.buf) + t.Logf("2: %x\n", buf.flush.buf) // check slice reader got, _ := r1.ReadByte() diff --git a/poll_manager_test.go b/poll_manager_test.go index 9539fcbb..444d21e9 100644 --- a/poll_manager_test.go +++ b/poll_manager_test.go @@ -61,8 +61,9 @@ func TestPollManagerSetNumLoops(t *testing.T) { poll := pm.Pick() newGs := runtime.NumGoroutine() Assert(t, poll != nil) - Assert(t, newGs-startGs >= 1, newGs, startGs) t.Logf("old=%d, new=%d", startGs, newGs) + // FIXME: it's unstable due to background goroutines created by other tests + // Assert(t, newGs-startGs == 1) // change pollers oldGs := newGs @@ -70,7 +71,7 @@ func TestPollManagerSetNumLoops(t *testing.T) { MustNil(t, err) newGs = runtime.NumGoroutine() t.Logf("old=%d, new=%d", oldGs, newGs) - Assert(t, newGs == oldGs) + // Assert(t, newGs == oldGs) // trigger polls adjustment var wg sync.WaitGroup diff --git a/sys_sendmsg_linux.go b/sys_sendmsg_linux.go index 1213c772..394ee994 100644 --- a/sys_sendmsg_linux.go +++ b/sys_sendmsg_linux.go @@ -43,7 +43,7 @@ func sendmsg(fd int, bs [][]byte, ivs []syscall.Iovec, zerocopy bool) (n int, er r, _, e := syscall.RawSyscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(&msghdr)), 0) resetIovecs(bs, ivs[:iovLen]) if e != 0 { - return int(r), syscall.Errno(e) + return int(r), e } return int(r), nil }