A tool to retrieve information from Twitch.
1"""Module to handle authentication."""
2
3import asyncio
4import json
5import logging
6import webbrowser
7
8import httpx
9from aiohttp import web
10
11from .api import (
12 refresh_access_token,
13 retrieve_authorize_url,
14 retrieve_token,
15 validate_access_token,
16)
17from .settings import settings
18
19logger = logging.getLogger(__name__)
20
21
22class OAuthServer:
23 """HTTP server to handle the redirection from authentication flow."""
24
25 access_token: str | None = None
26 access_token_complete: dict | None = None
27
28 def __init__(self):
29 """Initialize the server."""
30 self.access_token = None
31 self.access_token_complete = None
32
33 async def handler(self, request: web.Request):
34 """Handle the request."""
35 code = request.rel_url.query.get("code", None)
36 if not code:
37 logger.error("No code found in the request.")
38 return web.Response(text="⚠️ Error, code parameter not found.")
39
40 # change the code for a token
41 token_data = retrieve_token(code=code)
42
43 if "access_token" in token_data:
44 self.access_token = token_data["access_token"]
45 self.access_token_complete = token_data
46 else:
47 logger.error("Access token not found in response.")
48 return web.Response(text="⚠️ Error, access token can't be obtained.")
49
50 # event to close the server
51 logger.debug("Sending event to close web server...")
52 self.event.set()
53
54 return web.Response(
55 text="✅ Authentication complete! You can close this window."
56 )
57
58 async def run(self):
59 """Run the server."""
60 logger.debug("Starging the server...")
61 self.event = asyncio.Event()
62
63 self.server = web.Server(self.handler)
64 self.runner = web.ServerRunner(self.server)
65
66 await self.runner.setup()
67 site = web.TCPSite(self.runner, settings.host, settings.port)
68 await site.start()
69
70 await self.event.wait()
71 logger.debug("Server started")
72
73
74def _save_token_data(token_data: dict | None) -> None:
75 """Save the token data in the auth file."""
76 if token_data:
77 auth_file = settings.auth_file
78 auth_file.write_text(json.dumps(token_data, indent=2))
79
80
81def _request_new_access_token() -> str:
82 """Get a new valid access token."""
83 logger.debug("Opening browser for authentication...")
84 webbrowser.open(retrieve_authorize_url())
85
86 server = OAuthServer()
87 asyncio.run(server.run())
88
89 logger.debug(f"Obtaioned acces token: {server.access_token}")
90 assert server.access_token, "Access token not set"
91
92 # save access token data
93 _save_token_data(server.access_token_complete)
94
95 return server.access_token
96
97
98def obtain_access_token() -> str:
99 """Check if there is a saved access token and use it if it's valid.
100
101 If not, renews or get a new one.
102 """
103 auth_file = settings.auth_file
104 if not auth_file.is_file():
105 return _request_new_access_token()
106
107 token_data = json.loads(auth_file.read_text())
108 access_token = token_data["access_token"]
109
110 if validate_access_token(access_token):
111 return access_token
112
113 # try to refresh token
114 if "refresh_token" in token_data:
115 refresh_token = token_data["refresh_token"]
116 try:
117 token_data = refresh_access_token(refresh_token)
118 _save_token_data(token_data)
119 return token_data["access_token"]
120 except httpx.HTTPError:
121 # in case of error, request a new access token
122 return _request_new_access_token()
123
124 return _request_new_access_token()