服务端 cmd/server/main.go
package main
import (
"demo/pkg/utils"
"log"
"net"
"sync"
)
var clients = make(map[net.Conn]bool)
var lock sync.RWMutex
func main() {
ln, err := net.Listen("tcp", ":9000")
if err != nil {
log.Fatal(err)
}
for {
conn, err := ln.Accept()
if err != nil {
log.Println("连接错误", err)
continue
}
go handleConnection(conn)
}
}
func handleConnection(conn net.Conn) {
registry(conn)
for {
message, err := utils.ReadMessage(conn)
if err != nil {
log.Println("收信错误", err)
unRegistry(conn)
break
}
log.Println("收到消息", conn.RemoteAddr(), string(message))
go broadcast(message)
}
}
func registry(conn net.Conn) {
lock.Lock()
defer lock.Unlock()
clients[conn] = true
log.Println("注册连接", conn.RemoteAddr())
}
func unRegistry(conn net.Conn) {
lock.Lock()
defer lock.Unlock()
delete(clients, conn)
log.Println("注销连接", conn.RemoteAddr())
}
func broadcast(message []byte) {
lock.RLock()
defer lock.RUnlock()
log.Println("广播消息", string(message))
for conn, _ := range clients {
if err := utils.WriteMessage(conn, message); err != nil {
log.Println("广播错误", err)
go unRegistry(conn)
}
}
}
客户端 cmd/client/main.go
package main
import (
"bufio"
"demo/pkg/utils"
"fmt"
"net"
"os"
)
var prompt = "输入消息:"
func main() {
conn, err := net.Dial("tcp", ":9000")
if err != nil {
fmt.Printf("err:%v", err)
os.Exit(1)
}
fmt.Printf(prompt)
go func() {
for {
message, err := utils.ReadMessage(conn)
if err != nil {
fmt.Printf("\r收信错误:%v\n", err)
os.Exit(1)
}
fmt.Printf("\r收到消息:%s\n%s", string(message), prompt)
}
}()
scanner := bufio.NewScanner(os.Stdin)
for {
scanner.Scan()
text := scanner.Text()
if err := utils.WriteMessage(conn, []byte(text)); err != nil {
fmt.Printf("\r发信错误:%v", err)
os.Exit(1)
}
}
}
流处理 pkg/utils/net.go
package utils
import (
"encoding/binary"
"errors"
"io"
"net"
)
const headerSize = 4
const messageMaxLen = 32 * 1024
var MessageTooLong = errors.New("message too long")
func ReadMessage(conn net.Conn) ([]byte, error) {
header := make([]byte, headerSize)
if _, err := io.ReadFull(conn, header); err != nil {
return nil, err
}
size := binary.BigEndian.Uint32(header)
if size > messageMaxLen {
return nil, MessageTooLong
}
body := make([]byte, size)
if _, err := io.ReadFull(conn, body); err != nil {
return nil, err
}
return body, nil
}
func WriteMessage(conn net.Conn, data []byte) error {
size := len(data)
if size > messageMaxLen {
return MessageTooLong
}
buf := make([]byte, headerSize+size)
binary.BigEndian.PutUint32(buf[:headerSize], uint32(size))
copy(buf[headerSize:], data)
_, err := conn.Write(buf)
return err
}