show duration and cost, make s3 upload work

This commit is contained in:
2023-04-02 20:14:33 +02:00
parent 9d2a68d698
commit 43be074db9
3 changed files with 51 additions and 9 deletions

View File

@@ -2,6 +2,7 @@ import datetime as dt
import os import os
import pathlib as pl import pathlib as pl
import sys import sys
import time
import openai import openai
import pyperclip import pyperclip
@@ -15,11 +16,16 @@ openai.api_key = os.getenv("OPENAI_TOKEN")
S3_BUCKET = os.getenv("AI_GETTER_S3_BUCKET") S3_BUCKET = os.getenv("AI_GETTER_S3_BUCKET")
SAVE_PATH = pl.Path(os.getenv("AI_GETTER_SAVE_PATH") or os.path.expanduser("~")) SAVE_PATH = pl.Path(os.getenv("AI_GETTER_SAVE_PATH") or os.path.expanduser("~"))
OPENAI_GPT_3_5_TURBO_COST_PER_1K_TOKENS_IN_TENTHS_OF_A_CENT = 2
OPENAI_DALE_COST_PER_IMAGE_IN_TENTHS_OF_A_CENT = 20
class NoBucket(Exception): class NoBucket(Exception):
pass pass
def chat(prompt: str, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> str: def chat(prompt: str, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> str:
start = time.time()
response = openai.ChatCompletion.create( response = openai.ChatCompletion.create(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content": prompt}], messages=[{"role": "user", "content": prompt}],
@@ -29,22 +35,26 @@ def chat(prompt: str, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False)
if save_to_s3: if save_to_s3:
if S3_BUCKET is None: if S3_BUCKET is None:
raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env") raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env")
upload_to_s3(S3_BUCKET, str(fp), str(fp)) upload_to_s3(bucket_name=S3_BUCKET, file_path=str(fp), key=str(fp), prompt=prompt, typ="text", vendor="openai")
pprint(response) total_tokens = response["usage"]["total_tokens"] # type: ignore
print(f"cost: {OPENAI_GPT_3_5_TURBO_COST_PER_1K_TOKENS_IN_TENTHS_OF_A_CENT * total_tokens * 10 / 1000} cents")
print(f"Time taken: {time.time() - start} seconds")
return content # type: ignore return content # type: ignore
def generate_images(prompt: str, num_images: int, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> dict: def generate_images(prompt: str, num_images: int, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> dict:
if num_images > 10: if num_images > 10:
raise ValueError("num_images must be <= 10") raise ValueError("num_images must be <= 10")
start = time.time()
res = openai.Image.create(prompt=prompt, n=num_images) # type: ignore res = openai.Image.create(prompt=prompt, n=num_images) # type: ignore
file_paths = save_images_from_openai(prompt, res, save_path) # type: ignore file_paths = save_images_from_openai(prompt, res, save_path) # type: ignore
if save_to_s3: if save_to_s3:
if S3_BUCKET is None: if S3_BUCKET is None:
raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env") raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env")
for idx, fp in enumerate(file_paths): for fp in file_paths:
upload_to_s3(S3_BUCKET, fp, f"{prompt}-{dt.date.today()}-{idx}") upload_to_s3(bucket_name=S3_BUCKET, file_path=fp, key=fp, prompt=prompt, typ="image", vendor="openai")
pprint(res) print(f"Time taken: {time.time() - start} seconds")
print(f"Cost: {OPENAI_DALE_COST_PER_IMAGE_IN_TENTHS_OF_A_CENT * num_images * 10} cents")
return res # type: ignore return res # type: ignore
@@ -55,7 +65,8 @@ Usage:
aig image <prompt> [--clip] --num-images <num_images> [--save-path <path>] [--s3] aig image <prompt> [--clip] --num-images <num_images> [--save-path <path>] [--s3]
aig text <prompt> [--clip] [--save-path <path>] [--s3] aig text <prompt> [--clip] [--save-path <path>] [--s3]
(--clip will get the prompt from your clipboard's contents) (--clip will get the prompt from your clipboard's contents, in addition to <prompt> if you supply one)
(--s3 will upload the result to your AI_GETTER_S3_BUCKET)
""") """)

View File

@@ -1,8 +1,10 @@
import datetime as dt import datetime as dt
import pathlib as pl import pathlib as pl
import typing as t
import boto3 import boto3
from requests import get from requests import get
import trans
def save_output(prompt: str, content: str, save_path: pl.Path) -> pl.Path: def save_output(prompt: str, content: str, save_path: pl.Path) -> pl.Path:
@@ -49,9 +51,39 @@ def save_images_from_openai(
return file_paths return file_paths
def upload_to_s3(bucket_name: str, file_path: str, key: str): def transform_prompt_for_aws_metadata(prompt: str) -> str:
MAX_NUM_PYTHON_CHARS_IN_S3_METADATA = (
1849 # not sure why, but this is close to 1.9kb
# AWS's limit is supposedly 2kb
)
truncated = prompt[:MAX_NUM_PYTHON_CHARS_IN_S3_METADATA]
backslashes_removed = truncated.replace("\n", "").replace("\t", "")
return trans.trans(backslashes_removed)
def upload_to_s3(
bucket_name: str,
file_path: str,
key: str,
prompt: str,
typ: t.Literal["image", "text"],
vendor: str,
):
s3c = boto3.client("s3") s3c = boto3.client("s3")
s3c.upload_file(file_path, bucket_name, key) transformed_prompt = transform_prompt_for_aws_metadata(prompt)
s3c.upload_file(
file_path,
bucket_name,
key,
ExtraArgs={
"Metadata": {
"prompt": transformed_prompt,
"date": str(dt.datetime.now()),
"type": typ,
"vendor": vendor,
}
},
)
def download_images(prompt: str, res: dict, save_path: pl.Path) -> list[str]: def download_images(prompt: str, res: dict, save_path: pl.Path) -> list[str]:

View File

@@ -1,4 +1,3 @@
-e git+https://github.com/zevaverbach/ai_getter@bc472274f6f10c180daa9ee267247cb51ea4ea15#egg=ai_getter
aiohttp==3.8.4 aiohttp==3.8.4
aiosignal==1.3.1 aiosignal==1.3.1
async-timeout==4.0.2 async-timeout==4.0.2