diff --git a/node/app/node.go b/node/app/node.go index 6254cc8..e0ed4ae 100644 --- a/node/app/node.go +++ b/node/app/node.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "errors" "fmt" + "sync" "go.uber.org/zap" "golang.org/x/crypto/sha3" @@ -162,9 +163,17 @@ func (n *Node) Start() { } // TODO: add config mapping to engine name/frame registration + wg := sync.WaitGroup{} for _, e := range n.execEngines { - n.engine.RegisterExecutor(e, 0) + wg.Add(1) + go func(e execution.ExecutionEngine) { + defer wg.Done() + if err := <-n.engine.RegisterExecutor(e, 0); err != nil { + panic(err) + } + }(e) } + wg.Wait() } func (n *Node) Stop() { diff --git a/node/consensus/data/data_clock_consensus_engine.go b/node/consensus/data/data_clock_consensus_engine.go index 66a8804..801714d 100644 --- a/node/consensus/data/data_clock_consensus_engine.go +++ b/node/consensus/data/data_clock_consensus_engine.go @@ -72,6 +72,8 @@ type DataClockConsensusEngine struct { cancel context.CancelFunc wg sync.WaitGroup + grpcServers []*grpc.Server + lastProven uint64 difficulty uint32 config *config.Config @@ -349,38 +351,40 @@ func (e *DataClockConsensusEngine) Start() <-chan error { e.pubSub.Subscribe(e.frameFragmentFilter, e.handleFrameFragmentMessage) e.pubSub.Subscribe(e.txFilter, e.handleTxMessage) e.pubSub.Subscribe(e.infoFilter, e.handleInfoMessage) + + syncServer := qgrpc.NewServer( + grpc.MaxSendMsgSize(40*1024*1024), + grpc.MaxRecvMsgSize(40*1024*1024), + ) + e.grpcServers = append(e.grpcServers[:0:0], syncServer) + protobufs.RegisterDataServiceServer(syncServer, e) go func() { - server := qgrpc.NewServer( - grpc.MaxSendMsgSize(40*1024*1024), - grpc.MaxRecvMsgSize(40*1024*1024), - ) - protobufs.RegisterDataServiceServer(server, e) if err := e.pubSub.StartDirectChannelListener( e.pubSub.GetPeerID(), "sync", - server, + syncServer, ); err != nil { - panic(err) + e.logger.Error("error starting sync server", zap.Error(err)) } }() - go func() { - if e.dataTimeReel.GetFrameProverTries()[0].Contains(e.provingKeyAddress) { - server := qgrpc.NewServer( - grpc.MaxSendMsgSize(1*1024*1024), - grpc.MaxRecvMsgSize(1*1024*1024), - ) - protobufs.RegisterDataServiceServer(server, e) - + if e.FrameProverTrieContains(0, e.provingKeyAddress) { + workerServer := qgrpc.NewServer( + grpc.MaxSendMsgSize(1*1024*1024), + grpc.MaxRecvMsgSize(1*1024*1024), + ) + e.grpcServers = append(e.grpcServers, workerServer) + protobufs.RegisterDataServiceServer(workerServer, e) + go func() { if err := e.pubSub.StartDirectChannelListener( e.pubSub.GetPeerID(), "worker", - server, + workerServer, ); err != nil { - panic(err) + e.logger.Error("error starting worker server", zap.Error(err)) } - } - }() + }() + } e.stateMx.Lock() e.state = consensus.EngineStateCollecting @@ -661,6 +665,16 @@ func (e *DataClockConsensusEngine) PerformTimeProof( } func (e *DataClockConsensusEngine) Stop(force bool) <-chan error { + wg := sync.WaitGroup{} + wg.Add(len(e.grpcServers)) + for _, server := range e.grpcServers { + go func(server *grpc.Server) { + defer wg.Done() + server.GracefulStop() + }(server) + } + wg.Wait() + e.logger.Info("stopping ceremony consensus engine") e.cancel() e.wg.Wait() @@ -684,7 +698,6 @@ func (e *DataClockConsensusEngine) Stop(force bool) <-chan error { e.logger.Warn("error publishing prover pause", zap.Error(err)) } - wg := sync.WaitGroup{} wg.Add(len(e.executionEngines)) executionErrors := make(chan error, len(e.executionEngines)) for name := range e.executionEngines { diff --git a/node/main.go b/node/main.go index 35154ce..32ce019 100644 --- a/node/main.go +++ b/node/main.go @@ -467,6 +467,7 @@ func main() { if !*integrityCheck { go spawnDataWorkers(nodeConfig) + defer stopDataWorkers() } kzg.Init() @@ -510,6 +511,9 @@ func main() { // runtime.GOMAXPROCS(1) + node.Start() + defer node.Stop() + if nodeConfig.ListenGRPCMultiaddr != "" { srv, err := rpc.NewRPCServer( nodeConfig.ListenGRPCMultiaddr, @@ -526,20 +530,13 @@ func main() { if err != nil { panic(err) } - - go func() { - err := srv.Start() - if err != nil { - panic(err) - } - }() + if err := srv.Start(); err != nil { + panic(err) + } + defer srv.Stop() } - node.Start() - <-done - stopDataWorkers() - node.Stop() } var dataWorkers []*exec.Cmd diff --git a/node/rpc/node_rpc_server.go b/node/rpc/node_rpc_server.go index 568b836..be1c5ba 100644 --- a/node/rpc/node_rpc_server.go +++ b/node/rpc/node_rpc_server.go @@ -6,6 +6,7 @@ import ( "math/big" "net/http" "strings" + "sync" "time" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" @@ -43,6 +44,8 @@ type RPCServer struct { pubSub p2p.PubSub masterClock *master.MasterClockConsensusEngine executionEngines []execution.ExecutionEngine + grpcServer *grpc.Server + httpServer *http.Server } // GetFrameInfo implements protobufs.NodeServiceServer. @@ -384,7 +387,33 @@ func NewRPCServer( masterClock *master.MasterClockConsensusEngine, executionEngines []execution.ExecutionEngine, ) (*RPCServer, error) { - return &RPCServer{ + mg, err := multiaddr.NewMultiaddr(listenAddrGRPC) + if err != nil { + return nil, errors.Wrap(err, "new rpc server") + } + mga, err := mn.ToNetAddr(mg) + if err != nil { + return nil, errors.Wrap(err, "new rpc server") + } + + mux := runtime.NewServeMux() + opts := qgrpc.ClientOptions( + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithDefaultCallOptions( + grpc.MaxCallRecvMsgSize(600*1024*1024), + grpc.MaxCallSendMsgSize(600*1024*1024), + ), + ) + if err := protobufs.RegisterNodeServiceHandlerFromEndpoint( + context.Background(), + mux, + mga.String(), + opts, + ); err != nil { + return nil, err + } + + rpcServer := &RPCServer{ listenAddrGRPC: listenAddrGRPC, listenAddrHTTP: listenAddrHTTP, logger: logger, @@ -395,17 +424,22 @@ func NewRPCServer( pubSub: pubSub, masterClock: masterClock, executionEngines: executionEngines, - }, nil + grpcServer: qgrpc.NewServer( + grpc.MaxRecvMsgSize(600*1024*1024), + grpc.MaxSendMsgSize(600*1024*1024), + ), + httpServer: &http.Server{ + Handler: mux, + }, + } + + protobufs.RegisterNodeServiceServer(rpcServer.grpcServer, rpcServer) + reflection.Register(rpcServer.grpcServer) + + return rpcServer, nil } func (r *RPCServer) Start() error { - s := qgrpc.NewServer( - grpc.MaxRecvMsgSize(600*1024*1024), - grpc.MaxSendMsgSize(600*1024*1024), - ) - protobufs.RegisterNodeServiceServer(s, r) - reflection.Register(s) - mg, err := multiaddr.NewMultiaddr(r.listenAddrGRPC) if err != nil { return errors.Wrap(err, "start") @@ -417,51 +451,42 @@ func (r *RPCServer) Start() error { } go func() { - if err := s.Serve(mn.NetListener(lis)); err != nil { - panic(err) + if err := r.grpcServer.Serve(mn.NetListener(lis)); err != nil { + r.logger.Error("serve error", zap.Error(err)) } }() if r.listenAddrHTTP != "" { - m, err := multiaddr.NewMultiaddr(r.listenAddrHTTP) + mh, err := multiaddr.NewMultiaddr(r.listenAddrHTTP) if err != nil { return errors.Wrap(err, "start") } - ma, err := mn.ToNetAddr(m) - if err != nil { - return errors.Wrap(err, "start") - } - - mga, err := mn.ToNetAddr(mg) + lis, err := mn.Listen(mh) if err != nil { return errors.Wrap(err, "start") } go func() { - mux := runtime.NewServeMux() - opts := qgrpc.ClientOptions( - grpc.WithTransportCredentials(insecure.NewCredentials()), - grpc.WithDefaultCallOptions( - grpc.MaxCallRecvMsgSize(600*1024*1024), - grpc.MaxCallSendMsgSize(600*1024*1024), - ), - ) - - if err := protobufs.RegisterNodeServiceHandlerFromEndpoint( - context.Background(), - mux, - mga.String(), - opts, - ); err != nil { - panic(err) - } - - if err := http.ListenAndServe(ma.String(), mux); err != nil { - panic(err) + if err := r.httpServer.Serve(mn.NetListener(lis)); err != nil && !errors.Is(err, http.ErrServerClosed) { + r.logger.Error("serve error", zap.Error(err)) } }() } return nil } + +func (r *RPCServer) Stop() { + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + r.grpcServer.GracefulStop() + }() + go func() { + defer wg.Done() + r.httpServer.Shutdown(context.Background()) + }() + wg.Wait() +}