fix: race condition in handshake

pull/1/head
Benyamin Azarkhazin 2023-08-04 13:16:27 +03:30
parent 8d2d26c14c
commit 1ff98068cd
Signed by: benyamin
GPG Key ID: 3AE44F5623C70269
2 changed files with 86 additions and 59 deletions

View File

@ -70,6 +70,7 @@ func (c *RoomController) Offer(ctx *gin.Context) {
c.helper.ResponseUnprocessableEntity(ctx) c.helper.ResponseUnprocessableEntity(ctx)
return return
} }
println("offer from", reqModel.ID)
answer, err := c.repo.SetPeerOffer(reqModel.RoomId, reqModel.ID, reqModel.SDP) answer, err := c.repo.SetPeerOffer(reqModel.RoomId, reqModel.ID, reqModel.SDP)
if c.helper.HandleIfErr(ctx, err, nil) { if c.helper.HandleIfErr(ctx, err, nil) {
println(err.Error()) println(err.Error())
@ -110,6 +111,7 @@ func (c *RoomController) Answer(ctx *gin.Context) {
c.helper.ResponseUnprocessableEntity(ctx) c.helper.ResponseUnprocessableEntity(ctx)
return return
} }
println("answer from", reqModel.ID)
err := c.repo.SetPeerAnswer(reqModel.RoomId, reqModel.ID, reqModel.SDP) err := c.repo.SetPeerAnswer(reqModel.RoomId, reqModel.ID, reqModel.SDP)
if c.helper.HandleIfErr(ctx, err, nil) { if c.helper.HandleIfErr(ctx, err, nil) {
println(err.Error()) println(err.Error())

View File

@ -3,6 +3,7 @@ package repositories
import ( import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"github.com/pion/rtcp" "github.com/pion/rtcp"
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
@ -23,6 +24,7 @@ type Peer struct {
ID uint64 ID uint64
Conn *webrtc.PeerConnection Conn *webrtc.PeerConnection
CanPublish bool CanPublish bool
HandshakeLock *sync.Mutex
} }
type Room struct { type Room struct {
@ -86,11 +88,16 @@ func (r *RoomRepository) CreatePeer(roomId string, id uint64, canPublish, isCall
continue continue
} }
go peer.Conn.WriteRTCP([]rtcp.Packet{ go func(recv *webrtc.RTPReceiver) {
err := peer.Conn.WriteRTCP([]rtcp.Packet{
&rtcp.PictureLossIndication{ &rtcp.PictureLossIndication{
MediaSSRC: uint32(receiver.Track().SSRC()), MediaSSRC: uint32(recv.Track().SSRC()),
}, },
}) })
if err != nil {
println(`[E] [rtcp] `, err.Error())
}
}(receiver)
} }
} }
@ -99,8 +106,8 @@ func (r *RoomRepository) CreatePeer(roomId string, id uint64, canPublish, isCall
}() }()
} }
r.Unlock()
room := r.Rooms[roomId] room := r.Rooms[roomId]
r.Unlock()
room.Lock() room.Lock()
defer room.Unlock() defer room.Unlock()
@ -118,10 +125,15 @@ func (r *RoomRepository) CreatePeer(roomId string, id uint64, canPublish, isCall
peerConn.OnTrack(func(remote *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { peerConn.OnTrack(func(remote *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
r.onPeerTrack(roomId, id, remote, receiver) r.onPeerTrack(roomId, id, remote, receiver)
}) })
/*peerConn.OnNegotiationNeeded(func() {
println("[PC] negotiating with peer", id)
r.offerPeer(peerConn,roomId,id)
})*/
room.Peers[id] = &Peer{ room.Peers[id] = &Peer{
ID: id, ID: id,
Conn: peerConn, Conn: peerConn,
HandshakeLock: &sync.Mutex{},
CanPublish: canPublish,
} }
go r.updatePCTracks(roomId) go r.updatePCTracks(roomId)
return nil return nil
@ -197,12 +209,12 @@ func (r *RoomRepository) onPeerTrack(roomId string, id uint64, remote *webrtc.Tr
} }
room.trackLock.Unlock() room.trackLock.Unlock()
defer func() { defer func(trackId string) {
room.trackLock.Lock() room.trackLock.Lock()
delete(room.Tracks, remote.ID()) delete(room.Tracks, trackId)
room.trackLock.Unlock() room.trackLock.Unlock()
r.updatePCTracks(roomId) r.updatePCTracks(roomId)
}() }(remote.ID())
go r.updatePCTracks(roomId) go r.updatePCTracks(roomId)
buffer := make([]byte, 1500) buffer := make([]byte, 1500)
for { for {
@ -252,10 +264,10 @@ func (r *RoomRepository) updatePCTracks(roomId string) {
renegotiate := false renegotiate := false
for id, track := range room.Tracks { for id, track := range room.Tracks {
_, alreadySend := alreadySentTracks[id] _, alreadySend := alreadySentTracks[id]
_, alreadyReceiver := receivingPeerTracks[id] _, alreadyReceived := receivingPeerTracks[id]
if track.OwnerId != peer.ID && (!alreadySend && !alreadyReceiver) { if track.OwnerId != peer.ID && (!alreadySend && !alreadyReceived) {
renegotiate = true renegotiate = true
println("add track") println("[PC] add track", track.TrackLocal.ID(), "to", peer.ID)
_, err := peer.Conn.AddTrack(track.TrackLocal) _, err := peer.Conn.AddTrack(track.TrackLocal)
if err != nil { if err != nil {
println(err.Error()) println(err.Error())
@ -263,38 +275,20 @@ func (r *RoomRepository) updatePCTracks(roomId string) {
} }
} }
} }
for trackId, rtpSender := range alreadySentTracks {
if _, exists := room.Tracks[trackId]; !exists {
println("[PC] remove track", trackId, "from", peer.ID)
_ = rtpSender.Stop()
_ = peer.Conn.RemoveTrack(rtpSender)
}
}
room.trackLock.Unlock() room.trackLock.Unlock()
if renegotiate { if renegotiate {
offer, err := peer.Conn.CreateOffer(nil) err := r.offerPeer(peer, roomId)
if err != nil { if err != nil {
println(err.Error()) println(`[E]`, err.Error())
return return
} }
err = peer.Conn.SetLocalDescription(offer)
if err != nil {
println(err.Error())
return
}
reqModel := dto.SetSDPReqModel{
PeerDTO: dto.PeerDTO{
RoomId: roomId,
ID: peer.ID,
},
SDP: offer,
}
bodyJson, err := json.Marshal(reqModel)
if err != nil {
println(err.Error())
return
}
res, err := http.Post(r.conf.LogjamBaseUrl+"/offer", "application/json", bytes.NewReader(bodyJson))
if err != nil {
println(err.Error())
return
}
if res.StatusCode > 204 {
println("/offer ", res.Status)
}
} }
} }
} }
@ -340,6 +334,7 @@ func (r *RoomRepository) SetPeerAnswer(roomId string, id uint64, answer webrtc.S
if err != nil { if err != nil {
return models.NewError(err.Error(), 500, models.MessageResponse{Message: err.Error()}) return models.NewError(err.Error(), 500, models.MessageResponse{Message: err.Error()})
} }
room.Peers[id].HandshakeLock.Unlock()
return nil return nil
} }
func (r *RoomRepository) SetPeerOffer(roomId string, id uint64, offer webrtc.SessionDescription) (sdpAnswer *webrtc.SessionDescription, err error) { func (r *RoomRepository) SetPeerOffer(roomId string, id uint64, offer webrtc.SessionDescription) (sdpAnswer *webrtc.SessionDescription, err error) {
@ -348,24 +343,28 @@ func (r *RoomRepository) SetPeerOffer(roomId string, id uint64, offer webrtc.Ses
r.Unlock() r.Unlock()
return nil, models.NewError("room doesn't exists", 403, map[string]any{"roomId": roomId}) return nil, models.NewError("room doesn't exists", 403, map[string]any{"roomId": roomId})
} }
r.Unlock()
room := r.Rooms[roomId] room := r.Rooms[roomId]
r.Unlock()
room.Lock() room.Lock()
defer room.Unlock()
if !r.doesPeerExists(roomId, id) { if !r.doesPeerExists(roomId, id) {
room.Unlock()
return nil, models.NewError("no such a peer with this id in this room", 403, map[string]any{"roomId": roomId, "peerId": id}) return nil, models.NewError("no such a peer with this id in this room", 403, map[string]any{"roomId": roomId, "peerId": id})
} }
peer := room.Peers[id]
room.Unlock()
err = room.Peers[id].Conn.SetRemoteDescription(offer) peer.HandshakeLock.Lock()
defer peer.HandshakeLock.Unlock()
err = peer.Conn.SetRemoteDescription(offer)
if err != nil { if err != nil {
return nil, models.NewError(err.Error(), 500, models.MessageResponse{Message: err.Error()}) return nil, models.NewError(err.Error(), 500, models.MessageResponse{Message: err.Error()})
} }
answer, err := room.Peers[id].Conn.CreateAnswer(nil) answer, err := peer.Conn.CreateAnswer(nil)
if err != nil { if err != nil {
return nil, models.NewError(err.Error(), 500, models.MessageResponse{Message: err.Error()}) return nil, models.NewError(err.Error(), 500, models.MessageResponse{Message: err.Error()})
} }
err = room.Peers[id].Conn.SetLocalDescription(answer) err = peer.Conn.SetLocalDescription(answer)
if err != nil { if err != nil {
return nil, models.NewError(err.Error(), 500, models.MessageResponse{Message: err.Error()}) return nil, models.NewError(err.Error(), 500, models.MessageResponse{Message: err.Error()})
} }
@ -411,23 +410,49 @@ func (r *RoomRepository) ClosePeer(roomId string, id uint64) error {
func (r *RoomRepository) ResetRoom(roomId string) error { func (r *RoomRepository) ResetRoom(roomId string) error {
r.Lock() r.Lock()
defer r.Unlock()
if !r.doesRoomExists(roomId) { if !r.doesRoomExists(roomId) {
r.Unlock()
return nil return nil
} }
room := r.Rooms[roomId] room := r.Rooms[roomId]
r.Unlock()
room.Lock() room.Lock()
defer room.Unlock()
room.timer.Stop() room.timer.Stop()
for _, peer := range room.Peers { for _, peer := range room.Peers {
peer.Conn.Close() _ = peer.Conn.Close()
} }
room.Unlock()
r.Lock()
delete(r.Rooms, roomId) delete(r.Rooms, roomId)
r.Unlock() return nil
}
func (r *RoomRepository) offerPeer(peer *Peer, roomId string) error {
peer.HandshakeLock.Lock()
println("[PC] negotiating with peer", peer.ID)
offer, err := peer.Conn.CreateOffer(nil)
if err != nil {
return err
}
err = peer.Conn.SetLocalDescription(offer)
if err != nil {
return err
}
reqModel := dto.SetSDPReqModel{
PeerDTO: dto.PeerDTO{
RoomId: roomId,
ID: peer.ID,
},
SDP: offer,
}
bodyJson, err := json.Marshal(reqModel)
if err != nil {
return err
}
res, err := http.Post(r.conf.LogjamBaseUrl+"/offer", "application/json", bytes.NewReader(bodyJson))
if err != nil {
return err
}
if res.StatusCode > 204 {
return errors.New("POST {logjambaseurl}/offer : " + res.Status)
}
return nil return nil
} }