Based on https://github.com/nnevatie/capnwebcpp
1// Package gocapnweb provides a Go implementation of the Cap'n Web RPC protocol server.
2// This library allows creating server implementations for the Cap'n Web RPC protocol
3// with support for WebSocket and HTTP batch endpoints.
4package gocapnweb
5
6import (
7 "encoding/json"
8 "fmt"
9 "log"
10 "sync"
11)
12
13// RpcTarget defines the interface that server implementations must satisfy.
14// It provides method dispatch functionality for incoming RPC calls.
15type RpcTarget interface {
16 // Dispatch handles method calls and returns the result as JSON.
17 // It should return an error if the method is not found or execution fails.
18 Dispatch(method string, args json.RawMessage) (interface{}, error)
19}
20
21// BaseRpcTarget provides a convenient base implementation of RpcTarget
22// with method registration capabilities.
23type BaseRpcTarget struct {
24 methods map[string]func(json.RawMessage) (interface{}, error)
25 mu sync.RWMutex
26}
27
28// NewBaseRpcTarget creates a new BaseRpcTarget instance.
29func NewBaseRpcTarget() *BaseRpcTarget {
30 return &BaseRpcTarget{
31 methods: make(map[string]func(json.RawMessage) (interface{}, error)),
32 }
33}
34
35// Method registers a method handler with the given name.
36func (t *BaseRpcTarget) Method(name string, handler func(json.RawMessage) (interface{}, error)) {
37 t.mu.Lock()
38 defer t.mu.Unlock()
39 t.methods[name] = handler
40}
41
42// Dispatch implements the RpcTarget interface.
43func (t *BaseRpcTarget) Dispatch(method string, args json.RawMessage) (interface{}, error) {
44 t.mu.RLock()
45 handler, exists := t.methods[method]
46 t.mu.RUnlock()
47
48 if !exists {
49 return nil, fmt.Errorf("method not found: %s", method)
50 }
51
52 return handler(args)
53}
54
55// SessionData holds the state for each RPC session (WebSocket connection or HTTP batch).
56type SessionData struct {
57 PendingResults map[int]interface{} `json:"pendingResults"`
58 PendingOperations map[int]Operation `json:"pendingOperations"`
59 NextExportID int `json:"nextExportId"`
60 Target RpcTarget `json:"-"`
61 mu sync.RWMutex
62}
63
64// Operation represents a pending RPC operation.
65type Operation struct {
66 Method string `json:"method"`
67 Args json.RawMessage `json:"args"`
68}
69
70// NewSessionData creates a new SessionData instance.
71func NewSessionData(target RpcTarget) *SessionData {
72 return &SessionData{
73 PendingResults: make(map[int]interface{}),
74 PendingOperations: make(map[int]Operation),
75 NextExportID: 1,
76 Target: target,
77 }
78}
79
80// RpcSession handles the Cap'n Web RPC protocol for connections.
81type RpcSession struct {
82 target RpcTarget
83}
84
85// NewRpcSession creates a new RpcSession with the given target.
86func NewRpcSession(target RpcTarget) *RpcSession {
87 return &RpcSession{target: target}
88}
89
90// HandleMessage processes an incoming RPC message and returns the response.
91// Returns an empty string if no response should be sent.
92func (s *RpcSession) HandleMessage(sessionData *SessionData, message string) (string, error) {
93 var msg []interface{}
94 if err := json.Unmarshal([]byte(message), &msg); err != nil {
95 return "", fmt.Errorf("invalid message format: %w", err)
96 }
97
98 if len(msg) == 0 {
99 return "", fmt.Errorf("empty message")
100 }
101
102 messageType, ok := msg[0].(string)
103 if !ok {
104 return "", fmt.Errorf("invalid message type")
105 }
106
107 switch messageType {
108 case "push":
109 if len(msg) >= 2 {
110 s.handlePush(sessionData, msg[1])
111 }
112 return "", nil // No response for push
113
114 case "pull":
115 if len(msg) >= 2 {
116 if exportIDFloat, ok := msg[1].(float64); ok {
117 exportID := int(exportIDFloat)
118 response, err := s.handlePull(sessionData, exportID)
119 if err != nil {
120 return "", err
121 }
122 responseBytes, err := json.Marshal(response)
123 if err != nil {
124 return "", err
125 }
126 return string(responseBytes), nil
127 }
128 }
129
130 case "release":
131 if len(msg) >= 3 {
132 if exportIDFloat, ok := msg[1].(float64); ok {
133 if refcountFloat, ok := msg[2].(float64); ok {
134 s.handleRelease(sessionData, int(exportIDFloat), int(refcountFloat))
135 }
136 }
137 }
138 return "", nil // No response for release
139
140 case "abort":
141 if len(msg) >= 2 {
142 s.handleAbort(sessionData, msg[1])
143 }
144 return "", nil // No response for abort
145 }
146
147 return "", nil
148}
149
150// OnOpen initializes a new session.
151func (s *RpcSession) OnOpen(sessionData *SessionData) {
152 log.Println("WebSocket connection opened")
153 sessionData.mu.Lock()
154 defer sessionData.mu.Unlock()
155 sessionData.NextExportID = 1
156 sessionData.PendingResults = make(map[int]interface{})
157 sessionData.PendingOperations = make(map[int]Operation)
158}
159
160// OnClose cleans up a session.
161func (s *RpcSession) OnClose(sessionData *SessionData) {
162 log.Println("WebSocket connection closed")
163}
164
165func (s *RpcSession) handlePush(sessionData *SessionData, pushData interface{}) {
166 pushArray, ok := pushData.([]interface{})
167 if !ok || len(pushArray) == 0 {
168 return
169 }
170
171 sessionData.mu.Lock()
172 defer sessionData.mu.Unlock()
173
174 // Create a new export on the server side
175 exportID := sessionData.NextExportID
176 sessionData.NextExportID++
177
178 if len(pushArray) >= 3 && pushArray[0] == "pipeline" {
179 if importIDFloat, ok := pushArray[1].(float64); ok {
180 _ = int(importIDFloat) // importID for future use
181
182 if methodArray, ok := pushArray[2].([]interface{}); ok && len(methodArray) > 0 {
183 if method, ok := methodArray[0].(string); ok {
184 var args json.RawMessage
185 if len(pushArray) >= 4 {
186 argsBytes, _ := json.Marshal(pushArray[3])
187 args = argsBytes
188 } else {
189 args = json.RawMessage("[]")
190 }
191
192 // Store the operation for lazy evaluation when pulled
193 sessionData.PendingOperations[exportID] = Operation{
194 Method: method,
195 Args: args,
196 }
197 }
198 }
199 }
200 }
201}
202
203func (s *RpcSession) resolvePipelineReferences(sessionData *SessionData, value interface{}) (interface{}, error) {
204 switch v := value.(type) {
205 case []interface{}:
206 // Check if this is a pipeline reference: ["pipeline", exportId, ["path", ...]]
207 if len(v) >= 2 {
208 if pipelineStr, ok := v[0].(string); ok && pipelineStr == "pipeline" {
209 if refExportIDFloat, ok := v[1].(float64); ok {
210 refExportID := int(refExportIDFloat)
211
212 sessionData.mu.RLock()
213 // Check if result is already computed
214 if result, exists := sessionData.PendingResults[refExportID]; exists {
215 sessionData.mu.RUnlock()
216
217 // If there's a path, traverse it
218 if len(v) >= 3 {
219 if pathArray, ok := v[2].([]interface{}); ok {
220 return s.traversePath(result, pathArray)
221 }
222 }
223 return result, nil
224 }
225
226 // Check if we need to execute a pending operation
227 if operation, exists := sessionData.PendingOperations[refExportID]; exists {
228 sessionData.mu.RUnlock()
229
230 // Recursively resolve arguments
231 var args interface{}
232 if err := json.Unmarshal(operation.Args, &args); err != nil {
233 return nil, err
234 }
235 resolvedArgs, err := s.resolvePipelineReferences(sessionData, args)
236 if err != nil {
237 return nil, err
238 }
239
240 resolvedArgsBytes, err := json.Marshal(resolvedArgs)
241 if err != nil {
242 return nil, err
243 }
244
245 // Execute the operation
246 result, err := sessionData.Target.Dispatch(operation.Method, resolvedArgsBytes)
247 if err != nil {
248 return nil, err
249 }
250
251 // Normalize the result for pipeline traversal
252 normalizedResult, err := s.normalizeResult(result)
253 if err != nil {
254 return nil, err
255 }
256
257 // Cache the normalized result
258 sessionData.mu.Lock()
259 sessionData.PendingResults[refExportID] = normalizedResult
260 delete(sessionData.PendingOperations, refExportID)
261 sessionData.mu.Unlock()
262
263 // If there's a path, traverse it
264 if len(v) >= 3 {
265 if pathArray, ok := v[2].([]interface{}); ok {
266 return s.traversePath(normalizedResult, pathArray)
267 }
268 }
269 return normalizedResult, nil
270 }
271 sessionData.mu.RUnlock()
272
273 return nil, fmt.Errorf("pipeline reference to non-existent export: %d", refExportID)
274 }
275 }
276 }
277
278 // Not a pipeline reference, recursively resolve elements
279 resolved := make([]interface{}, len(v))
280 for i, elem := range v {
281 resolvedElem, err := s.resolvePipelineReferences(sessionData, elem)
282 if err != nil {
283 return nil, err
284 }
285 resolved[i] = resolvedElem
286 }
287 return resolved, nil
288
289 case map[string]interface{}:
290 // Recursively resolve object values
291 resolved := make(map[string]interface{})
292 for key, val := range v {
293 resolvedVal, err := s.resolvePipelineReferences(sessionData, val)
294 if err != nil {
295 return nil, err
296 }
297 resolved[key] = resolvedVal
298 }
299 return resolved, nil
300
301 default:
302 // Primitive value, return as-is
303 return value, nil
304 }
305}
306
307func (s *RpcSession) traversePath(result interface{}, path []interface{}) (interface{}, error) {
308 current := result
309 for _, key := range path {
310 switch k := key.(type) {
311 case string:
312 if obj, ok := current.(map[string]interface{}); ok {
313 current = obj[k]
314 } else {
315 return nil, fmt.Errorf("cannot traverse string key on non-object")
316 }
317 case float64:
318 if arr, ok := current.([]interface{}); ok {
319 idx := int(k)
320 if idx < 0 || idx >= len(arr) {
321 return nil, fmt.Errorf("array index out of bounds")
322 }
323 current = arr[idx]
324 } else {
325 return nil, fmt.Errorf("cannot traverse numeric key on non-array")
326 }
327 default:
328 return nil, fmt.Errorf("invalid path key type")
329 }
330 }
331 return current, nil
332}
333
334func (s *RpcSession) handlePull(sessionData *SessionData, exportID int) ([]interface{}, error) {
335 sessionData.mu.RLock()
336 // Check if we already have a cached result
337 if result, exists := sessionData.PendingResults[exportID]; exists {
338 sessionData.mu.RUnlock()
339
340 // Clean up
341 sessionData.mu.Lock()
342 delete(sessionData.PendingResults, exportID)
343 sessionData.mu.Unlock()
344
345 // Check if the stored result is an error
346 if errArray, ok := result.([]interface{}); ok && len(errArray) >= 2 {
347 if errType, ok := errArray[0].(string); ok && errType == "error" {
348 // Send as reject
349 return []interface{}{"reject", exportID, result}, nil
350 }
351 }
352
353 // Send as resolve
354 // Arrays need to be wrapped in another array to escape them per Cap'n Web protocol
355 if _, ok := result.([]interface{}); ok {
356 return []interface{}{"resolve", exportID, []interface{}{result}}, nil
357 }
358 return []interface{}{"resolve", exportID, result}, nil
359 }
360
361 // Check if we have a pending operation to execute
362 if operation, exists := sessionData.PendingOperations[exportID]; exists {
363 sessionData.mu.RUnlock()
364
365 // Resolve any pipeline references in the arguments
366 var args interface{}
367 if err := json.Unmarshal(operation.Args, &args); err != nil {
368 return s.createErrorResponse(exportID, "ArgumentError", err.Error()), nil
369 }
370
371 resolvedArgs, err := s.resolvePipelineReferences(sessionData, args)
372 if err != nil {
373 return s.createErrorResponse(exportID, "PipelineError", err.Error()), nil
374 }
375
376 resolvedArgsBytes, err := json.Marshal(resolvedArgs)
377 if err != nil {
378 return s.createErrorResponse(exportID, "SerializationError", err.Error()), nil
379 }
380
381 // Dispatch the method call to the target
382 result, err := sessionData.Target.Dispatch(operation.Method, resolvedArgsBytes)
383
384 // Clean up the operation
385 sessionData.mu.Lock()
386 delete(sessionData.PendingOperations, exportID)
387 sessionData.mu.Unlock()
388
389 if err != nil {
390 return s.createErrorResponse(exportID, "MethodError", err.Error()), nil
391 }
392
393 // Normalize the result to ensure it's JSON-compatible for pipeline traversal
394 normalizedResult, err := s.normalizeResult(result)
395 if err != nil {
396 return s.createErrorResponse(exportID, "SerializationError", err.Error()), nil
397 }
398
399 // Store the normalized result for future reference
400 sessionData.mu.Lock()
401 sessionData.PendingResults[exportID] = normalizedResult
402 sessionData.mu.Unlock()
403
404 // Send as resolve
405 if _, ok := normalizedResult.([]interface{}); ok {
406 return []interface{}{"resolve", exportID, []interface{}{normalizedResult}}, nil
407 }
408 return []interface{}{"resolve", exportID, normalizedResult}, nil
409 }
410 sessionData.mu.RUnlock()
411
412 // Export ID not found - send an error
413 return []interface{}{"reject", exportID, []interface{}{
414 "error", "ExportNotFound", "Export ID not found",
415 }}, nil
416}
417
418func (s *RpcSession) createErrorResponse(exportID int, errorType, message string) []interface{} {
419 return []interface{}{"reject", exportID, []interface{}{
420 "error", errorType, message,
421 }}
422}
423
424func (s *RpcSession) handleRelease(sessionData *SessionData, exportID, refcount int) {
425 log.Printf("Released export %d with refcount %d", exportID, refcount)
426}
427
428func (s *RpcSession) handleAbort(sessionData *SessionData, errorData interface{}) {
429 errorBytes, _ := json.Marshal(errorData)
430 log.Printf("Abort received: %s", string(errorBytes))
431}
432
433// normalizeResult ensures that Go structs are converted to map[string]interface{}
434// for proper pipeline traversal. This is necessary because Go structs need to be
435// JSON-marshaled and then unmarshaled to become navigable objects.
436func (s *RpcSession) normalizeResult(result interface{}) (interface{}, error) {
437 // If it's already a map[string]interface{} or basic type, return as-is
438 switch result.(type) {
439 case map[string]interface{}, []interface{}, string, float64, bool, nil:
440 return result, nil
441 }
442
443 // For other types (like structs), marshal to JSON and unmarshal to interface{}
444 // This converts structs to map[string]interface{} which can be traversed
445 jsonBytes, err := json.Marshal(result)
446 if err != nil {
447 return nil, fmt.Errorf("failed to marshal result: %w", err)
448 }
449
450 var normalized interface{}
451 if err := json.Unmarshal(jsonBytes, &normalized); err != nil {
452 return nil, fmt.Errorf("failed to unmarshal result: %w", err)
453 }
454
455 return normalized, nil
456}