at main 7.9 kB view raw
1#!/usr/bin/env python3 2# /// script 3# requires-python = ">=3.11" 4# dependencies = ["asyncpg", "rich", "plotext", "typer", "pydantic-settings"] 5# /// 6"""audd api cost tracker - because $0.005 adds up 7 8usage: 9 uv run scripts/audd_costs.py # current billing period (prod) 10 uv run scripts/audd_costs.py --all # all time stats 11 12set NEON_DATABASE_URL in .env (or NEON_DATABASE_URL_PRD, _STG, _DEV) 13""" 14 15import asyncio 16import re 17from datetime import datetime, timedelta 18from typing import Any 19 20import plotext as plt 21import typer 22from pydantic_settings import BaseSettings, SettingsConfigDict 23from rich.console import Console 24from rich.panel import Panel 25from rich.table import Table 26 27# audd indie plan pricing 28INCLUDED_REQUESTS = 6000 # 1000 + 5000 bonus 29COST_PER_REQUEST = 0.005 # $5 per 1000 30BILLING_DAY = 24 # payment expected on the 24th 31 32 33class Settings(BaseSettings): 34 model_config = SettingsConfigDict(env_file=".env", extra="ignore") 35 36 neon_database_url: str | None = None 37 neon_database_url_prd: str | None = None 38 neon_database_url_stg: str | None = None 39 neon_database_url_dev: str | None = None 40 41 def get_url(self, env: str) -> str: 42 """get database url for environment, converting to asyncpg format""" 43 url = getattr(self, f"neon_database_url_{env}", None) or self.neon_database_url 44 if not url: 45 raise ValueError( 46 f"no database url for {env} - set NEON_DATABASE_URL or NEON_DATABASE_URL_{env.upper()}" 47 ) 48 # convert sqlalchemy dialect to plain postgres 49 return re.sub(r"postgresql\+\w+://", "postgresql://", url) 50 51 52settings = Settings() 53 54console = Console() 55app = typer.Typer(add_completion=False) 56 57 58def get_billing_period_start() -> datetime: 59 """get the start of current billing period (24th of month)""" 60 now = datetime.now() 61 if now.day >= BILLING_DAY: 62 return datetime(now.year, now.month, BILLING_DAY) 63 else: 64 first_of_month = datetime(now.year, now.month, 1) 65 prev_month = first_of_month - timedelta(days=1) 66 return datetime(prev_month.year, prev_month.month, BILLING_DAY) 67 68 69async def query_scans( 70 db_url: str, since: datetime | None = None 71) -> list[dict[str, Any]]: 72 """fetch scan data from postgres""" 73 import asyncpg 74 75 conn = await asyncpg.connect(db_url) 76 try: 77 if since: 78 rows = await conn.fetch( 79 """ 80 SELECT DATE(scanned_at) as date, 81 COUNT(*) as scans, 82 COUNT(CASE WHEN is_flagged THEN 1 END) as flagged 83 FROM copyright_scans 84 WHERE scanned_at >= $1 85 GROUP BY DATE(scanned_at) 86 ORDER BY date 87 """, 88 since, 89 ) 90 else: 91 rows = await conn.fetch( 92 """ 93 SELECT DATE(scanned_at) as date, 94 COUNT(*) as scans, 95 COUNT(CASE WHEN is_flagged THEN 1 END) as flagged 96 FROM copyright_scans 97 GROUP BY DATE(scanned_at) 98 ORDER BY date 99 """ 100 ) 101 return [dict(r) for r in rows] 102 finally: 103 await conn.close() 104 105 106async def get_totals(db_url: str, since: datetime | None = None) -> dict[str, int]: 107 """get total counts""" 108 import asyncpg 109 110 conn = await asyncpg.connect(db_url) 111 try: 112 if since: 113 row = await conn.fetchrow( 114 """ 115 SELECT COUNT(*) as total, 116 COUNT(CASE WHEN is_flagged THEN 1 END) as flagged 117 FROM copyright_scans 118 WHERE scanned_at >= $1 119 """, 120 since, 121 ) 122 else: 123 row = await conn.fetchrow( 124 """ 125 SELECT COUNT(*) as total, 126 COUNT(CASE WHEN is_flagged THEN 1 END) as flagged 127 FROM copyright_scans 128 """ 129 ) 130 return {"total": row["total"], "flagged": row["flagged"]} 131 finally: 132 await conn.close() 133 134 135def calculate_cost(total_requests: int) -> tuple[int, float]: 136 """calculate billable requests and cost""" 137 billable = max(0, total_requests - INCLUDED_REQUESTS) 138 cost = billable * COST_PER_REQUEST 139 return billable, cost 140 141 142def display_dashboard( 143 daily_data: list[dict[str, Any]], 144 totals: dict[str, int], 145 period_label: str, 146 env: str, 147) -> None: 148 """render the cost dashboard""" 149 console.print(f"\n[bold cyan]audd api costs[/] - {period_label} [{env}]\n") 150 151 total = totals["total"] 152 flagged = totals["flagged"] 153 billable, cost = calculate_cost(total) 154 remaining_free = max(0, INCLUDED_REQUESTS - total) 155 156 # stats panel 157 stats_table = Table(show_header=False, box=None, padding=(0, 2)) 158 stats_table.add_column(style="dim") 159 stats_table.add_column(style="bold green", justify="right") 160 161 stats_table.add_row("total scans", f"{total:,}") 162 stats_table.add_row("flagged (matches)", f"{flagged:,}") 163 stats_table.add_row("flag rate", f"{flagged / total * 100:.1f}%" if total else "0%") 164 stats_table.add_row("", "") 165 stats_table.add_row("included requests", f"{INCLUDED_REQUESTS:,}") 166 stats_table.add_row("remaining free", f"{remaining_free:,}") 167 stats_table.add_row("billable requests", f"{billable:,}") 168 stats_table.add_row( 169 "estimated cost", 170 f"[{'red' if cost > 0 else 'green'}]${cost:.2f}[/]", 171 ) 172 173 console.print( 174 Panel(stats_table, title="[bold]usage & costs[/]", border_style="blue") 175 ) 176 177 if not daily_data: 178 console.print("[dim]no scan data available[/]") 179 return 180 181 # extract data - use indices for x-axis to avoid plotext date parsing 182 dates = [d["date"].strftime("%m/%d") for d in daily_data] 183 scans = [d["scans"] for d in daily_data] 184 flagged_counts = [d["flagged"] for d in daily_data] 185 x = list(range(len(dates))) 186 187 # daily scans chart 188 plt.clear_figure() 189 plt.theme("dark") 190 plt.title("daily scans") 191 plt.bar(x, scans, color="cyan", label="scans") 192 plt.xticks(x, dates) 193 plt.plotsize(80, 12) 194 plt.show() 195 print() 196 197 # cumulative cost projection 198 cumulative = [] 199 running = 0 200 for s in scans: 201 running += s 202 _, c = calculate_cost(running) 203 cumulative.append(c) 204 205 if any(c > 0 for c in cumulative): 206 plt.clear_figure() 207 plt.theme("dark") 208 plt.title("cumulative cost ($)") 209 plt.plot(x, cumulative, color="red", marker="braille") 210 plt.xticks(x, dates) 211 plt.plotsize(80, 10) 212 plt.show() 213 print() 214 215 # flag rate over time 216 rates = [f / s * 100 if s > 0 else 0 for f, s in zip(flagged_counts, scans)] 217 plt.clear_figure() 218 plt.theme("dark") 219 plt.title("flag rate (%)") 220 plt.plot(x, rates, color="yellow", marker="braille") 221 plt.xticks(x, dates) 222 plt.plotsize(80, 10) 223 plt.show() 224 print() 225 226 227@app.command() 228def main( 229 all_time: bool = typer.Option(False, "--all", "-a", help="show all time stats"), 230 env: str = typer.Option("prd", "--env", "-e", help="environment: prd, stg, dev"), 231) -> None: 232 """audd api cost tracker for plyr.fm""" 233 try: 234 db_url = settings.get_url(env) 235 except ValueError as e: 236 console.print(f"[red]error:[/] {e}") 237 raise typer.Exit(1) 238 239 async def run(): 240 if all_time: 241 since = None 242 label = "all time" 243 else: 244 since = get_billing_period_start() 245 label = f"billing period (since {since.strftime('%b %d')})" 246 247 daily_data = await query_scans(db_url, since) 248 totals = await get_totals(db_url, since) 249 display_dashboard(daily_data, totals, label, env) 250 251 asyncio.run(run()) 252 253 254if __name__ == "__main__": 255 app()