123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229 |
- package NTP
- import (
- "encoding/binary"
- "fmt"
- "log"
- "net"
- "sync"
- "time"
- )
- const (
- STANDARD_PACKET_SIZE = 48 // 标准NTP的报文大小
- )
- type NTPServer struct {
- srvAddress string
- conn *net.UDPConn
- wait sync.WaitGroup
- ntpPack NtpPacket // NTP协议报文
- requestCount uint64 // 请求计数
- }
- type NtpPacket struct {
- /*
- LI: 2bit 00 Leap Indicator(0)
- VN: 3bit 100 NTP Version(4)
- Mode: 3bit 100 Mode: server(4), client(3)
- */
- Header uint8 // 报文头: 包含LI、VN、Mode
- Stratum uint8 // Peer Clock Stratum: primary reference (1)
- Poll uint8 // Peer Polling Interval: invalid (0)
- Precision uint8 // Peer Clock Precision: 0.000000 seconds
- RootDelay uint32 // Root Delay
- RootDisp uint32 // Root Dispersion
- RefID uint32 // Reference Identifier
- RefTS uint64 // Reference Timestamp 参考时间戳
- OrigTS uint64 // Originate Timestamp 起始时间戳
- RecvTS uint64 // Receive Timestamp 接收时间戳
- TransTS uint64 // Transmit Timestamp 传输时间戳
- }
- func (srv *NTPServer) NewNtpPacket() *NtpPacket {
- // 初始化Header字段
- header := uint8(0)
- header |= (0 << 6) // LI: 2bit 00
- header |= (4 << 3) // VN: 3bit 100
- header |= (4 << 0) // Mode: 3bit 100
- // 创建新的NtpPacket实例
- packet := &NtpPacket{
- Header: header,
- Stratum: 0x01,
- Poll: 0x00,
- Precision: 0x00,
- RootDelay: 0,
- RootDisp: 0,
- RefID: 0,
- RefTS: 0,
- OrigTS: 0,
- RecvTS: 0,
- TransTS: 0,
- }
- return packet
- }
- func (pack *NtpPacket) SetTimestamp(timestamp time.Time, field string) {
- ntpTime := ToNTPTime(timestamp)
- switch field {
- case "RefTS":
- pack.RefTS = ntpTime
- case "OrigTS":
- pack.OrigTS = ntpTime
- case "RecvTS":
- pack.RecvTS = ntpTime
- case "TransTS":
- pack.TransTS = ntpTime
- }
- }
- // toNTPTime 将Unix时间转换为NTP时间
- func ToNTPTime(t time.Time) uint64 {
- seconds := uint32(t.Unix()) + 2208988800 // NTP时间从1900年开始计算
- fraction := uint32(float64(t.Nanosecond()) * (1 << 32) / 1e9)
- return uint64(seconds)<<32 | uint64(fraction)
- }
- func NewNTPServer(srvAddr string) *NTPServer {
- return &NTPServer{srvAddress: srvAddr}
- }
- // 启动NTP服务器
- func (srv *NTPServer) Start() error {
- addr, err := net.ResolveUDPAddr("udp", srv.srvAddress)
- if err != nil {
- return err
- }
- log.Println(fmt.Sprintf("<%s:%d>", addr.IP.String(), addr.Port))
- conn, err := net.ListenUDP("udp", addr)
- if err != nil {
- return err
- }
- srv.wait.Add(1)
- srv.conn = conn
- go RecvMsg(srv)
- return nil
- }
- // 关闭NTP服务器
- func (srv *NTPServer) Stop() {
- srv.conn.Close()
- srv.wait.Wait()
- }
- // 接收数据
- func RecvMsg(srv *NTPServer) {
- defer srv.wait.Done()
- buffer := make([]byte, 2*1024)
- for {
- n, remoteAddr, err := srv.conn.ReadFromUDP(buffer[0:])
- if err != nil {
- fmt.Println("ReadFromUDP error:", err)
- return
- }
- log.Println(fmt.Sprintf("[Recv] %d bytes from <%s>", n, remoteAddr.String()))
- if n != STANDARD_PACKET_SIZE {
- continue
- }
- // 接收到NTP客户端消息的时间
- recvMsgTime := time.Now().UTC()
- recvHexString := BytesToHex(buffer[:n])
- log.Println(fmt.Sprintf("[Recv] %s", recvHexString))
- udpPacket, err := ParseUDPPacket(buffer[:n])
- if err != nil {
- log.Printf("Error parsing UDP packet: %v", err)
- continue
- }
- ntpPack := srv.NewNtpPacket()
- ntpPack.SetTimestamp(time.Now().UTC(), "RefTS")
- ntpPack.OrigTS = udpPacket.TransTS
- ntpPack.SetTimestamp(recvMsgTime, "RecvTS")
- ntpPack.SetTimestamp(time.Now().UTC(), "TransTS")
- sendPacket := ntpPack.Serialize()
- sendLen, err := srv.conn.WriteToUDP(sendPacket, remoteAddr)
- if err != nil {
- log.Println(err.Error())
- continue
- }
- if sendLen > 0 {
- log.Println(fmt.Sprintf("[Send] %s", BytesToHex(sendPacket)))
- }
- srv.requestCount++
- }
- }
- func (pack *NtpPacket) Serialize() []byte {
- packet := make([]byte, 48)
- // binary.BigEndian.PutUint32(packet[0:4], pack.Header)
- packet[0] = pack.Header
- packet[1] = pack.Stratum
- packet[2] = pack.Poll
- packet[3] = pack.Precision
- binary.BigEndian.PutUint32(packet[4:8], pack.RootDelay)
- binary.BigEndian.PutUint32(packet[8:12], pack.RootDisp)
- binary.BigEndian.PutUint32(packet[12:16], pack.RefID)
- binary.BigEndian.PutUint64(packet[16:24], pack.RefTS)
- binary.BigEndian.PutUint64(packet[24:32], pack.OrigTS)
- binary.BigEndian.PutUint64(packet[32:40], pack.RecvTS)
- binary.BigEndian.PutUint64(packet[40:48], pack.TransTS)
- return packet
- }
- // BytesToHex 将字节数组转换为16进制字符串
- func BytesToHex(data []byte) string {
- hexString := make([]byte, 3*len(data)-1)
- for i, b := range data {
- high := "0123456789ABCDEF"[(b >> 4)]
- low := "0123456789ABCDEF"[(b & 0x0F)]
- hexString[i*3] = high
- hexString[i*3+1] = low
- if i < len(data)-1 {
- hexString[i*3+2] = ' ' // 每个16进制数据之间加空格
- }
- }
- return string(hexString)
- }
- func ParseUDPPacket(buf []byte) (*NtpPacket, error) {
- if len(buf) < STANDARD_PACKET_SIZE { // 最小有效长度为48字节
- return nil, fmt.Errorf("Invalid UDP packet length: %d", len(buf))
- }
- packet := &NtpPacket{
- // Header: binary.BigEndian.Uint32(buf[0:4]),
- Header: buf[0],
- Stratum: buf[1],
- Poll: buf[2],
- Precision: buf[3],
- RootDelay: binary.BigEndian.Uint32(buf[4:8]),
- RootDisp: binary.BigEndian.Uint32(buf[8:12]),
- RefID: binary.BigEndian.Uint32(buf[12:16]),
- RefTS: binary.BigEndian.Uint64(buf[16:24]),
- OrigTS: binary.BigEndian.Uint64(buf[24:32]),
- RecvTS: binary.BigEndian.Uint64(buf[32:40]),
- TransTS: binary.BigEndian.Uint64(buf[40:48]),
- }
- return packet, nil
- }
|