diff --git a/ai_getter/main.py b/ai_getter/main.py index ed5506b..5b2257f 100644 --- a/ai_getter/main.py +++ b/ai_getter/main.py @@ -2,6 +2,7 @@ import datetime as dt import os import pathlib as pl import sys +import time import openai import pyperclip @@ -15,11 +16,16 @@ openai.api_key = os.getenv("OPENAI_TOKEN") S3_BUCKET = os.getenv("AI_GETTER_S3_BUCKET") SAVE_PATH = pl.Path(os.getenv("AI_GETTER_SAVE_PATH") or os.path.expanduser("~")) +OPENAI_GPT_3_5_TURBO_COST_PER_1K_TOKENS_IN_TENTHS_OF_A_CENT = 2 +OPENAI_DALE_COST_PER_IMAGE_IN_TENTHS_OF_A_CENT = 20 + + class NoBucket(Exception): pass def chat(prompt: str, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> str: + start = time.time() response = openai.ChatCompletion.create( model="gpt-3.5-turbo", messages=[{"role": "user", "content": prompt}], @@ -29,22 +35,26 @@ def chat(prompt: str, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) if save_to_s3: if S3_BUCKET is None: raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env") - upload_to_s3(S3_BUCKET, str(fp), str(fp)) - pprint(response) + upload_to_s3(bucket_name=S3_BUCKET, file_path=str(fp), key=str(fp), prompt=prompt, typ="text", vendor="openai") + total_tokens = response["usage"]["total_tokens"] # type: ignore + print(f"cost: {OPENAI_GPT_3_5_TURBO_COST_PER_1K_TOKENS_IN_TENTHS_OF_A_CENT * total_tokens * 10 / 1000} cents") + print(f"Time taken: {time.time() - start} seconds") return content # type: ignore def generate_images(prompt: str, num_images: int, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> dict: if num_images > 10: raise ValueError("num_images must be <= 10") + start = time.time() res = openai.Image.create(prompt=prompt, n=num_images) # type: ignore file_paths = save_images_from_openai(prompt, res, save_path) # type: ignore if save_to_s3: if S3_BUCKET is None: raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env") - for idx, fp in enumerate(file_paths): - upload_to_s3(S3_BUCKET, fp, f"{prompt}-{dt.date.today()}-{idx}") - pprint(res) + for fp in file_paths: + upload_to_s3(bucket_name=S3_BUCKET, file_path=fp, key=fp, prompt=prompt, typ="image", vendor="openai") + print(f"Time taken: {time.time() - start} seconds") + print(f"Cost: {OPENAI_DALE_COST_PER_IMAGE_IN_TENTHS_OF_A_CENT * num_images * 10} cents") return res # type: ignore @@ -55,7 +65,8 @@ Usage: aig image [--clip] --num-images [--save-path ] [--s3] aig text [--clip] [--save-path ] [--s3] - (--clip will get the prompt from your clipboard's contents) + (--clip will get the prompt from your clipboard's contents, in addition to if you supply one) + (--s3 will upload the result to your AI_GETTER_S3_BUCKET) """) diff --git a/ai_getter/save.py b/ai_getter/save.py index ce800ef..4d8549d 100644 --- a/ai_getter/save.py +++ b/ai_getter/save.py @@ -1,8 +1,10 @@ import datetime as dt import pathlib as pl +import typing as t import boto3 from requests import get +import trans def save_output(prompt: str, content: str, save_path: pl.Path) -> pl.Path: @@ -49,9 +51,39 @@ def save_images_from_openai( return file_paths -def upload_to_s3(bucket_name: str, file_path: str, key: str): +def transform_prompt_for_aws_metadata(prompt: str) -> str: + MAX_NUM_PYTHON_CHARS_IN_S3_METADATA = ( + 1849 # not sure why, but this is close to 1.9kb + # AWS's limit is supposedly 2kb + ) + truncated = prompt[:MAX_NUM_PYTHON_CHARS_IN_S3_METADATA] + backslashes_removed = truncated.replace("\n", "").replace("\t", "") + return trans.trans(backslashes_removed) + + +def upload_to_s3( + bucket_name: str, + file_path: str, + key: str, + prompt: str, + typ: t.Literal["image", "text"], + vendor: str, +): s3c = boto3.client("s3") - s3c.upload_file(file_path, bucket_name, key) + transformed_prompt = transform_prompt_for_aws_metadata(prompt) + s3c.upload_file( + file_path, + bucket_name, + key, + ExtraArgs={ + "Metadata": { + "prompt": transformed_prompt, + "date": str(dt.datetime.now()), + "type": typ, + "vendor": vendor, + } + }, + ) def download_images(prompt: str, res: dict, save_path: pl.Path) -> list[str]: diff --git a/requirements.txt b/requirements.txt index f236601..b6127f9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ --e git+https://github.com/zevaverbach/ai_getter@bc472274f6f10c180daa9ee267247cb51ea4ea15#egg=ai_getter aiohttp==3.8.4 aiosignal==1.3.1 async-timeout==4.0.2