perf: reuse db session in multi-account endpoints (#714)

each db_session() creates a new Neon connection (~77ms overhead).
the multi-account endpoints were creating 3 separate connections:
1. require_auth -> get_session()
2. get_session_group()
3. switch_active_account() or artist lookup

now get_session_group() and switch_active_account() accept optional
db parameter to reuse an existing connection. endpoints pass their
injected db through, reducing connection overhead by ~154ms.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>

authored by zzstoatzz.io Claude Opus 4.5 and committed by GitHub e501351c 68e986d9

Changed files
+102 -64
backend
src
backend
_internal
api
+87 -54
backend/src/backend/_internal/auth.py
··· 16 16 from fastapi import Cookie, Header, HTTPException 17 17 from jose import jwk 18 18 from sqlalchemy import select 19 + from sqlalchemy.ext.asyncio import AsyncSession 19 20 20 21 from backend._internal.oauth_stores import PostgresStateStore 21 22 from backend.config import settings ··· 843 844 session_id: str 844 845 845 846 846 - async def get_session_group(session_id: str) -> list[LinkedAccount]: 847 - """get all accounts in the same session group. 847 + async def _get_session_group_impl( 848 + session_id: str, db: AsyncSession 849 + ) -> list[LinkedAccount]: 850 + """implementation of get_session_group using provided db session.""" 851 + result = await db.execute( 852 + select(UserSession.group_id).where(UserSession.session_id == session_id) 853 + ) 854 + group_id = result.scalar_one_or_none() 848 855 849 - returns empty list if session has no group_id (single account). 850 - """ 851 - async with db_session() as db: 852 - result = await db.execute( 853 - select(UserSession.group_id).where(UserSession.session_id == session_id) 856 + if not group_id: 857 + return [] 858 + 859 + result = await db.execute( 860 + select(UserSession).where( 861 + UserSession.group_id == group_id, 862 + UserSession.is_developer_token == False, # noqa: E712 854 863 ) 855 - group_id = result.scalar_one_or_none() 864 + ) 865 + sessions = result.scalars().all() 856 866 857 - if not group_id: 858 - return [] 867 + accounts = [] 868 + for session in sessions: 869 + if session.expires_at and datetime.now(UTC) > session.expires_at: 870 + continue 859 871 860 - result = await db.execute( 861 - select(UserSession).where( 862 - UserSession.group_id == group_id, 863 - UserSession.is_developer_token == False, # noqa: E712 872 + accounts.append( 873 + LinkedAccount( 874 + did=session.did, 875 + handle=session.handle, 876 + session_id=session.session_id, 864 877 ) 865 878 ) 866 - sessions = result.scalars().all() 867 879 868 - accounts = [] 869 - for session in sessions: 870 - if session.expires_at and datetime.now(UTC) > session.expires_at: 871 - continue 880 + return accounts 881 + 882 + 883 + async def get_session_group( 884 + session_id: str, db: AsyncSession | None = None 885 + ) -> list[LinkedAccount]: 886 + """get all accounts in the same session group. 887 + 888 + returns empty list if session has no group_id (single account). 872 889 873 - accounts.append( 874 - LinkedAccount( 875 - did=session.did, 876 - handle=session.handle, 877 - session_id=session.session_id, 878 - ) 879 - ) 890 + args: 891 + session_id: the session to look up 892 + db: optional database session to reuse (avoids new connection) 893 + """ 894 + if db is not None: 895 + return await _get_session_group_impl(session_id, db) 880 896 881 - return accounts 897 + async with db_session() as new_db: 898 + return await _get_session_group_impl(session_id, new_db) 882 899 883 900 884 901 async def get_or_create_group_id(session_id: str) -> str: ··· 906 923 return group_id 907 924 908 925 909 - async def switch_active_account(current_session_id: str, target_session_id: str) -> str: 926 + async def _switch_active_account_impl( 927 + current_session_id: str, target_session_id: str, db: AsyncSession 928 + ) -> str: 929 + """implementation of switch_active_account using provided db session.""" 930 + result = await db.execute( 931 + select(UserSession).where(UserSession.session_id == current_session_id) 932 + ) 933 + current_session = result.scalar_one_or_none() 934 + 935 + if not current_session or not current_session.group_id: 936 + raise HTTPException(status_code=400, detail="no session group found") 937 + 938 + result = await db.execute( 939 + select(UserSession).where(UserSession.session_id == target_session_id) 940 + ) 941 + target_session = result.scalar_one_or_none() 942 + 943 + if not target_session: 944 + raise HTTPException(status_code=404, detail="target session not found") 945 + 946 + if target_session.group_id != current_session.group_id: 947 + raise HTTPException(status_code=403, detail="target session not in same group") 948 + 949 + if target_session.expires_at and datetime.now(UTC) > target_session.expires_at: 950 + raise HTTPException(status_code=401, detail="target session expired") 951 + 952 + return target_session_id 953 + 954 + 955 + async def switch_active_account( 956 + current_session_id: str, target_session_id: str, db: AsyncSession | None = None 957 + ) -> str: 910 958 """switch to a different account within a session group. 911 959 912 960 validates that the target session exists, is in the same group, and isn't expired. 913 961 returns the target session_id (caller updates the cookie). 962 + 963 + args: 964 + current_session_id: the current session 965 + target_session_id: the session to switch to 966 + db: optional database session to reuse (avoids new connection) 914 967 """ 915 - async with db_session() as db: 916 - # get current session to find group_id 917 - result = await db.execute( 918 - select(UserSession).where(UserSession.session_id == current_session_id) 968 + if db is not None: 969 + return await _switch_active_account_impl( 970 + current_session_id, target_session_id, db 919 971 ) 920 - current_session = result.scalar_one_or_none() 921 972 922 - if not current_session or not current_session.group_id: 923 - raise HTTPException(status_code=400, detail="no session group found") 924 - 925 - # verify target session is in the same group 926 - result = await db.execute( 927 - select(UserSession).where(UserSession.session_id == target_session_id) 973 + async with db_session() as new_db: 974 + return await _switch_active_account_impl( 975 + current_session_id, target_session_id, new_db 928 976 ) 929 - target_session = result.scalar_one_or_none() 930 - 931 - if not target_session: 932 - raise HTTPException(status_code=404, detail="target session not found") 933 - 934 - if target_session.group_id != current_session.group_id: 935 - raise HTTPException( 936 - status_code=403, detail="target session not in same group" 937 - ) 938 - 939 - # check if target session is expired 940 - if target_session.expires_at and datetime.now(UTC) > target_session.expires_at: 941 - raise HTTPException(status_code=401, detail="target session expired") 942 - 943 - return target_session_id 944 977 945 978 946 979 async def remove_account_from_group(session_id: str) -> str | None:
+15 -10
backend/src/backend/api/auth.py
··· 289 289 switch_to: Annotated[ 290 290 str | None, Query(description="DID to switch to after logout") 291 291 ] = None, 292 + db=Depends(get_db), 292 293 ) -> JSONResponse: 293 294 """logout current user. 294 295 ··· 296 297 to the specified account. otherwise, fully logs out. 297 298 """ 298 299 if switch_to: 299 - # validate target is in same group 300 - linked = await get_session_group(session.session_id) 300 + # validate target is in same group (reuse db connection) 301 + linked = await get_session_group(session.session_id, db=db) 301 302 target = next((a for a in linked if a.did == switch_to), None) 302 303 303 304 if not target: ··· 354 355 db=Depends(get_db), 355 356 ) -> CurrentUserResponse: 356 357 """get current authenticated user with linked accounts.""" 357 - # get all accounts in the session group 358 - linked = await get_session_group(session.session_id) 358 + # get all accounts in the session group (reuse db connection) 359 + linked = await get_session_group(session.session_id, db=db) 359 360 360 361 # look up artist profiles to get fresh avatars 361 362 dids = [account.did for account in linked] ··· 603 604 body: SwitchAccountRequest, 604 605 response: Response, 605 606 session: Session = Depends(require_auth), 607 + db=Depends(get_db), 606 608 ) -> SwitchAccountResponse: 607 609 """switch to a different account in the session group. 608 610 ··· 611 613 612 614 returns the new active account's info. 613 615 """ 614 - # get all accounts in the group 615 - linked = await get_session_group(session.session_id) 616 + # get all accounts in the group (reuse db connection) 617 + linked = await get_session_group(session.session_id, db=db) 616 618 617 619 if not linked: 618 620 raise HTTPException( ··· 634 636 detail="already logged in as this account", 635 637 ) 636 638 637 - # switch the active account 638 - new_session_id = await switch_active_account(session.session_id, target.session_id) 639 + # switch the active account (reuse db connection) 640 + new_session_id = await switch_active_account( 641 + session.session_id, target.session_id, db=db 642 + ) 639 643 640 644 # update the cookie to point to the new session 641 645 if settings.frontend.url: ··· 660 664 @router.post("/logout-all") 661 665 async def logout_all( 662 666 session: Session = Depends(require_auth), 667 + db=Depends(get_db), 663 668 ) -> JSONResponse: 664 669 """logout all accounts in the session group. 665 670 666 671 removes all sessions in the group and clears the cookie. 667 672 """ 668 - # get all accounts in the group 669 - linked = await get_session_group(session.session_id) 673 + # get all accounts in the group (reuse db connection) 674 + linked = await get_session_group(session.session_id, db=db) 670 675 671 676 # delete all sessions (or just this one if not in a group) 672 677 if linked: