A privacy-first, self-hosted, fully open source personal knowledge management software, written in typescript and golang. (PERSONAL FORK)
at lambda-fork/main 368 lines 10 kB view raw
1// SiYuan - Refactor your thinking 2// Copyright (c) 2020-present, b3log.org 3// 4// This program is free software: you can redistribute it and/or modify 5// it under the terms of the GNU Affero General Public License as published by 6// the Free Software Foundation, either version 3 of the License, or 7// (at your option) any later version. 8// 9// This program is distributed in the hope that it will be useful, 10// but WITHOUT ANY WARRANTY; without even the implied warranty of 11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12// GNU Affero General Public License for more details. 13// 14// You should have received a copy of the GNU Affero General Public License 15// along with this program. If not, see <https://www.gnu.org/licenses/>. 16 17package util 18 19import ( 20 "crypto/ecdsa" 21 "crypto/elliptic" 22 "crypto/rand" 23 "crypto/x509" 24 "crypto/x509/pkix" 25 "encoding/pem" 26 "fmt" 27 "math/big" 28 "net" 29 "os" 30 "path/filepath" 31 "strings" 32 "time" 33 34 "github.com/88250/gulu" 35 "github.com/siyuan-note/logging" 36) 37 38const ( 39 TLSCACertFilename = "ca.crt" 40 TLSCAKeyFilename = "ca.key" 41 TLSCertFilename = "cert.pem" 42 TLSKeyFilename = "key.pem" 43) 44 45// Returns paths to existing TLS certificates or generates new ones signed by a local CA. 46// Certificates are stored in the conf directory of the workspace. 47func GetOrCreateTLSCert() (certPath, keyPath string, err error) { 48 certPath = filepath.Join(ConfDir, TLSCertFilename) 49 keyPath = filepath.Join(ConfDir, TLSKeyFilename) 50 caCertPath := filepath.Join(ConfDir, TLSCACertFilename) 51 caKeyPath := filepath.Join(ConfDir, TLSCAKeyFilename) 52 53 if !gulu.File.IsExist(caCertPath) || !gulu.File.IsExist(caKeyPath) { 54 logging.LogInfof("generating local CA for TLS...") 55 if err = generateCACert(caCertPath, caKeyPath); err != nil { 56 logging.LogErrorf("failed to generate CA certificates: %s", err) 57 return "", "", err 58 } 59 } 60 61 if gulu.File.IsExist(certPath) && gulu.File.IsExist(keyPath) { 62 if validateCert(certPath) { 63 logging.LogInfof("using existing TLS certificates from [%s]", ConfDir) 64 return certPath, keyPath, nil 65 } 66 logging.LogInfof("existing TLS certificates are invalid or expired, regenerating...") 67 } 68 69 caCert, caKey, err := loadCA(caCertPath, caKeyPath) 70 if err != nil { 71 logging.LogErrorf("failed to load CA certificates: %s", err) 72 return "", "", err 73 } 74 75 logging.LogInfof("generating TLS server certificates signed by local CA...") 76 if err = generateServerCert(certPath, keyPath, caCert, caKey); err != nil { 77 logging.LogErrorf("failed to generate TLS certificates: %s", err) 78 return "", "", err 79 } 80 81 logging.LogInfof("generated TLS certificates at [%s]", ConfDir) 82 return certPath, keyPath, nil 83} 84 85// Checks if the certificate file exists, is not expired, and contains all current IP addresses 86func validateCert(certPath string) bool { 87 certPEM, err := os.ReadFile(certPath) 88 if err != nil { 89 return false 90 } 91 92 block, _ := pem.Decode(certPEM) 93 if block == nil { 94 return false 95 } 96 97 cert, err := x509.ParseCertificate(block.Bytes) 98 if err != nil { 99 return false 100 } 101 102 // Check if certificate is still valid, with 7 day buffer 103 if !time.Now().Add(7 * 24 * time.Hour).Before(cert.NotAfter) { 104 return false 105 } 106 107 // Check if certificate contains all current IP addresses 108 currentIPs := extractIPsFromServerAddrs() 109 certIPMap := make(map[string]bool) 110 for _, ip := range cert.IPAddresses { 111 certIPMap[ip.String()] = true 112 } 113 114 for _, ipStr := range currentIPs { 115 ipStr = trimIPv6Brackets(ipStr) 116 ip := net.ParseIP(ipStr) 117 if ip == nil { 118 continue 119 } 120 121 if !certIPMap[ip.String()] { 122 logging.LogInfof("certificate missing current IP address [%s], will regenerate", ip.String()) 123 return false 124 } 125 } 126 127 return true 128} 129 130// Creates a new self-signed CA certificate 131func generateCACert(certPath, keyPath string) error { 132 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 133 if err != nil { 134 return err 135 } 136 137 serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) 138 if err != nil { 139 return err 140 } 141 142 notBefore := time.Now() 143 notAfter := notBefore.Add(10 * 365 * 24 * time.Hour) 144 145 template := x509.Certificate{ 146 SerialNumber: serialNumber, 147 Subject: pkix.Name{ 148 Organization: []string{"SiYuan"}, 149 CommonName: "SiYuan Local CA", 150 }, 151 NotBefore: notBefore, 152 NotAfter: notAfter, 153 KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign, 154 BasicConstraintsValid: true, 155 IsCA: true, 156 } 157 158 certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) 159 if err != nil { 160 return err 161 } 162 163 return writeCertAndKey(certPath, keyPath, certDER, privateKey) 164} 165 166// Creates a new server certificate signed by the CA 167func generateServerCert(certPath, keyPath string, caCert *x509.Certificate, caKey any) error { 168 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 169 if err != nil { 170 return err 171 } 172 173 serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) 174 if err != nil { 175 return err 176 } 177 178 notBefore := time.Now() 179 notAfter := notBefore.Add(365 * 24 * time.Hour) 180 181 ipAddresses := []net.IP{ 182 net.ParseIP("127.0.0.1"), 183 net.IPv6loopback, 184 } 185 186 localIPs := extractIPsFromServerAddrs() 187 for _, ipStr := range localIPs { 188 ipStr = trimIPv6Brackets(ipStr) 189 if ip := net.ParseIP(ipStr); ip != nil { 190 ipAddresses = append(ipAddresses, ip) 191 } 192 } 193 194 template := x509.Certificate{ 195 SerialNumber: serialNumber, 196 Subject: pkix.Name{ 197 Organization: []string{"SiYuan"}, 198 CommonName: "SiYuan Local Server", 199 }, 200 NotBefore: notBefore, 201 NotAfter: notAfter, 202 KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, 203 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 204 BasicConstraintsValid: true, 205 DNSNames: []string{"localhost"}, 206 IPAddresses: ipAddresses, 207 } 208 209 certDER, err := x509.CreateCertificate(rand.Reader, &template, caCert, &privateKey.PublicKey, caKey) 210 if err != nil { 211 return err 212 } 213 214 return writeCertAndKey(certPath, keyPath, certDER, privateKey) 215} 216 217// Loads the CA certificate and private key from files 218func loadCA(certPath, keyPath string) (*x509.Certificate, any, error) { 219 certPEM, err := os.ReadFile(certPath) 220 if err != nil { 221 return nil, nil, err 222 } 223 224 block, _ := pem.Decode(certPEM) 225 if block == nil { 226 return nil, nil, fmt.Errorf("failed to decode CA certificate PEM") 227 } 228 229 caCert, err := x509.ParseCertificate(block.Bytes) 230 if err != nil { 231 return nil, nil, err 232 } 233 234 keyPEM, err := os.ReadFile(keyPath) 235 if err != nil { 236 return nil, nil, err 237 } 238 239 block, _ = pem.Decode(keyPEM) 240 if block == nil { 241 return nil, nil, fmt.Errorf("failed to decode CA key PEM") 242 } 243 244 caKey, err := x509.ParseECPrivateKey(block.Bytes) 245 if err != nil { 246 return nil, nil, err 247 } 248 249 return caCert, caKey, nil 250} 251 252func writeCertAndKey(certPath, keyPath string, certDER []byte, privateKey *ecdsa.PrivateKey) error { 253 certFile, err := os.Create(certPath) 254 if err != nil { 255 return err 256 } 257 defer certFile.Close() 258 259 if err = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}); err != nil { 260 return err 261 } 262 263 keyFile, err := os.Create(keyPath) 264 if err != nil { 265 return err 266 } 267 defer keyFile.Close() 268 269 keyDER, err := x509.MarshalECPrivateKey(privateKey) 270 if err != nil { 271 return err 272 } 273 274 if err = pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}); err != nil { 275 return err 276 } 277 278 return nil 279} 280 281// Imports a CA certificate and private key from PEM-encoded strings. 282func ImportCABundle(caCertPEM, caKeyPEM string) error { 283 certBlock, _ := pem.Decode([]byte(caCertPEM)) 284 if certBlock == nil { 285 return fmt.Errorf("failed to decode CA certificate PEM") 286 } 287 288 caCert, err := x509.ParseCertificate(certBlock.Bytes) 289 if err != nil { 290 return fmt.Errorf("failed to parse CA certificate: %w", err) 291 } 292 293 if !caCert.IsCA { 294 return fmt.Errorf("the provided certificate is not a CA certificate") 295 } 296 297 keyBlock, _ := pem.Decode([]byte(caKeyPEM)) 298 if keyBlock == nil { 299 return fmt.Errorf("failed to decode CA private key PEM") 300 } 301 302 _, err = x509.ParseECPrivateKey(keyBlock.Bytes) 303 if err != nil { 304 return fmt.Errorf("failed to parse CA private key: %w", err) 305 } 306 307 caCertPath := filepath.Join(ConfDir, TLSCACertFilename) 308 caKeyPath := filepath.Join(ConfDir, TLSCAKeyFilename) 309 310 if err := os.WriteFile(caCertPath, []byte(caCertPEM), 0644); err != nil { 311 return fmt.Errorf("failed to write CA certificate: %w", err) 312 } 313 314 if err := os.WriteFile(caKeyPath, []byte(caKeyPEM), 0600); err != nil { 315 return fmt.Errorf("failed to write CA private key: %w", err) 316 } 317 318 certPath := filepath.Join(ConfDir, TLSCertFilename) 319 keyPath := filepath.Join(ConfDir, TLSKeyFilename) 320 321 if gulu.File.IsExist(certPath) { 322 os.Remove(certPath) 323 } 324 if gulu.File.IsExist(keyPath) { 325 os.Remove(keyPath) 326 } 327 328 logging.LogInfof("imported CA bundle, server certificate will be regenerated on next TLS initialization") 329 return nil 330} 331 332// trimIPv6Brackets removes brackets from IPv6 address strings like "[::1]" 333func trimIPv6Brackets(ip string) string { 334 if len(ip) > 2 && ip[0] == '[' && ip[len(ip)-1] == ']' { 335 return ip[1 : len(ip)-1] 336 } 337 return ip 338} 339 340// extractIPsFromServerAddrs extracts IP addresses from server URLs returned by GetServerAddrs() 341// GetServerAddrs() returns URLs like "http://192.168.1.1:6806", this function extracts just the IP part 342func extractIPsFromServerAddrs() []string { 343 serverAddrs := GetServerAddrs() 344 var ips []string 345 for _, addr := range serverAddrs { 346 addr = strings.TrimPrefix(addr, "http://") 347 addr = strings.TrimPrefix(addr, "https://") 348 349 if strings.HasPrefix(addr, "[") { 350 // IPv6 address with brackets 351 if idx := strings.Index(addr, "]:"); idx != -1 { 352 addr = addr[1:idx] 353 } else if strings.HasSuffix(addr, "]") { 354 addr = addr[1 : len(addr)-1] 355 } 356 } else { 357 // IPv4 address or IPv6 without brackets 358 if idx := strings.LastIndex(addr, ":"); idx != -1 { 359 if strings.Count(addr, ":") == 1 { 360 // IPv4 with port 361 addr = addr[:idx] 362 } 363 } 364 } 365 ips = append(ips, addr) 366 } 367 return ips 368}