Based on https://github.com/nnevatie/capnwebcpp
at main 14 kB view raw
1// Package gocapnweb provides a Go implementation of the Cap'n Web RPC protocol server. 2// This library allows creating server implementations for the Cap'n Web RPC protocol 3// with support for WebSocket and HTTP batch endpoints. 4package gocapnweb 5 6import ( 7 "encoding/json" 8 "fmt" 9 "log" 10 "sync" 11) 12 13// RpcTarget defines the interface that server implementations must satisfy. 14// It provides method dispatch functionality for incoming RPC calls. 15type RpcTarget interface { 16 // Dispatch handles method calls and returns the result as JSON. 17 // It should return an error if the method is not found or execution fails. 18 Dispatch(method string, args json.RawMessage) (interface{}, error) 19} 20 21// BaseRpcTarget provides a convenient base implementation of RpcTarget 22// with method registration capabilities. 23type BaseRpcTarget struct { 24 methods map[string]func(json.RawMessage) (interface{}, error) 25 mu sync.RWMutex 26} 27 28// NewBaseRpcTarget creates a new BaseRpcTarget instance. 29func NewBaseRpcTarget() *BaseRpcTarget { 30 return &BaseRpcTarget{ 31 methods: make(map[string]func(json.RawMessage) (interface{}, error)), 32 } 33} 34 35// Method registers a method handler with the given name. 36func (t *BaseRpcTarget) Method(name string, handler func(json.RawMessage) (interface{}, error)) { 37 t.mu.Lock() 38 defer t.mu.Unlock() 39 t.methods[name] = handler 40} 41 42// Dispatch implements the RpcTarget interface. 43func (t *BaseRpcTarget) Dispatch(method string, args json.RawMessage) (interface{}, error) { 44 t.mu.RLock() 45 handler, exists := t.methods[method] 46 t.mu.RUnlock() 47 48 if !exists { 49 return nil, fmt.Errorf("method not found: %s", method) 50 } 51 52 return handler(args) 53} 54 55// SessionData holds the state for each RPC session (WebSocket connection or HTTP batch). 56type SessionData struct { 57 PendingResults map[int]interface{} `json:"pendingResults"` 58 PendingOperations map[int]Operation `json:"pendingOperations"` 59 NextExportID int `json:"nextExportId"` 60 Target RpcTarget `json:"-"` 61 mu sync.RWMutex 62} 63 64// Operation represents a pending RPC operation. 65type Operation struct { 66 Method string `json:"method"` 67 Args json.RawMessage `json:"args"` 68} 69 70// NewSessionData creates a new SessionData instance. 71func NewSessionData(target RpcTarget) *SessionData { 72 return &SessionData{ 73 PendingResults: make(map[int]interface{}), 74 PendingOperations: make(map[int]Operation), 75 NextExportID: 1, 76 Target: target, 77 } 78} 79 80// RpcSession handles the Cap'n Web RPC protocol for connections. 81type RpcSession struct { 82 target RpcTarget 83} 84 85// NewRpcSession creates a new RpcSession with the given target. 86func NewRpcSession(target RpcTarget) *RpcSession { 87 return &RpcSession{target: target} 88} 89 90// HandleMessage processes an incoming RPC message and returns the response. 91// Returns an empty string if no response should be sent. 92func (s *RpcSession) HandleMessage(sessionData *SessionData, message string) (string, error) { 93 var msg []interface{} 94 if err := json.Unmarshal([]byte(message), &msg); err != nil { 95 return "", fmt.Errorf("invalid message format: %w", err) 96 } 97 98 if len(msg) == 0 { 99 return "", fmt.Errorf("empty message") 100 } 101 102 messageType, ok := msg[0].(string) 103 if !ok { 104 return "", fmt.Errorf("invalid message type") 105 } 106 107 switch messageType { 108 case "push": 109 if len(msg) >= 2 { 110 s.handlePush(sessionData, msg[1]) 111 } 112 return "", nil // No response for push 113 114 case "pull": 115 if len(msg) >= 2 { 116 if exportIDFloat, ok := msg[1].(float64); ok { 117 exportID := int(exportIDFloat) 118 response, err := s.handlePull(sessionData, exportID) 119 if err != nil { 120 return "", err 121 } 122 responseBytes, err := json.Marshal(response) 123 if err != nil { 124 return "", err 125 } 126 return string(responseBytes), nil 127 } 128 } 129 130 case "release": 131 if len(msg) >= 3 { 132 if exportIDFloat, ok := msg[1].(float64); ok { 133 if refcountFloat, ok := msg[2].(float64); ok { 134 s.handleRelease(sessionData, int(exportIDFloat), int(refcountFloat)) 135 } 136 } 137 } 138 return "", nil // No response for release 139 140 case "abort": 141 if len(msg) >= 2 { 142 s.handleAbort(sessionData, msg[1]) 143 } 144 return "", nil // No response for abort 145 } 146 147 return "", nil 148} 149 150// OnOpen initializes a new session. 151func (s *RpcSession) OnOpen(sessionData *SessionData) { 152 log.Println("WebSocket connection opened") 153 sessionData.mu.Lock() 154 defer sessionData.mu.Unlock() 155 sessionData.NextExportID = 1 156 sessionData.PendingResults = make(map[int]interface{}) 157 sessionData.PendingOperations = make(map[int]Operation) 158} 159 160// OnClose cleans up a session. 161func (s *RpcSession) OnClose(sessionData *SessionData) { 162 log.Println("WebSocket connection closed") 163} 164 165func (s *RpcSession) handlePush(sessionData *SessionData, pushData interface{}) { 166 pushArray, ok := pushData.([]interface{}) 167 if !ok || len(pushArray) == 0 { 168 return 169 } 170 171 sessionData.mu.Lock() 172 defer sessionData.mu.Unlock() 173 174 // Create a new export on the server side 175 exportID := sessionData.NextExportID 176 sessionData.NextExportID++ 177 178 if len(pushArray) >= 3 && pushArray[0] == "pipeline" { 179 if importIDFloat, ok := pushArray[1].(float64); ok { 180 _ = int(importIDFloat) // importID for future use 181 182 if methodArray, ok := pushArray[2].([]interface{}); ok && len(methodArray) > 0 { 183 if method, ok := methodArray[0].(string); ok { 184 var args json.RawMessage 185 if len(pushArray) >= 4 { 186 argsBytes, _ := json.Marshal(pushArray[3]) 187 args = argsBytes 188 } else { 189 args = json.RawMessage("[]") 190 } 191 192 // Store the operation for lazy evaluation when pulled 193 sessionData.PendingOperations[exportID] = Operation{ 194 Method: method, 195 Args: args, 196 } 197 } 198 } 199 } 200 } 201} 202 203func (s *RpcSession) resolvePipelineReferences(sessionData *SessionData, value interface{}) (interface{}, error) { 204 switch v := value.(type) { 205 case []interface{}: 206 // Check if this is a pipeline reference: ["pipeline", exportId, ["path", ...]] 207 if len(v) >= 2 { 208 if pipelineStr, ok := v[0].(string); ok && pipelineStr == "pipeline" { 209 if refExportIDFloat, ok := v[1].(float64); ok { 210 refExportID := int(refExportIDFloat) 211 212 sessionData.mu.RLock() 213 // Check if result is already computed 214 if result, exists := sessionData.PendingResults[refExportID]; exists { 215 sessionData.mu.RUnlock() 216 217 // If there's a path, traverse it 218 if len(v) >= 3 { 219 if pathArray, ok := v[2].([]interface{}); ok { 220 return s.traversePath(result, pathArray) 221 } 222 } 223 return result, nil 224 } 225 226 // Check if we need to execute a pending operation 227 if operation, exists := sessionData.PendingOperations[refExportID]; exists { 228 sessionData.mu.RUnlock() 229 230 // Recursively resolve arguments 231 var args interface{} 232 if err := json.Unmarshal(operation.Args, &args); err != nil { 233 return nil, err 234 } 235 resolvedArgs, err := s.resolvePipelineReferences(sessionData, args) 236 if err != nil { 237 return nil, err 238 } 239 240 resolvedArgsBytes, err := json.Marshal(resolvedArgs) 241 if err != nil { 242 return nil, err 243 } 244 245 // Execute the operation 246 result, err := sessionData.Target.Dispatch(operation.Method, resolvedArgsBytes) 247 if err != nil { 248 return nil, err 249 } 250 251 // Normalize the result for pipeline traversal 252 normalizedResult, err := s.normalizeResult(result) 253 if err != nil { 254 return nil, err 255 } 256 257 // Cache the normalized result 258 sessionData.mu.Lock() 259 sessionData.PendingResults[refExportID] = normalizedResult 260 delete(sessionData.PendingOperations, refExportID) 261 sessionData.mu.Unlock() 262 263 // If there's a path, traverse it 264 if len(v) >= 3 { 265 if pathArray, ok := v[2].([]interface{}); ok { 266 return s.traversePath(normalizedResult, pathArray) 267 } 268 } 269 return normalizedResult, nil 270 } 271 sessionData.mu.RUnlock() 272 273 return nil, fmt.Errorf("pipeline reference to non-existent export: %d", refExportID) 274 } 275 } 276 } 277 278 // Not a pipeline reference, recursively resolve elements 279 resolved := make([]interface{}, len(v)) 280 for i, elem := range v { 281 resolvedElem, err := s.resolvePipelineReferences(sessionData, elem) 282 if err != nil { 283 return nil, err 284 } 285 resolved[i] = resolvedElem 286 } 287 return resolved, nil 288 289 case map[string]interface{}: 290 // Recursively resolve object values 291 resolved := make(map[string]interface{}) 292 for key, val := range v { 293 resolvedVal, err := s.resolvePipelineReferences(sessionData, val) 294 if err != nil { 295 return nil, err 296 } 297 resolved[key] = resolvedVal 298 } 299 return resolved, nil 300 301 default: 302 // Primitive value, return as-is 303 return value, nil 304 } 305} 306 307func (s *RpcSession) traversePath(result interface{}, path []interface{}) (interface{}, error) { 308 current := result 309 for _, key := range path { 310 switch k := key.(type) { 311 case string: 312 if obj, ok := current.(map[string]interface{}); ok { 313 current = obj[k] 314 } else { 315 return nil, fmt.Errorf("cannot traverse string key on non-object") 316 } 317 case float64: 318 if arr, ok := current.([]interface{}); ok { 319 idx := int(k) 320 if idx < 0 || idx >= len(arr) { 321 return nil, fmt.Errorf("array index out of bounds") 322 } 323 current = arr[idx] 324 } else { 325 return nil, fmt.Errorf("cannot traverse numeric key on non-array") 326 } 327 default: 328 return nil, fmt.Errorf("invalid path key type") 329 } 330 } 331 return current, nil 332} 333 334func (s *RpcSession) handlePull(sessionData *SessionData, exportID int) ([]interface{}, error) { 335 sessionData.mu.RLock() 336 // Check if we already have a cached result 337 if result, exists := sessionData.PendingResults[exportID]; exists { 338 sessionData.mu.RUnlock() 339 340 // Clean up 341 sessionData.mu.Lock() 342 delete(sessionData.PendingResults, exportID) 343 sessionData.mu.Unlock() 344 345 // Check if the stored result is an error 346 if errArray, ok := result.([]interface{}); ok && len(errArray) >= 2 { 347 if errType, ok := errArray[0].(string); ok && errType == "error" { 348 // Send as reject 349 return []interface{}{"reject", exportID, result}, nil 350 } 351 } 352 353 // Send as resolve 354 // Arrays need to be wrapped in another array to escape them per Cap'n Web protocol 355 if _, ok := result.([]interface{}); ok { 356 return []interface{}{"resolve", exportID, []interface{}{result}}, nil 357 } 358 return []interface{}{"resolve", exportID, result}, nil 359 } 360 361 // Check if we have a pending operation to execute 362 if operation, exists := sessionData.PendingOperations[exportID]; exists { 363 sessionData.mu.RUnlock() 364 365 // Resolve any pipeline references in the arguments 366 var args interface{} 367 if err := json.Unmarshal(operation.Args, &args); err != nil { 368 return s.createErrorResponse(exportID, "ArgumentError", err.Error()), nil 369 } 370 371 resolvedArgs, err := s.resolvePipelineReferences(sessionData, args) 372 if err != nil { 373 return s.createErrorResponse(exportID, "PipelineError", err.Error()), nil 374 } 375 376 resolvedArgsBytes, err := json.Marshal(resolvedArgs) 377 if err != nil { 378 return s.createErrorResponse(exportID, "SerializationError", err.Error()), nil 379 } 380 381 // Dispatch the method call to the target 382 result, err := sessionData.Target.Dispatch(operation.Method, resolvedArgsBytes) 383 384 // Clean up the operation 385 sessionData.mu.Lock() 386 delete(sessionData.PendingOperations, exportID) 387 sessionData.mu.Unlock() 388 389 if err != nil { 390 return s.createErrorResponse(exportID, "MethodError", err.Error()), nil 391 } 392 393 // Normalize the result to ensure it's JSON-compatible for pipeline traversal 394 normalizedResult, err := s.normalizeResult(result) 395 if err != nil { 396 return s.createErrorResponse(exportID, "SerializationError", err.Error()), nil 397 } 398 399 // Store the normalized result for future reference 400 sessionData.mu.Lock() 401 sessionData.PendingResults[exportID] = normalizedResult 402 sessionData.mu.Unlock() 403 404 // Send as resolve 405 if _, ok := normalizedResult.([]interface{}); ok { 406 return []interface{}{"resolve", exportID, []interface{}{normalizedResult}}, nil 407 } 408 return []interface{}{"resolve", exportID, normalizedResult}, nil 409 } 410 sessionData.mu.RUnlock() 411 412 // Export ID not found - send an error 413 return []interface{}{"reject", exportID, []interface{}{ 414 "error", "ExportNotFound", "Export ID not found", 415 }}, nil 416} 417 418func (s *RpcSession) createErrorResponse(exportID int, errorType, message string) []interface{} { 419 return []interface{}{"reject", exportID, []interface{}{ 420 "error", errorType, message, 421 }} 422} 423 424func (s *RpcSession) handleRelease(sessionData *SessionData, exportID, refcount int) { 425 log.Printf("Released export %d with refcount %d", exportID, refcount) 426} 427 428func (s *RpcSession) handleAbort(sessionData *SessionData, errorData interface{}) { 429 errorBytes, _ := json.Marshal(errorData) 430 log.Printf("Abort received: %s", string(errorBytes)) 431} 432 433// normalizeResult ensures that Go structs are converted to map[string]interface{} 434// for proper pipeline traversal. This is necessary because Go structs need to be 435// JSON-marshaled and then unmarshaled to become navigable objects. 436func (s *RpcSession) normalizeResult(result interface{}) (interface{}, error) { 437 // If it's already a map[string]interface{} or basic type, return as-is 438 switch result.(type) { 439 case map[string]interface{}, []interface{}, string, float64, bool, nil: 440 return result, nil 441 } 442 443 // For other types (like structs), marshal to JSON and unmarshal to interface{} 444 // This converts structs to map[string]interface{} which can be traversed 445 jsonBytes, err := json.Marshal(result) 446 if err != nil { 447 return nil, fmt.Errorf("failed to marshal result: %w", err) 448 } 449 450 var normalized interface{} 451 if err := json.Unmarshal(jsonBytes, &normalized); err != nil { 452 return nil, fmt.Errorf("failed to unmarshal result: %w", err) 453 } 454 455 return normalized, nil 456}