diff --git a/format/webrtc/adapter.go b/format/webrtc/adapter.go index 7992b3d..96e1a0a 100644 --- a/format/webrtc/adapter.go +++ b/format/webrtc/adapter.go @@ -1,6 +1,7 @@ package webrtc import ( + "bytes" "encoding/base64" "errors" "fmt" @@ -14,9 +15,20 @@ import ( "github.com/pion/webrtc/v2/pkg/media" ) +var ( + ErrorNotFound = errors.New("stream not found") + ErrorCodecNotSupported = errors.New("codec not supported") + ErrorClientOffline = errors.New("client offline") + Label = "track_" +) + type Muxer struct { - streams map[int8]*Stream - Connected bool + streams map[int8]*Stream + status webrtc.ICEConnectionState + stop bool + pc *webrtc.PeerConnection + pt *time.Timer + ps chan bool } type Stream struct { codec av.CodecData @@ -24,12 +36,14 @@ type Stream struct { } func NewMuxer() *Muxer { - return &Muxer{streams: make(map[int8]*Stream)} + tmp := Muxer{ps: make(chan bool, 100), pt: time.NewTimer(time.Second * 20), streams: make(map[int8]*Stream)} + go tmp.WaitCloser() + return &tmp } -func (self *Muxer) WriteHeader(streams []av.CodecData, sdp64 string) (string, error) { +func (element *Muxer) WriteHeader(streams []av.CodecData, sdp64 string) (string, error) { if len(streams) == 0 { - return "", errors.New("No Stream Forund") + return "", ErrorNotFound } mediaEngine := webrtc.MediaEngine{} sdpB, err := base64.StdEncoding.DecodeString(sdp64) @@ -54,21 +68,15 @@ func (self *Muxer) WriteHeader(streams []av.CodecData, sdp64 string) (string, er if err != nil { return "", err } - timer1 := time.NewTimer(time.Second * 2) - peerConnection.OnDataChannel(func(d *webrtc.DataChannel) { - d.OnMessage(func(msg webrtc.DataChannelMessage) { - timer1.Reset(2 * time.Second) - }) - }) for i, i2 := range streams { var track *webrtc.Track if i2.Type().IsVideo() { - track, err = peerConnection.NewTrack(getPayloadType(mediaEngine, webrtc.RTPCodecTypeVideo, i2.Type().String()), rand.Uint32(), "video", "pion") + track, err = peerConnection.NewTrack(getPayloadType(mediaEngine, webrtc.RTPCodecTypeVideo, i2.Type().String()), rand.Uint32(), "video", Label) if err != nil { return "", err } } else if i2.Type().IsAudio() { - track, err = peerConnection.NewTrack(getPayloadType(mediaEngine, webrtc.RTPCodecTypeAudio, i2.Type().String()), rand.Uint32(), "audio", "pion") + track, err = peerConnection.NewTrack(getPayloadType(mediaEngine, webrtc.RTPCodecTypeAudio, i2.Type().String()), rand.Uint32(), "audio", Label) if err != nil { return "", err } @@ -85,14 +93,19 @@ func (self *Muxer) WriteHeader(streams []av.CodecData, sdp64 string) (string, er if err != nil { return "", err } - self.streams[int8(i)] = &Stream{track: track, codec: i2} + element.streams[int8(i)] = &Stream{track: track, codec: i2} } peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { - fmt.Printf("Connection State has changed %s \n", connectionState.String()) - if connectionState == webrtc.ICEConnectionStateConnected { - self.Connected = true + element.status = connectionState + if connectionState == webrtc.ICEConnectionStateDisconnected { + element.ps <- true } }) + peerConnection.OnDataChannel(func(d *webrtc.DataChannel) { + d.OnMessage(func(msg webrtc.DataChannelMessage) { + element.pt.Reset(5 * time.Second) + }) + }) if err = peerConnection.SetRemoteDescription(offer); err != nil { return "", err } @@ -103,26 +116,52 @@ func (self *Muxer) WriteHeader(streams []av.CodecData, sdp64 string) (string, er if err = peerConnection.SetLocalDescription(answer); err != nil { return "", err } + element.pc = peerConnection return base64.StdEncoding.EncodeToString([]byte(answer.SDP)), nil } -func (self *Muxer) WritePacket(pkt av.Packet) (err error) { - if tmp, ok := self.streams[pkt.Idx]; ok { +func (element *Muxer) WritePacket(pkt av.Packet) (err error) { + if element.stop { + return ErrorClientOffline + } + if element.status != webrtc.ICEConnectionStateConnected { + return nil + } + if tmp, ok := element.streams[pkt.Idx]; ok { switch tmp.codec.Type() { case av.H264: codec := tmp.codec.(h264parser.CodecData) if pkt.IsKeyFrame { - pkt.Data = append([]byte("\000\000\001"+string(codec.SPS())+"\000\000\001"+string(codec.PPS())+"\000\000\001"), pkt.Data[4:]...) - + pkt.Data = append([]byte{0, 0, 0, 1}, bytes.Join([][]byte{codec.SPS(), codec.PPS(), pkt.Data[4:]}, []byte{0, 0, 0, 1})...) } else { pkt.Data = pkt.Data[4:] } return tmp.track.WriteSample(media.Sample{Data: pkt.Data, Samples: 90000}) default: - return errors.New("Media Track Not Found") + return ErrorCodecNotSupported } } - return errors.New("Media Track Not Found") + return ErrorNotFound + +} +func (element *Muxer) WaitCloser() { + select { + case <-element.ps: + element.stop = true + element.Close() + case <-element.pt.C: + element.stop = true + element.Close() + } +} +func (element *Muxer) Close() error { + if element.pc != nil { + err := element.pc.Close() + if err != nil { + return err + } + } + return nil } func getPayloadType(m webrtc.MediaEngine, codecType webrtc.RTPCodecType, codecName string) uint8 {