golang socket

服务端 cmd/server/main.go

package main

import (
    "demo/pkg/utils"
    "log"
    "net"
    "sync"
)

var clients = make(map[net.Conn]bool)
var lock sync.RWMutex

func main() {
    ln, err := net.Listen("tcp", ":9000")
    if err != nil {
        log.Fatal(err)
    }

    for {
        conn, err := ln.Accept()
        if err != nil {
            log.Println("连接错误", err)
            continue
        }

        go handleConnection(conn)
    }
}

func handleConnection(conn net.Conn) {
    registry(conn)

    for {
        message, err := utils.ReadMessage(conn)
        if err != nil {
            log.Println("收信错误", err)
            unRegistry(conn)
            break
        }
        log.Println("收到消息", conn.RemoteAddr(), string(message))

        go broadcast(message)
    }
}

func registry(conn net.Conn) {
    lock.Lock()
    defer lock.Unlock()
    clients[conn] = true
    log.Println("注册连接", conn.RemoteAddr())
}

func unRegistry(conn net.Conn) {
    lock.Lock()
    defer lock.Unlock()
    delete(clients, conn)
    log.Println("注销连接", conn.RemoteAddr())
}

func broadcast(message []byte) {
    lock.RLock()
    defer lock.RUnlock()
    log.Println("广播消息", string(message))

    for conn, _ := range clients {
        if err := utils.WriteMessage(conn, message); err != nil {
            log.Println("广播错误", err)
            go unRegistry(conn)
        }
    }
}

客户端 cmd/client/main.go

package main

import (
    "bufio"
    "demo/pkg/utils"
    "fmt"
    "net"
    "os"
)

var prompt = "输入消息:"

func main() {
    conn, err := net.Dial("tcp", ":9000")
    if err != nil {
        fmt.Printf("err:%v", err)
        os.Exit(1)
    }

    fmt.Printf(prompt)

    go func() {
        for {
            message, err := utils.ReadMessage(conn)
            if err != nil {
                fmt.Printf("\r收信错误:%v\n", err)
                os.Exit(1)
            }
            fmt.Printf("\r收到消息:%s\n%s", string(message), prompt)
        }
    }()

    scanner := bufio.NewScanner(os.Stdin)
    for {
        scanner.Scan()
        text := scanner.Text()
        if err := utils.WriteMessage(conn, []byte(text)); err != nil {
            fmt.Printf("\r发信错误:%v", err)
            os.Exit(1)
        }
    }
}

流处理 pkg/utils/net.go

package utils

import (
    "encoding/binary"
    "errors"
    "io"
    "net"
)

const headerSize = 4

const messageMaxLen = 32 * 1024

var MessageTooLong = errors.New("message too long")

func ReadMessage(conn net.Conn) ([]byte, error) {
    header := make([]byte, headerSize)
    if _, err := io.ReadFull(conn, header); err != nil {
        return nil, err
    }

    size := binary.BigEndian.Uint32(header)
    if size > messageMaxLen {
        return nil, MessageTooLong
    }

    body := make([]byte, size)
    if _, err := io.ReadFull(conn, body); err != nil {
        return nil, err
    }

    return body, nil
}

func WriteMessage(conn net.Conn, data []byte) error {
    size := len(data)
    if size > messageMaxLen {
        return MessageTooLong
    }

    buf := make([]byte, headerSize+size)
    binary.BigEndian.PutUint32(buf[:headerSize], uint32(size))

    copy(buf[headerSize:], data)

    _, err := conn.Write(buf)
    return err
}