show duration and cost, make s3 upload work
This commit is contained in:
@@ -2,6 +2,7 @@ import datetime as dt
|
||||
import os
|
||||
import pathlib as pl
|
||||
import sys
|
||||
import time
|
||||
|
||||
import openai
|
||||
import pyperclip
|
||||
@@ -15,11 +16,16 @@ openai.api_key = os.getenv("OPENAI_TOKEN")
|
||||
S3_BUCKET = os.getenv("AI_GETTER_S3_BUCKET")
|
||||
SAVE_PATH = pl.Path(os.getenv("AI_GETTER_SAVE_PATH") or os.path.expanduser("~"))
|
||||
|
||||
OPENAI_GPT_3_5_TURBO_COST_PER_1K_TOKENS_IN_TENTHS_OF_A_CENT = 2
|
||||
OPENAI_DALE_COST_PER_IMAGE_IN_TENTHS_OF_A_CENT = 20
|
||||
|
||||
|
||||
class NoBucket(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def chat(prompt: str, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> str:
|
||||
start = time.time()
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
@@ -29,22 +35,26 @@ def chat(prompt: str, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False)
|
||||
if save_to_s3:
|
||||
if S3_BUCKET is None:
|
||||
raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env")
|
||||
upload_to_s3(S3_BUCKET, str(fp), str(fp))
|
||||
pprint(response)
|
||||
upload_to_s3(bucket_name=S3_BUCKET, file_path=str(fp), key=str(fp), prompt=prompt, typ="text", vendor="openai")
|
||||
total_tokens = response["usage"]["total_tokens"] # type: ignore
|
||||
print(f"cost: {OPENAI_GPT_3_5_TURBO_COST_PER_1K_TOKENS_IN_TENTHS_OF_A_CENT * total_tokens * 10 / 1000} cents")
|
||||
print(f"Time taken: {time.time() - start} seconds")
|
||||
return content # type: ignore
|
||||
|
||||
|
||||
def generate_images(prompt: str, num_images: int, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> dict:
|
||||
if num_images > 10:
|
||||
raise ValueError("num_images must be <= 10")
|
||||
start = time.time()
|
||||
res = openai.Image.create(prompt=prompt, n=num_images) # type: ignore
|
||||
file_paths = save_images_from_openai(prompt, res, save_path) # type: ignore
|
||||
if save_to_s3:
|
||||
if S3_BUCKET is None:
|
||||
raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env")
|
||||
for idx, fp in enumerate(file_paths):
|
||||
upload_to_s3(S3_BUCKET, fp, f"{prompt}-{dt.date.today()}-{idx}")
|
||||
pprint(res)
|
||||
for fp in file_paths:
|
||||
upload_to_s3(bucket_name=S3_BUCKET, file_path=fp, key=fp, prompt=prompt, typ="image", vendor="openai")
|
||||
print(f"Time taken: {time.time() - start} seconds")
|
||||
print(f"Cost: {OPENAI_DALE_COST_PER_IMAGE_IN_TENTHS_OF_A_CENT * num_images * 10} cents")
|
||||
return res # type: ignore
|
||||
|
||||
|
||||
@@ -55,7 +65,8 @@ Usage:
|
||||
aig image <prompt> [--clip] --num-images <num_images> [--save-path <path>] [--s3]
|
||||
aig text <prompt> [--clip] [--save-path <path>] [--s3]
|
||||
|
||||
(--clip will get the prompt from your clipboard's contents)
|
||||
(--clip will get the prompt from your clipboard's contents, in addition to <prompt> if you supply one)
|
||||
(--s3 will upload the result to your AI_GETTER_S3_BUCKET)
|
||||
""")
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import datetime as dt
|
||||
import pathlib as pl
|
||||
import typing as t
|
||||
|
||||
import boto3
|
||||
from requests import get
|
||||
import trans
|
||||
|
||||
|
||||
def save_output(prompt: str, content: str, save_path: pl.Path) -> pl.Path:
|
||||
@@ -49,9 +51,39 @@ def save_images_from_openai(
|
||||
return file_paths
|
||||
|
||||
|
||||
def upload_to_s3(bucket_name: str, file_path: str, key: str):
|
||||
def transform_prompt_for_aws_metadata(prompt: str) -> str:
|
||||
MAX_NUM_PYTHON_CHARS_IN_S3_METADATA = (
|
||||
1849 # not sure why, but this is close to 1.9kb
|
||||
# AWS's limit is supposedly 2kb
|
||||
)
|
||||
truncated = prompt[:MAX_NUM_PYTHON_CHARS_IN_S3_METADATA]
|
||||
backslashes_removed = truncated.replace("\n", "").replace("\t", "")
|
||||
return trans.trans(backslashes_removed)
|
||||
|
||||
|
||||
def upload_to_s3(
|
||||
bucket_name: str,
|
||||
file_path: str,
|
||||
key: str,
|
||||
prompt: str,
|
||||
typ: t.Literal["image", "text"],
|
||||
vendor: str,
|
||||
):
|
||||
s3c = boto3.client("s3")
|
||||
s3c.upload_file(file_path, bucket_name, key)
|
||||
transformed_prompt = transform_prompt_for_aws_metadata(prompt)
|
||||
s3c.upload_file(
|
||||
file_path,
|
||||
bucket_name,
|
||||
key,
|
||||
ExtraArgs={
|
||||
"Metadata": {
|
||||
"prompt": transformed_prompt,
|
||||
"date": str(dt.datetime.now()),
|
||||
"type": typ,
|
||||
"vendor": vendor,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def download_images(prompt: str, res: dict, save_path: pl.Path) -> list[str]:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
-e git+https://github.com/zevaverbach/ai_getter@bc472274f6f10c180daa9ee267247cb51ea4ea15#egg=ai_getter
|
||||
aiohttp==3.8.4
|
||||
aiosignal==1.3.1
|
||||
async-timeout==4.0.2
|
||||
|
||||
Reference in New Issue
Block a user