A privacy-first, self-hosted, fully open source personal knowledge management software, written in typescript and golang. (PERSONAL FORK)
1// SiYuan - Refactor your thinking
2// Copyright (c) 2020-present, b3log.org
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU Affero General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8//
9// This program is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU Affero General Public License for more details.
13//
14// You should have received a copy of the GNU Affero General Public License
15// along with this program. If not, see <https://www.gnu.org/licenses/>.
16
17package util
18
19import (
20 "crypto/ecdsa"
21 "crypto/elliptic"
22 "crypto/rand"
23 "crypto/x509"
24 "crypto/x509/pkix"
25 "encoding/pem"
26 "fmt"
27 "math/big"
28 "net"
29 "os"
30 "path/filepath"
31 "strings"
32 "time"
33
34 "github.com/88250/gulu"
35 "github.com/siyuan-note/logging"
36)
37
38const (
39 TLSCACertFilename = "ca.crt"
40 TLSCAKeyFilename = "ca.key"
41 TLSCertFilename = "cert.pem"
42 TLSKeyFilename = "key.pem"
43)
44
45// Returns paths to existing TLS certificates or generates new ones signed by a local CA.
46// Certificates are stored in the conf directory of the workspace.
47func GetOrCreateTLSCert() (certPath, keyPath string, err error) {
48 certPath = filepath.Join(ConfDir, TLSCertFilename)
49 keyPath = filepath.Join(ConfDir, TLSKeyFilename)
50 caCertPath := filepath.Join(ConfDir, TLSCACertFilename)
51 caKeyPath := filepath.Join(ConfDir, TLSCAKeyFilename)
52
53 if !gulu.File.IsExist(caCertPath) || !gulu.File.IsExist(caKeyPath) {
54 logging.LogInfof("generating local CA for TLS...")
55 if err = generateCACert(caCertPath, caKeyPath); err != nil {
56 logging.LogErrorf("failed to generate CA certificates: %s", err)
57 return "", "", err
58 }
59 }
60
61 if gulu.File.IsExist(certPath) && gulu.File.IsExist(keyPath) {
62 if validateCert(certPath) {
63 logging.LogInfof("using existing TLS certificates from [%s]", ConfDir)
64 return certPath, keyPath, nil
65 }
66 logging.LogInfof("existing TLS certificates are invalid or expired, regenerating...")
67 }
68
69 caCert, caKey, err := loadCA(caCertPath, caKeyPath)
70 if err != nil {
71 logging.LogErrorf("failed to load CA certificates: %s", err)
72 return "", "", err
73 }
74
75 logging.LogInfof("generating TLS server certificates signed by local CA...")
76 if err = generateServerCert(certPath, keyPath, caCert, caKey); err != nil {
77 logging.LogErrorf("failed to generate TLS certificates: %s", err)
78 return "", "", err
79 }
80
81 logging.LogInfof("generated TLS certificates at [%s]", ConfDir)
82 return certPath, keyPath, nil
83}
84
85// Checks if the certificate file exists, is not expired, and contains all current IP addresses
86func validateCert(certPath string) bool {
87 certPEM, err := os.ReadFile(certPath)
88 if err != nil {
89 return false
90 }
91
92 block, _ := pem.Decode(certPEM)
93 if block == nil {
94 return false
95 }
96
97 cert, err := x509.ParseCertificate(block.Bytes)
98 if err != nil {
99 return false
100 }
101
102 // Check if certificate is still valid, with 7 day buffer
103 if !time.Now().Add(7 * 24 * time.Hour).Before(cert.NotAfter) {
104 return false
105 }
106
107 // Check if certificate contains all current IP addresses
108 currentIPs := extractIPsFromServerAddrs()
109 certIPMap := make(map[string]bool)
110 for _, ip := range cert.IPAddresses {
111 certIPMap[ip.String()] = true
112 }
113
114 for _, ipStr := range currentIPs {
115 ipStr = trimIPv6Brackets(ipStr)
116 ip := net.ParseIP(ipStr)
117 if ip == nil {
118 continue
119 }
120
121 if !certIPMap[ip.String()] {
122 logging.LogInfof("certificate missing current IP address [%s], will regenerate", ip.String())
123 return false
124 }
125 }
126
127 return true
128}
129
130// Creates a new self-signed CA certificate
131func generateCACert(certPath, keyPath string) error {
132 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
133 if err != nil {
134 return err
135 }
136
137 serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
138 if err != nil {
139 return err
140 }
141
142 notBefore := time.Now()
143 notAfter := notBefore.Add(10 * 365 * 24 * time.Hour)
144
145 template := x509.Certificate{
146 SerialNumber: serialNumber,
147 Subject: pkix.Name{
148 Organization: []string{"SiYuan"},
149 CommonName: "SiYuan Local CA",
150 },
151 NotBefore: notBefore,
152 NotAfter: notAfter,
153 KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
154 BasicConstraintsValid: true,
155 IsCA: true,
156 }
157
158 certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
159 if err != nil {
160 return err
161 }
162
163 return writeCertAndKey(certPath, keyPath, certDER, privateKey)
164}
165
166// Creates a new server certificate signed by the CA
167func generateServerCert(certPath, keyPath string, caCert *x509.Certificate, caKey any) error {
168 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
169 if err != nil {
170 return err
171 }
172
173 serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
174 if err != nil {
175 return err
176 }
177
178 notBefore := time.Now()
179 notAfter := notBefore.Add(365 * 24 * time.Hour)
180
181 ipAddresses := []net.IP{
182 net.ParseIP("127.0.0.1"),
183 net.IPv6loopback,
184 }
185
186 localIPs := extractIPsFromServerAddrs()
187 for _, ipStr := range localIPs {
188 ipStr = trimIPv6Brackets(ipStr)
189 if ip := net.ParseIP(ipStr); ip != nil {
190 ipAddresses = append(ipAddresses, ip)
191 }
192 }
193
194 template := x509.Certificate{
195 SerialNumber: serialNumber,
196 Subject: pkix.Name{
197 Organization: []string{"SiYuan"},
198 CommonName: "SiYuan Local Server",
199 },
200 NotBefore: notBefore,
201 NotAfter: notAfter,
202 KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
203 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
204 BasicConstraintsValid: true,
205 DNSNames: []string{"localhost"},
206 IPAddresses: ipAddresses,
207 }
208
209 certDER, err := x509.CreateCertificate(rand.Reader, &template, caCert, &privateKey.PublicKey, caKey)
210 if err != nil {
211 return err
212 }
213
214 return writeCertAndKey(certPath, keyPath, certDER, privateKey)
215}
216
217// Loads the CA certificate and private key from files
218func loadCA(certPath, keyPath string) (*x509.Certificate, any, error) {
219 certPEM, err := os.ReadFile(certPath)
220 if err != nil {
221 return nil, nil, err
222 }
223
224 block, _ := pem.Decode(certPEM)
225 if block == nil {
226 return nil, nil, fmt.Errorf("failed to decode CA certificate PEM")
227 }
228
229 caCert, err := x509.ParseCertificate(block.Bytes)
230 if err != nil {
231 return nil, nil, err
232 }
233
234 keyPEM, err := os.ReadFile(keyPath)
235 if err != nil {
236 return nil, nil, err
237 }
238
239 block, _ = pem.Decode(keyPEM)
240 if block == nil {
241 return nil, nil, fmt.Errorf("failed to decode CA key PEM")
242 }
243
244 caKey, err := x509.ParseECPrivateKey(block.Bytes)
245 if err != nil {
246 return nil, nil, err
247 }
248
249 return caCert, caKey, nil
250}
251
252func writeCertAndKey(certPath, keyPath string, certDER []byte, privateKey *ecdsa.PrivateKey) error {
253 certFile, err := os.Create(certPath)
254 if err != nil {
255 return err
256 }
257 defer certFile.Close()
258
259 if err = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil {
260 return err
261 }
262
263 keyFile, err := os.Create(keyPath)
264 if err != nil {
265 return err
266 }
267 defer keyFile.Close()
268
269 keyDER, err := x509.MarshalECPrivateKey(privateKey)
270 if err != nil {
271 return err
272 }
273
274 if err = pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}); err != nil {
275 return err
276 }
277
278 return nil
279}
280
281// Imports a CA certificate and private key from PEM-encoded strings.
282func ImportCABundle(caCertPEM, caKeyPEM string) error {
283 certBlock, _ := pem.Decode([]byte(caCertPEM))
284 if certBlock == nil {
285 return fmt.Errorf("failed to decode CA certificate PEM")
286 }
287
288 caCert, err := x509.ParseCertificate(certBlock.Bytes)
289 if err != nil {
290 return fmt.Errorf("failed to parse CA certificate: %w", err)
291 }
292
293 if !caCert.IsCA {
294 return fmt.Errorf("the provided certificate is not a CA certificate")
295 }
296
297 keyBlock, _ := pem.Decode([]byte(caKeyPEM))
298 if keyBlock == nil {
299 return fmt.Errorf("failed to decode CA private key PEM")
300 }
301
302 _, err = x509.ParseECPrivateKey(keyBlock.Bytes)
303 if err != nil {
304 return fmt.Errorf("failed to parse CA private key: %w", err)
305 }
306
307 caCertPath := filepath.Join(ConfDir, TLSCACertFilename)
308 caKeyPath := filepath.Join(ConfDir, TLSCAKeyFilename)
309
310 if err := os.WriteFile(caCertPath, []byte(caCertPEM), 0644); err != nil {
311 return fmt.Errorf("failed to write CA certificate: %w", err)
312 }
313
314 if err := os.WriteFile(caKeyPath, []byte(caKeyPEM), 0600); err != nil {
315 return fmt.Errorf("failed to write CA private key: %w", err)
316 }
317
318 certPath := filepath.Join(ConfDir, TLSCertFilename)
319 keyPath := filepath.Join(ConfDir, TLSKeyFilename)
320
321 if gulu.File.IsExist(certPath) {
322 os.Remove(certPath)
323 }
324 if gulu.File.IsExist(keyPath) {
325 os.Remove(keyPath)
326 }
327
328 logging.LogInfof("imported CA bundle, server certificate will be regenerated on next TLS initialization")
329 return nil
330}
331
332// trimIPv6Brackets removes brackets from IPv6 address strings like "[::1]"
333func trimIPv6Brackets(ip string) string {
334 if len(ip) > 2 && ip[0] == '[' && ip[len(ip)-1] == ']' {
335 return ip[1 : len(ip)-1]
336 }
337 return ip
338}
339
340// extractIPsFromServerAddrs extracts IP addresses from server URLs returned by GetServerAddrs()
341// GetServerAddrs() returns URLs like "http://192.168.1.1:6806", this function extracts just the IP part
342func extractIPsFromServerAddrs() []string {
343 serverAddrs := GetServerAddrs()
344 var ips []string
345 for _, addr := range serverAddrs {
346 addr = strings.TrimPrefix(addr, "http://")
347 addr = strings.TrimPrefix(addr, "https://")
348
349 if strings.HasPrefix(addr, "[") {
350 // IPv6 address with brackets
351 if idx := strings.Index(addr, "]:"); idx != -1 {
352 addr = addr[1:idx]
353 } else if strings.HasSuffix(addr, "]") {
354 addr = addr[1 : len(addr)-1]
355 }
356 } else {
357 // IPv4 address or IPv6 without brackets
358 if idx := strings.LastIndex(addr, ":"); idx != -1 {
359 if strings.Count(addr, ":") == 1 {
360 // IPv4 with port
361 addr = addr[:idx]
362 }
363 }
364 }
365 ips = append(ips, addr)
366 }
367 return ips
368}