1package server
2
3import (
4 "context"
5 "crypto/ecdsa"
6 "crypto/elliptic"
7 "crypto/rand"
8 "crypto/tls"
9 "crypto/x509"
10 "crypto/x509/pkix"
11 "fmt"
12 "log/slog"
13 "math/big"
14 "net"
15 "net/http"
16 "os"
17 "os/signal"
18 "syscall"
19 "time"
20
21 "github.com/gocql/gocql"
22 vyletdatabase "github.com/vylet-app/go/database/proto"
23 "google.golang.org/grpc"
24 "google.golang.org/grpc/credentials"
25 "google.golang.org/grpc/reflection"
26)
27
28const (
29 grpcTimeout = 10 * time.Minute
30)
31
32type Server struct {
33 vyletdatabase.UnimplementedProfileServiceServer
34 vyletdatabase.UnimplementedPostServiceServer
35 vyletdatabase.UnimplementedLikeServiceServer
36 vyletdatabase.UnimplementedBlobRefServiceServer
37 vyletdatabase.UnimplementedFollowServiceServer
38
39 logger *slog.Logger
40
41 listenerAddr string
42 grpcServer *grpc.Server
43
44 cqlSession *gocql.Session
45
46 cassandraAddrs []string
47 cassandraKeyspace string
48}
49
50type Args struct {
51 Logger *slog.Logger
52
53 ListenAddr string
54
55 CassandraAddrs []string
56 CassandraKeyspace string
57}
58
59func New(args *Args) (*Server, error) {
60 if args.Logger == nil {
61 args.Logger = slog.Default()
62 }
63
64 logger := args.Logger
65
66 certificate, err := GenerateTLSCertificate("localhost")
67 if err != nil {
68 return nil, fmt.Errorf("failed to generate TLS certificate: %w", err)
69 }
70
71 tlsConfig := &tls.Config{
72 Certificates: []tls.Certificate{*certificate},
73 MinVersion: tls.VersionTLS13,
74 }
75 creds := credentials.NewTLS(tlsConfig)
76
77 grpcServer := grpc.NewServer(
78 grpc.Creds(creds),
79 grpc.MaxConcurrentStreams(100_000),
80 grpc.ConnectionTimeout(grpcTimeout),
81 )
82
83 cluster := gocql.NewCluster(args.CassandraAddrs...)
84 cluster.Keyspace = args.CassandraKeyspace
85 cluster.Consistency = gocql.Quorum
86 cluster.ProtoVersion = 4
87 cluster.ConnectTimeout = time.Second * 10
88 cluster.Timeout = time.Second * 10
89
90 session, err := cluster.CreateSession()
91 if err != nil {
92 return nil, fmt.Errorf("failed to connect to cassandra: %w", err)
93 }
94
95 server := Server{
96 logger: logger,
97
98 cassandraAddrs: args.CassandraAddrs,
99 cassandraKeyspace: args.CassandraKeyspace,
100
101 listenerAddr: args.ListenAddr,
102
103 cqlSession: session,
104
105 grpcServer: grpcServer,
106 }
107
108 server.registerServices()
109
110 return &server, nil
111}
112
113func (s *Server) Run(ctx context.Context) error {
114 logger := s.logger.With("name", "Run")
115
116 logger.Info("attempting to listen", "addr", s.listenerAddr)
117
118 listener, err := net.Listen("tcp", s.listenerAddr)
119 if err != nil {
120 return fmt.Errorf("failed to listen: %w", err)
121 }
122
123 logger.Info("running gRPC server", "addr", s.listenerAddr)
124
125 grpcServerErr := make(chan error, 1)
126 go func() {
127 logger.Info("starting gRPC server")
128
129 if err := s.grpcServer.Serve(listener); err != nil {
130 if err != http.ErrServerClosed {
131 logger.Error("gRPC server shutdown with error", "err", err)
132 grpcServerErr <- err
133 return
134 }
135 logger.Info("gRPC server shutdown")
136 grpcServerErr <- nil
137 }
138 }()
139
140 signals := make(chan os.Signal, 1)
141 signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
142
143 select {
144 case sig := <-signals:
145 logger.Info("received exit signal", "signal", sig)
146 case <-ctx.Done():
147 logger.Info("context cancelled")
148 case err := <-grpcServerErr:
149 logger.Error("received grpc server error", "err", err)
150 }
151
152 s.grpcServer.GracefulStop()
153 s.cqlSession.Close()
154
155 logger.Info("gRPC server shut down")
156
157 return nil
158}
159
160func (s *Server) registerServices() {
161 vyletdatabase.RegisterProfileServiceServer(s.grpcServer, s)
162 vyletdatabase.RegisterPostServiceServer(s.grpcServer, s)
163 vyletdatabase.RegisterLikeServiceServer(s.grpcServer, s)
164 vyletdatabase.RegisterBlobRefServiceServer(s.grpcServer, s)
165 vyletdatabase.RegisterFollowServiceServer(s.grpcServer, s)
166 reflection.Register(s.grpcServer)
167}
168
169func GenerateTLSCertificate(commonName string) (*tls.Certificate, error) {
170 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
171 if err != nil {
172 return nil, err
173 }
174 template := x509.Certificate{
175 SerialNumber: big.NewInt(time.Now().UnixMilli()),
176 Subject: pkix.Name{
177 CommonName: commonName,
178 },
179 NotBefore: time.Now(),
180 NotAfter: time.Now().Add(10 * 365 * 24 * time.Hour),
181 }
182 certificate, err := x509.CreateCertificate(rand.Reader, &template, &template, privateKey.Public(), privateKey)
183 if err != nil {
184 return nil, err
185 }
186 return &tls.Certificate{
187 Certificate: [][]byte{certificate},
188 PrivateKey: privateKey,
189 }, nil
190}