this repo has no description
1#!/usr/bin/env python3
2# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
3import contextvars
4import unittest
5
6
7class ContextTests(unittest.TestCase):
8 def test_class_getitem_returns_cls(self):
9 self.assertIs(
10 contextvars.ContextVar.__class_getitem__(int), contextvars.ContextVar
11 )
12
13 def test_class_getitem_with_wrong_cls_raises_type_error(self):
14 class C:
15 m = contextvars.ContextVar.__dict__["__class_getitem__"]
16
17 with self.assertRaisesRegex(
18 TypeError,
19 "descriptor '__class_getitem__' requires a subtype of 'ContextVar' but received 'C'",
20 ):
21 C.m(None)
22
23 def test_dunder_contains_with_invalid_self(self):
24 with self.assertRaises(TypeError):
25 contextvars.Context.__contains__(None, None)
26
27 def test_dunder_eq_with_invalid_self(self):
28 with self.assertRaises(TypeError):
29 contextvars.Context.__eq__(None, None)
30
31 def test_dunder_getitem_with_invalid_self(self):
32 with self.assertRaises(TypeError):
33 contextvars.Context.__getitem__(None, None)
34
35 def test_dunder_iter_with_invalid_self(self):
36 with self.assertRaises(TypeError):
37 contextvars.Context.__iter__(None)
38
39 def test_dunder_len_with_invalid_self(self):
40 with self.assertRaises(TypeError):
41 contextvars.Context.__len__(None)
42
43 def test_dunder_new_with_invalid_cls(self):
44 with self.assertRaises(TypeError):
45 contextvars.Context.__new__(None)
46
47 def test_copy_with_invalid_self(self):
48 with self.assertRaises(TypeError):
49 contextvars.Context.copy(None)
50
51 def test_get_with_invalid_self(self):
52 with self.assertRaises(TypeError):
53 contextvars.Context.get(None, None)
54
55 def test_items_with_invalid_self(self):
56 with self.assertRaises(TypeError):
57 contextvars.Context.items(None)
58
59 def test_keys_with_invalid_self(self):
60 with self.assertRaises(TypeError):
61 contextvars.Context.keys(None)
62
63 def test_run_with_invalid_self(self):
64 with self.assertRaises(TypeError):
65 contextvars.Context.keys(None, None)
66
67 def test_values_with_invalid_self(self):
68 with self.assertRaises(TypeError):
69 contextvars.Context.values(None)
70
71 def test_dunder_contains_only_allows_context_var_types(self):
72 with self.assertRaises(TypeError):
73 contextvars.Context().__contains__(None)
74
75 def test_dunder_contains_finds_var_from_context(self):
76 ctx = contextvars.Context()
77
78 def f_with_ctx():
79 var = contextvars.ContextVar("foo")
80 var.set(1)
81 self.assertTrue(ctx.__contains__(var))
82
83 ctx.run(f_with_ctx)
84
85 def test_dunder_contains_does_not_find_var_from_another_context(self):
86 var = contextvars.ContextVar("foo")
87 var.set(1)
88 ctx = contextvars.Context()
89
90 def f_with_ctx():
91 self.assertFalse(ctx.__contains__(var))
92
93 ctx.run(f_with_ctx)
94
95 def test_dunder_eq_matches_with_same_data(self):
96 ctx1 = contextvars.Context()
97 var = None
98
99 def f_with_ctx1():
100 nonlocal var
101 var = contextvars.ContextVar("foo")
102 var.set(1)
103
104 ctx1.run(f_with_ctx1)
105
106 ctx2 = contextvars.Context()
107
108 def f_with_ctx2():
109 nonlocal var
110 var.set(1)
111
112 ctx2.run(f_with_ctx2)
113 self.assertTrue(ctx1.__eq__(ctx2))
114
115 def test_dunder_eq_does_not_match_with_different_data(self):
116 ctx1 = contextvars.Context()
117 var = None
118
119 def f_with_ctx1():
120 nonlocal var
121 var = contextvars.ContextVar("foo")
122 var.set(1)
123
124 ctx1.run(f_with_ctx1)
125 ctx2 = contextvars.Context()
126
127 def f_with_ctx2():
128 nonlocal var
129 var.set(2)
130
131 ctx2.run(f_with_ctx2)
132 self.assertFalse(ctx1.__eq__(ctx2))
133
134 def test_dunder_eq_returns_not_implemented_on_mismatched_other_type(self):
135 self.assertEqual(contextvars.Context().__eq__(None), NotImplemented)
136
137 def test_dunder_getitem_only_allows_context_var_types(self):
138 with self.assertRaises(TypeError):
139 contextvars.Context().__getitem__(None)
140
141 def test_dunder_getitem_returns_var_from_context(self):
142 ctx = contextvars.Context()
143
144 def f_with_ctx():
145 nonlocal ctx
146 var = contextvars.ContextVar("foo")
147 var.set(1)
148 self.assertEqual(ctx.__getitem__(var), 1)
149
150 ctx.run(f_with_ctx)
151
152 def test_dunder_getitem_does_not_return_var_from_another_context(self):
153 var = contextvars.ContextVar("foo")
154 var.set(1)
155 ctx = contextvars.Context()
156
157 def f_with_ctx():
158 nonlocal var
159 with self.assertRaises(KeyError):
160 contextvars.Context().__getitem__(var)
161
162 ctx.run(f_with_ctx)
163
164 def test_dunder_hash_is_not_set(self):
165 self.assertIsNone(contextvars.Context.__hash__)
166
167 def test_dunder_len_with_0_items(self):
168 ctx = contextvars.Context()
169 self.assertEqual(ctx.__len__(), 0)
170
171 def test_dunder_len_with_1_items(self):
172 ctx = contextvars.Context()
173
174 def f_with_ctx():
175 var = contextvars.ContextVar("foo")
176 var.set(1)
177
178 ctx.run(f_with_ctx)
179 self.assertEqual(ctx.__len__(), 1)
180
181 def test_dunder_len_with_2_items(self):
182 ctx = contextvars.Context()
183
184 def f_with_ctx():
185 var1 = contextvars.ContextVar("foo")
186 var1.set(1)
187 var2 = contextvars.ContextVar("bar")
188 var2.set(2)
189
190 ctx.run(f_with_ctx)
191 self.assertEqual(ctx.__len__(), 2)
192
193 def test_dunder_iter_with_0_items(self):
194 ctx = contextvars.Context()
195 self.assertEqual(list(ctx), [])
196
197 def test_dunder_iter_with_1_items(self):
198 ctx = contextvars.Context()
199
200 def f_with_ctx():
201 nonlocal ctx
202 var = contextvars.ContextVar("foo")
203 var.set(1)
204 self.assertEqual(list(ctx), [var])
205
206 ctx.run(f_with_ctx)
207
208 def test_dunder_iter_with_2_items(self):
209 ctx = contextvars.Context()
210
211 def f_with_ctx():
212 nonlocal ctx
213 var1 = contextvars.ContextVar("foo")
214 var1.set(1)
215 var2 = contextvars.ContextVar("bar")
216 var2.set(2)
217 self.assertEqual(set(ctx), {var1, var2})
218
219 ctx.run(f_with_ctx)
220
221 def test_copy_produces_context_with_same_data(self):
222 ctx = contextvars.Context()
223
224 def f_with_ctx():
225 var1 = contextvars.ContextVar("foo")
226 var1.set(1)
227 var2 = contextvars.ContextVar("bar")
228 var2.set(2)
229
230 ctx.run(f_with_ctx)
231 ctx_copy = ctx.copy()
232 self.assertEqual(list(ctx_copy.items()), list(ctx.items()))
233
234 def test_copy_produces_new_context(self):
235 ctx = contextvars.Context()
236 ctx_copy = ctx.copy()
237 self.assertIsNot(ctx, ctx_copy)
238
239 def test_copy_result_has_no_prev_context(self):
240 ctx = contextvars.Context()
241
242 def f():
243 def g():
244 pass
245
246 ctx_copy = ctx.copy()
247 # This would raise a RuntimeError if 'ctx' were used
248 ctx_copy.run(g)
249
250 def test_copy_context_mutation_does_not_affect_source_context(self):
251 ctx = contextvars.Context()
252 ctx_copy = ctx.copy()
253
254 def f_with_ctx_copy():
255 nonlocal ctx
256 var = contextvars.ContextVar("foo")
257 var.set(1)
258 self.assertFalse(ctx.__contains__(var))
259
260 ctx_copy.run(f_with_ctx_copy)
261
262 def test_copy_source_mutation_does_not_affect_copy_context(self):
263 ctx = contextvars.Context()
264 ctx_copy = ctx.copy()
265
266 def f_with_ctx():
267 nonlocal ctx_copy
268 var = contextvars.ContextVar("foo")
269 var.set(1)
270 self.assertFalse(ctx_copy.__contains__(var))
271
272 ctx.run(f_with_ctx)
273
274 def test_get_with_invalid_key_type_raises(self):
275 ctx = contextvars.Context()
276 with self.assertRaises(TypeError):
277 ctx.get(None)
278
279 def test_get_valid_key_returns_value(self):
280 ctx = contextvars.Context()
281
282 def f_with_ctx():
283 var = contextvars.ContextVar("foo")
284 var.set(1)
285 self.assertEqual(ctx.get(var), 1)
286
287 ctx.run(f_with_ctx)
288
289 def test_get_missing_key_returns_specified_default(self):
290 var = contextvars.ContextVar("foo")
291 var.set(1)
292 ctx = contextvars.Context()
293 self.assertEqual(ctx.get(var, 2), 2)
294
295 def test_get_missing_key_is_none_with_no_specified_default(self):
296 var = contextvars.ContextVar("foo")
297 var.set(1)
298 ctx = contextvars.Context()
299 self.assertIsNone(ctx.get(var))
300
301 def test_items(self):
302 ctx = contextvars.Context()
303
304 def f_with_ctx():
305 nonlocal ctx
306 var1 = contextvars.ContextVar("foo")
307 var1.set(1)
308 var2 = contextvars.ContextVar("bar")
309 var2.set(2)
310 self.assertEqual(set(ctx.items()), {(var1, 1), (var2, 2)})
311
312 ctx.run(f_with_ctx)
313
314 def test_keys(self):
315 ctx = contextvars.Context()
316
317 def f_with_ctx():
318 nonlocal ctx
319 var1 = contextvars.ContextVar("foo")
320 var1.set(1)
321 var2 = contextvars.ContextVar("bar")
322 var2.set(2)
323 self.assertEqual(set(ctx.keys()), {var1, var2})
324
325 ctx.run(f_with_ctx)
326
327 def test_values(self):
328 ctx = contextvars.Context()
329
330 def f_with_ctx():
331 nonlocal ctx
332 var1 = contextvars.ContextVar("foo")
333 var1.set(1)
334 var2 = contextvars.ContextVar("bar")
335 var2.set(2)
336 self.assertEqual(set(ctx.values()), {1, 2})
337
338 ctx.run(f_with_ctx)
339
340 def test_run_passes_through_args(self):
341 def f(a, *, b=None):
342 a.assertEqual(b, 1)
343
344 contextvars.Context().run(f, self, b=1)
345
346 def test_run_cannot_have_nested_calls_to_same_context(self):
347 ctx = contextvars.Context()
348
349 def f():
350 def g():
351 pass
352
353 ctx.run(g)
354
355 with self.assertRaises(RuntimeError):
356 ctx.run(f)
357
358 def test_run_can_be_called_multiple_times_on_context(self):
359 ctx = contextvars.Context()
360
361 def f():
362 pass
363
364 ctx.run(f)
365 ctx.run(f)
366
367 def test_run_can_have_nested_calls_to_copied_context(self):
368 ctx = contextvars.Context()
369 ctx_copy = ctx.copy()
370
371 def f():
372 def g():
373 pass
374
375 ctx_copy.run(g)
376
377 ctx.run(f)
378
379 def test_run_protects_previous_var_values(self):
380 def f_with_empty_context():
381 var = contextvars.ContextVar("foo")
382 var.set(1)
383 ctx = contextvars.Context()
384 ctx.run(lambda: var.set(2))
385 self.assertEqual(var.get(), 1)
386
387 contextvars.Context().run(f_with_empty_context)
388
389
390class ContextVarTests(unittest.TestCase):
391 def test_dunder_new_with_invalid_cls(self):
392 with self.assertRaises(TypeError):
393 contextvars.ContextVar.__new__(None, None)
394
395 def test_dunder_repr_with_invalid_self(self):
396 with self.assertRaises(TypeError):
397 contextvars.ContextVar.__repr__(None)
398
399 def test_get_with_invalid_self(self):
400 with self.assertRaises(TypeError):
401 contextvars.ContextVar.get(None)
402
403 def test_set_with_invalid_self(self):
404 with self.assertRaises(TypeError):
405 contextvars.ContextVar.set(None, None)
406
407 def test_reset_with_invalid_self(self):
408 with self.assertRaises(TypeError):
409 contextvars.ContextVar.reset(None, None)
410
411 def test_dunder_repr_with_default(self):
412 c = contextvars.ContextVar("foo", default=None)
413 self.assertRegex(
414 c.__repr__(), r"<ContextVar name='foo' default=None at 0x[0-9a-f]+>"
415 )
416
417 def test_dunder_new_with_non_string_name(self):
418 with self.assertRaises(TypeError):
419 contextvars.ContextVar.__new__(contextvars.ContextVar, None)
420
421 def test_dunder_repr_without_default(self):
422 var = contextvars.ContextVar("foo")
423 self.assertRegex(var.__repr__(), r"<ContextVar name='foo' at 0x[0-9a-f]+>")
424
425 def test_dunder_repr_recursive(self):
426 backref = []
427 var = contextvars.ContextVar("foo", default=backref)
428 backref.append(var)
429 self.assertIn("...", var.__repr__())
430
431 def test_name_property(self):
432 var = contextvars.ContextVar("foo")
433 self.assertEqual(var.name, "foo")
434
435 def test_get_unset_value_with_no_defaults(self):
436 def f_with_empty_context():
437 var = contextvars.ContextVar("foo")
438 with self.assertRaises(LookupError):
439 var.get()
440
441 contextvars.Context().run(f_with_empty_context)
442
443 def test_get_value_set_in_current_context(self):
444 var = contextvars.ContextVar("foo")
445 var.set(1)
446 self.assertEqual(var.get(), 1)
447
448 def test_get_value_set_in_current_context_overrides_defaults(self):
449 var = contextvars.ContextVar("foo", default=1)
450 var.set(2)
451 self.assertEqual(var.get(3), 2)
452
453 def test_get_with_default_takes_priority_over_context_var_default(self):
454 def f_with_empty_context():
455 var = contextvars.ContextVar("foo", default=1)
456 self.assertEqual(var.get(2), 2)
457
458 contextvars.Context().run(f_with_empty_context)
459
460 def test_get_with_unset_and_no_default_returns_context_var_default(self):
461 def f_with_empty_context():
462 var = contextvars.ContextVar("foo", default=1)
463 self.assertEqual(var.get(), 1)
464
465 contextvars.Context().run(f_with_empty_context)
466
467 def test_set_with_no_previous_value_returns_token_with_missing_old_value(self):
468 def f_with_empty_context():
469 var = contextvars.ContextVar("foo")
470 token = var.set(None)
471 self.assertIs(token.old_value, contextvars.Token.MISSING)
472
473 contextvars.Context().run(f_with_empty_context)
474
475 def test_set_with_previous_value_returns_token_with_old_value(self):
476 def f_with_empty_context():
477 var = contextvars.ContextVar("foo")
478 var.set(1)
479 token = var.set(2)
480 self.assertEqual(token.old_value, 1)
481
482 contextvars.Context().run(f_with_empty_context)
483
484 def test_set_is_available_in_subsequent_get(self):
485 var = contextvars.ContextVar("foo")
486 var.set(1)
487 self.assertEqual(var.get(), 1)
488
489 def test_set_does_not_affect_other_contexts(self):
490 ctx1 = contextvars.Context()
491 var = None
492
493 def f1_with_ctx1():
494 nonlocal var
495 var = contextvars.ContextVar("foo")
496 var.set(1)
497
498 ctx1.run(f1_with_ctx1)
499
500 ctx2 = contextvars.Context()
501
502 def f_with_ctx2():
503 nonlocal var
504 var.set(2)
505
506 ctx2.run(f_with_ctx2)
507
508 def f2_with_ctx1():
509 self.assertEqual(var.get(), 1)
510
511 ctx1.run(f2_with_ctx1)
512
513 def test_set_returns_token_referring_to_self(self):
514 var = contextvars.ContextVar("foo")
515 token = var.set(1)
516 self.assertIs(token.var, var)
517
518 def test_reset_with_used_token_raises_runtime_error(self):
519 var = contextvars.ContextVar("foo")
520 token = var.set(1)
521 var.reset(token)
522 with self.assertRaises(RuntimeError):
523 var.reset(token)
524
525 def test_reset_with_token_from_another_context_var_raises_value_error(self):
526 var1 = contextvars.ContextVar("foo")
527 token = var1.set(1)
528 var2 = contextvars.ContextVar("foo")
529 with self.assertRaises(ValueError):
530 var2.reset(token)
531
532 def test_reset_with_token_from_another_context_raises_value_error(self):
533 ctx1 = contextvars.Context()
534 var = None
535 token = None
536
537 def f_with_ctx1():
538 nonlocal var, token
539 var = contextvars.ContextVar("foo")
540 token = var.set(1)
541
542 ctx1.run(f_with_ctx1)
543
544 ctx2 = contextvars.Context()
545
546 def f_with_ctx2():
547 with self.assertRaises(ValueError):
548 var.reset(token)
549
550 ctx2.run(f_with_ctx2)
551
552 def test_reset_with_no_previous_value_removes_value_from_context(self):
553 def f_with_empty_context():
554 var = contextvars.ContextVar("foo")
555 token = var.set(1)
556 var.reset(token)
557 with self.assertRaises(LookupError):
558 var.get()
559
560 contextvars.Context().run(f_with_empty_context)
561
562 def test_reset_restores_previous_value_to_context(self):
563 var = contextvars.ContextVar("foo")
564 var.set(1)
565 token = var.set(2)
566 var.reset(token)
567 self.assertEqual(var.get(), 1)
568
569
570class CopyContextTests(unittest.TestCase):
571 def test_mutating_copied_context_does_not_affect_current_context(self):
572 def f_with_empty_context():
573 ctx_copy = contextvars.copy_context()
574 var = contextvars.ContextVar("foo")
575 var.set(1)
576 self.assertFalse(ctx_copy.__contains__(var))
577
578 contextvars.Context().run(f_with_empty_context)
579
580
581class TokenTests(unittest.TestCase):
582 def test_dunder_repr_with_invalid_self(self):
583 with self.assertRaises(TypeError):
584 contextvars.Token.__repr__(None)
585
586 def test_dunder_new_raises_runtime_error(self):
587 with self.assertRaises(RuntimeError):
588 contextvars.Token.__new__(contextvars.Token, None, None, None)
589
590 def test_missing_value_exists(self):
591 self.assertTrue(hasattr(contextvars.Token, "MISSING"))
592
593 def test_dunder_repr_unused(self):
594 var = contextvars.ContextVar("foo")
595 token = var.set(1)
596 self.assertRegex(
597 token.__repr__(),
598 r"<Token var=<ContextVar name='foo' at 0x[0-9a-f]+> at 0x[0-9a-f]+>",
599 )
600
601 def test_dunder_repr_used(self):
602 var = contextvars.ContextVar("foo")
603 token = var.set(1)
604 var.reset(token)
605 self.assertRegex(
606 token.__repr__(),
607 r"<Token used var=<ContextVar name='foo' at 0x[0-9a-f]+> at 0x[0-9a-f]+>",
608 )
609
610 def test_dunder_repr_recursive(self):
611 backref = []
612 var = contextvars.ContextVar("foo", default=backref)
613 token = var.set(1)
614 backref.append(token)
615 self.assertIn("...", token.__repr__())
616
617 def test_var_property(self):
618 var = contextvars.ContextVar("foo")
619 token = var.set(1)
620 self.assertIs(token.var, var)
621
622 def test_old_value_property_when_there_is_an_old_value(self):
623 value = object()
624 var = contextvars.ContextVar("foo")
625 var.set(value)
626 token = var.set(object())
627 self.assertIs(token.old_value, value)
628
629 def test_old_value_property_when_missing(self):
630 var = contextvars.ContextVar("foo")
631 token = var.set(None)
632 self.assertIs(token.old_value, contextvars.Token.MISSING)
633
634
635if __name__ == "__main__":
636 unittest.main()