this repo has no description
1// Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
2#include "under-contextvars-module.h"
3
4#include "builtins.h"
5#include "dict-builtins.h"
6#include "type-builtins.h"
7
8namespace py {
9
10static const BuiltinAttribute kContextAttributes[] = {
11 {ID(_context__data), RawContext::kDataOffset, AttributeFlags::kHidden},
12 {ID(_context__prev_context), RawContext::kPrevContextOffset,
13 AttributeFlags::kHidden},
14};
15
16static const BuiltinAttribute kContextVarAttributes[] = {
17 {ID(_context_var__default_value), RawContextVar::kDefaultValueOffset,
18 AttributeFlags::kHidden},
19 {ID(name), RawContextVar::kNameOffset, AttributeFlags::kReadOnly},
20};
21
22static const BuiltinAttribute kTokenAttributes[] = {
23 {ID(_token__context), RawToken::kContextOffset, AttributeFlags::kHidden},
24 {ID(old_value), RawToken::kOldValueOffset},
25 {ID(_token__used), RawToken::kUsedOffset, AttributeFlags::kHidden},
26 {ID(var), RawToken::kVarOffset, AttributeFlags::kReadOnly},
27};
28
29void initializeUnderContextvarsTypes(Thread* thread) {
30 addBuiltinType(thread, ID(Context), LayoutId::kContext,
31 /*superclass_id=*/LayoutId::kObject, kContextAttributes,
32 Context::kSize, /*basetype=*/false);
33
34 addBuiltinType(thread, ID(ContextVar), LayoutId::kContextVar,
35 /*superclass_id=*/LayoutId::kObject, kContextVarAttributes,
36 ContextVar::kSize, /*basetype=*/false);
37
38 addBuiltinType(thread, ID(Token), LayoutId::kToken,
39 /*superclass_id=*/LayoutId::kObject, kTokenAttributes,
40 Token::kSize, /*basetype=*/false);
41}
42
43RawObject FUNC(_contextvars, _ContextVar_default_value)(Thread* thread,
44 Arguments args) {
45 HandleScope scope(thread);
46 Object ctxvar_obj(&scope, args.get(0));
47 if (!ctxvar_obj.isContextVar()) {
48 return thread->raiseWithFmt(
49 LayoutId::kTypeError,
50 "'_contextvar__default_value_get' for 'ContextVar' objects doesn't "
51 "apply to a '%T' object",
52 &ctxvar_obj);
53 }
54 ContextVar ctxvar(&scope, *ctxvar_obj);
55 return ctxvar.defaultValue();
56}
57
58RawObject FUNC(_contextvars, _ContextVar_name)(Thread* thread, Arguments args) {
59 HandleScope scope(thread);
60 Object ctxvar_obj(&scope, args.get(0));
61 if (!ctxvar_obj.isContextVar()) {
62 return thread->raiseWithFmt(LayoutId::kTypeError,
63 "'_contextvar__name_get' for 'ContextVar' "
64 "objects doesn't apply to a '%T' object",
65 &ctxvar_obj);
66 }
67 ContextVar ctxvar(&scope, *ctxvar_obj);
68 return ctxvar.name();
69}
70
71RawObject FUNC(_contextvars, _Token_used)(Thread* thread, Arguments args) {
72 HandleScope scope(thread);
73 Object token_obj(&scope, args.get(0));
74 if (!token_obj.isToken()) {
75 return thread->raiseWithFmt(
76 LayoutId::kTypeError,
77 "'_Token_used' for 'Token' objects doesn't apply to a '%T' object",
78 &token_obj);
79 }
80 Token token(&scope, *token_obj);
81 return Bool::fromBool(token.used());
82}
83
84RawObject FUNC(_contextvars, _Token_var)(Thread* thread, Arguments args) {
85 HandleScope scope(thread);
86 Object token_obj(&scope, args.get(0));
87 if (!token_obj.isToken()) {
88 return thread->raiseWithFmt(
89 LayoutId::kTypeError,
90 "'_Token_var' for 'Token' objects doesn't apply to a '%T' object",
91 &token_obj);
92 }
93 Token token(&scope, *token_obj);
94 return token.var();
95}
96
97static RawObject contextForThread(Thread* thread) {
98 HandleScope scope(thread);
99 Object ctx_obj(&scope, thread->contextvarsContext());
100 if (ctx_obj.isNoneType()) {
101 Runtime* runtime = thread->runtime();
102 Dict data(&scope, runtime->newDict());
103 Context ctx(&scope, runtime->newContext(data));
104 thread->setContextvarsContext(*ctx);
105 return *ctx;
106 }
107 return *ctx_obj;
108}
109
110RawObject FUNC(_contextvars, _thread_context)(Thread* thread, Arguments) {
111 return contextForThread(thread);
112}
113
114static RawObject dataDictFromContext(Thread* thread, Arguments args) {
115 HandleScope scope(thread);
116 Object self_obj(&scope, args.get(0));
117 if (!self_obj.isContext()) {
118 return thread->raiseRequiresType(self_obj, ID(Context));
119 }
120 Context self(&scope, *self_obj);
121 return self.data();
122}
123
124static RawObject lookupVarInContext(Thread* thread, Arguments args,
125 bool contains_mode) {
126 HandleScope scope(thread);
127 Object var_obj(&scope, args.get(1));
128 if (!var_obj.isContextVar()) {
129 return thread->raiseRequiresType(var_obj, ID(ContextVar));
130 }
131 ContextVar var(&scope, *var_obj);
132 Object data_obj(&scope, dataDictFromContext(thread, args));
133 if (data_obj.isError()) return *data_obj;
134 Dict data(&scope, *data_obj);
135 Object var_hash_obj(&scope, Interpreter::hash(thread, var));
136 if (var_hash_obj.isError()) return *var_hash_obj;
137 word var_hash = SmallInt::cast(*var_hash_obj).value();
138 return contains_mode ? dictIncludes(thread, data, var, var_hash)
139 : dictAt(thread, data, var, var_hash);
140}
141
142RawObject METH(Context, __contains__)(Thread* thread, Arguments args) {
143 return lookupVarInContext(thread, args, true);
144}
145
146RawObject METH(Context, __eq__)(Thread* thread, Arguments args) {
147 HandleScope scope(thread);
148
149 Object data_obj(&scope, dataDictFromContext(thread, args));
150 if (data_obj.isError()) return *data_obj;
151 Dict data(&scope, *data_obj);
152 Object other_ctx_obj(&scope, args.get(1));
153 if (!other_ctx_obj.isContext()) {
154 return NotImplementedType::object();
155 }
156 Context other_ctx(&scope, *other_ctx_obj);
157 Dict other_data(&scope, other_ctx.data());
158
159 return dictEq(thread, data, other_data);
160}
161
162RawObject METH(Context, __getitem__)(Thread* thread, Arguments args) {
163 HandleScope scope(thread);
164 Object result(&scope, lookupVarInContext(thread, args, false));
165 if (result.isErrorNotFound()) {
166 return thread->raise(LayoutId::kKeyError, NoneType::object());
167 }
168 return *result;
169}
170
171RawObject METH(Context, __iter__)(Thread* thread, Arguments args) {
172 return METH(Context, keys)(thread, args);
173}
174
175RawObject METH(Context, __new__)(Thread* thread, Arguments args) {
176 HandleScope scope(thread);
177 Runtime* runtime = thread->runtime();
178 if (args.get(0) != runtime->typeAt(LayoutId::kContext)) {
179 return thread->raiseWithFmt(LayoutId::kTypeError,
180 "Context.__new__(X): X is not 'Context'");
181 }
182 Dict data(&scope, runtime->newDict());
183 Context ctx(&scope, runtime->newContext(data));
184 return *ctx;
185}
186
187RawObject METH(Context, __len__)(Thread* thread, Arguments args) {
188 HandleScope scope(thread);
189 Object data_obj(&scope, dataDictFromContext(thread, args));
190 if (data_obj.isError()) return *data_obj;
191 Dict data(&scope, *data_obj);
192 return SmallInt::fromWord(data.numItems());
193}
194
195RawObject METH(Context, copy)(Thread* thread, Arguments args) {
196 HandleScope scope(thread);
197 Object data_obj(&scope, dataDictFromContext(thread, args));
198 if (data_obj.isError()) return *data_obj;
199 Dict data(&scope, *data_obj);
200 return thread->runtime()->newContext(data);
201}
202
203RawObject METH(Context, get)(Thread* thread, Arguments args) {
204 HandleScope scope(thread);
205 Object val(&scope, lookupVarInContext(thread, args, false));
206 if (val.isErrorNotFound()) {
207 return args.get(2);
208 }
209 return *val;
210}
211
212RawObject METH(Context, items)(Thread* thread, Arguments args) {
213 HandleScope scope(thread);
214 Object data_obj(&scope, dataDictFromContext(thread, args));
215 if (data_obj.isError()) return *data_obj;
216 Dict data(&scope, *data_obj);
217 return thread->runtime()->newDictItemIterator(thread, data);
218}
219
220RawObject METH(Context, keys)(Thread* thread, Arguments args) {
221 HandleScope scope(thread);
222 Object data_obj(&scope, dataDictFromContext(thread, args));
223 if (data_obj.isError()) return *data_obj;
224 Dict data(&scope, *data_obj);
225 return thread->runtime()->newDictKeyIterator(thread, data);
226}
227
228RawObject METH(Context, run)(Thread* thread, Arguments args) {
229 HandleScope scope(thread);
230 Object self_obj(&scope, args.get(0));
231 if (!self_obj.isContext()) {
232 return thread->raiseRequiresType(self_obj, ID(Context));
233 }
234 Context self(&scope, *self_obj);
235
236 // Set Context.prev_context to current thread-global Context
237 if (!self.prevContext().isNoneType()) {
238 Str self_repr(&scope, thread->invokeMethod1(self, ID(__repr__)));
239 return thread->raiseWithFmt(LayoutId::kRuntimeError,
240 "cannot enter context: %S is already entered",
241 &self_repr);
242 }
243 Context ctx(&scope, contextForThread(thread));
244 self.setPrevContext(*ctx);
245
246 thread->setContextvarsContext(*self);
247
248 // Call callable forwarding all args
249 thread->stackPush(args.get(1)); // callable
250 thread->stackPush(args.get(2)); // *args
251 thread->stackPush(args.get(3)); // **kwargs
252 Object call_result(
253 &scope, Interpreter::callEx(thread, CallFunctionExFlag::VAR_KEYWORDS));
254
255 // Always restore the thread's previous Context even if call above failed
256 thread->setContextvarsContext(self.prevContext());
257 self.setPrevContext(NoneType::object());
258
259 return *call_result;
260}
261
262RawObject METH(Context, values)(Thread* thread, Arguments args) {
263 HandleScope scope(thread);
264 Object data_obj(&scope, dataDictFromContext(thread, args));
265 if (data_obj.isError()) return *data_obj;
266 Dict data(&scope, *data_obj);
267 return thread->runtime()->newDictValueIterator(thread, data);
268}
269
270RawObject METH(ContextVar, __new__)(Thread* thread, Arguments args) {
271 HandleScope scope(thread);
272 Runtime* runtime = thread->runtime();
273 if (args.get(0) != runtime->typeAt(LayoutId::kContextVar)) {
274 return thread->raiseWithFmt(LayoutId::kTypeError,
275 "ContextVar.__new__(X): X is not 'ContextVar'");
276 }
277
278 Object name_obj(&scope, args.get(1));
279 if (!name_obj.isStr()) {
280 return thread->raiseWithFmt(LayoutId::kTypeError,
281 "context variable name must be a str");
282 }
283 Str name(&scope, *name_obj);
284
285 Object default_value(&scope, args.get(2));
286
287 return runtime->newContextVar(name, default_value);
288}
289
290RawObject METH(ContextVar, get)(Thread* thread, Arguments args) {
291 HandleScope scope(thread);
292 Object self_obj(&scope, args.get(0));
293 if (!self_obj.isContextVar()) {
294 return thread->raiseRequiresType(self_obj, ID(ContextVar));
295 }
296 ContextVar self(&scope, *self_obj);
297
298 // Check for value held in thread-global Context
299 Context ctx(&scope, contextForThread(thread));
300 Dict ctx_data(&scope, ctx.data());
301 Object self_hash_obj(&scope, Interpreter::hash(thread, self));
302 if (self_hash_obj.isError()) return *self_hash_obj;
303 word self_hash = SmallInt::cast(*self_hash_obj).value();
304 Object result(&scope, dictAt(thread, ctx_data, self, self_hash));
305 if (!result.isError() || !result.isErrorNotFound()) {
306 return *result;
307 }
308
309 // No data in thread-global Context, check default argument
310 Object arg_default(&scope, args.get(1));
311 if (!arg_default.isUnbound()) {
312 return *arg_default;
313 }
314
315 // No default argument, check ContextVar default
316 Object default_value(&scope, self.defaultValue());
317 if (!default_value.isUnbound()) {
318 return *default_value;
319 }
320
321 return thread->raise(LayoutId::kLookupError, NoneType::object());
322}
323
324RawObject METH(ContextVar, reset)(Thread* thread, Arguments args) {
325 HandleScope scope(thread);
326 Object self_obj(&scope, args.get(0));
327 if (!self_obj.isContextVar()) {
328 return thread->raiseRequiresType(self_obj, ID(ContextVar));
329 }
330 ContextVar self(&scope, *self_obj);
331 Object token_obj(&scope, args.get(1));
332 if (!token_obj.isToken()) {
333 return thread->raiseRequiresType(self_obj, ID(Token));
334 }
335 Token token(&scope, *token_obj);
336
337 if (token.used()) {
338 return thread->raiseWithFmt(LayoutId::kRuntimeError,
339 "Token has already been used once");
340 }
341
342 if (token.var() != self) {
343 return thread->raiseWithFmt(LayoutId::kValueError,
344 "Token was created by a different ContextVar");
345 }
346
347 Context ctx(&scope, contextForThread(thread));
348 if (token.context() != ctx) {
349 return thread->raiseWithFmt(LayoutId::kValueError,
350 "Token was created in a different Context");
351 }
352
353 // Copy thread-global Context data for update
354 Dict ctx_data(&scope, ctx.data());
355 Object self_hash_obj(&scope, Interpreter::hash(thread, self));
356 if (self_hash_obj.isError()) return *self_hash_obj;
357 word self_hash = SmallInt::cast(*self_hash_obj).value();
358 Object ctx_data_copy_obj(&scope, dictCopy(thread, ctx_data));
359 if (ctx_data_copy_obj.isError()) return *ctx_data_copy_obj;
360 Dict ctx_data_copy(&scope, *ctx_data_copy_obj);
361
362 // Update thread-global Context data based on Token.old_value
363 Object dict_op_res(&scope, NoneType::object());
364 Object old_value(&scope, token.oldValue());
365 if (old_value.isUnbound()) {
366 dict_op_res = dictRemove(thread, ctx_data_copy, self, self_hash);
367 } else {
368 dict_op_res = dictAtPut(thread, ctx_data_copy, self, self_hash, old_value);
369 }
370 if (dict_op_res.isError()) return *dict_op_res;
371 ctx.setData(*ctx_data_copy);
372
373 token.setUsed(true);
374
375 return NoneType::object();
376}
377
378RawObject METH(ContextVar, set)(Thread* thread, Arguments args) {
379 HandleScope scope(thread);
380 Object self_obj(&scope, args.get(0));
381 if (!self_obj.isContextVar()) {
382 return thread->raiseRequiresType(self_obj, ID(ContextVar));
383 }
384 ContextVar self(&scope, *self_obj);
385
386 // Get thread-global Context and its data dict.
387 Context ctx(&scope, contextForThread(thread));
388 Dict ctx_data(&scope, ctx.data());
389 Object self_hash_obj(&scope, Interpreter::hash(thread, self));
390 if (self_hash_obj.isError()) return *self_hash_obj;
391 word self_hash = SmallInt::cast(*self_hash_obj).value();
392
393 // Get any oldvalue from the thread-global Context or Token.MISSING
394 Object old_value(&scope, dictAt(thread, ctx_data, self, self_hash));
395 if (old_value.isError()) {
396 if (old_value.isErrorNotFound()) {
397 old_value = Unbound::object();
398 } else {
399 return *old_value;
400 }
401 }
402
403 // Update thread-global Context data by copying the dict and updating it.
404 Object ctx_data_copy_obj(&scope, dictCopy(thread, ctx_data));
405 if (ctx_data_copy_obj.isError()) return *ctx_data_copy_obj;
406 Dict ctx_data_copy(&scope, *ctx_data_copy_obj);
407 Object value(&scope, args.get(1));
408 Object ctx_data_copy_put_result(
409 &scope, dictAtPut(thread, ctx_data_copy, self, self_hash, value));
410 if (ctx_data_copy_put_result.isError()) return *ctx_data_copy_put_result;
411 ctx.setData(*ctx_data_copy);
412
413 return thread->runtime()->newToken(ctx, self, old_value);
414}
415
416} // namespace py