show duration and cost, make s3 upload work
This commit is contained in:
@@ -2,6 +2,7 @@ import datetime as dt
|
|||||||
import os
|
import os
|
||||||
import pathlib as pl
|
import pathlib as pl
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import pyperclip
|
import pyperclip
|
||||||
@@ -15,11 +16,16 @@ openai.api_key = os.getenv("OPENAI_TOKEN")
|
|||||||
S3_BUCKET = os.getenv("AI_GETTER_S3_BUCKET")
|
S3_BUCKET = os.getenv("AI_GETTER_S3_BUCKET")
|
||||||
SAVE_PATH = pl.Path(os.getenv("AI_GETTER_SAVE_PATH") or os.path.expanduser("~"))
|
SAVE_PATH = pl.Path(os.getenv("AI_GETTER_SAVE_PATH") or os.path.expanduser("~"))
|
||||||
|
|
||||||
|
OPENAI_GPT_3_5_TURBO_COST_PER_1K_TOKENS_IN_TENTHS_OF_A_CENT = 2
|
||||||
|
OPENAI_DALE_COST_PER_IMAGE_IN_TENTHS_OF_A_CENT = 20
|
||||||
|
|
||||||
|
|
||||||
class NoBucket(Exception):
|
class NoBucket(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def chat(prompt: str, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> str:
|
def chat(prompt: str, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> str:
|
||||||
|
start = time.time()
|
||||||
response = openai.ChatCompletion.create(
|
response = openai.ChatCompletion.create(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
messages=[{"role": "user", "content": prompt}],
|
messages=[{"role": "user", "content": prompt}],
|
||||||
@@ -29,22 +35,26 @@ def chat(prompt: str, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False)
|
|||||||
if save_to_s3:
|
if save_to_s3:
|
||||||
if S3_BUCKET is None:
|
if S3_BUCKET is None:
|
||||||
raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env")
|
raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env")
|
||||||
upload_to_s3(S3_BUCKET, str(fp), str(fp))
|
upload_to_s3(bucket_name=S3_BUCKET, file_path=str(fp), key=str(fp), prompt=prompt, typ="text", vendor="openai")
|
||||||
pprint(response)
|
total_tokens = response["usage"]["total_tokens"] # type: ignore
|
||||||
|
print(f"cost: {OPENAI_GPT_3_5_TURBO_COST_PER_1K_TOKENS_IN_TENTHS_OF_A_CENT * total_tokens * 10 / 1000} cents")
|
||||||
|
print(f"Time taken: {time.time() - start} seconds")
|
||||||
return content # type: ignore
|
return content # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def generate_images(prompt: str, num_images: int, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> dict:
|
def generate_images(prompt: str, num_images: int, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> dict:
|
||||||
if num_images > 10:
|
if num_images > 10:
|
||||||
raise ValueError("num_images must be <= 10")
|
raise ValueError("num_images must be <= 10")
|
||||||
|
start = time.time()
|
||||||
res = openai.Image.create(prompt=prompt, n=num_images) # type: ignore
|
res = openai.Image.create(prompt=prompt, n=num_images) # type: ignore
|
||||||
file_paths = save_images_from_openai(prompt, res, save_path) # type: ignore
|
file_paths = save_images_from_openai(prompt, res, save_path) # type: ignore
|
||||||
if save_to_s3:
|
if save_to_s3:
|
||||||
if S3_BUCKET is None:
|
if S3_BUCKET is None:
|
||||||
raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env")
|
raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env")
|
||||||
for idx, fp in enumerate(file_paths):
|
for fp in file_paths:
|
||||||
upload_to_s3(S3_BUCKET, fp, f"{prompt}-{dt.date.today()}-{idx}")
|
upload_to_s3(bucket_name=S3_BUCKET, file_path=fp, key=fp, prompt=prompt, typ="image", vendor="openai")
|
||||||
pprint(res)
|
print(f"Time taken: {time.time() - start} seconds")
|
||||||
|
print(f"Cost: {OPENAI_DALE_COST_PER_IMAGE_IN_TENTHS_OF_A_CENT * num_images * 10} cents")
|
||||||
return res # type: ignore
|
return res # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@@ -55,7 +65,8 @@ Usage:
|
|||||||
aig image <prompt> [--clip] --num-images <num_images> [--save-path <path>] [--s3]
|
aig image <prompt> [--clip] --num-images <num_images> [--save-path <path>] [--s3]
|
||||||
aig text <prompt> [--clip] [--save-path <path>] [--s3]
|
aig text <prompt> [--clip] [--save-path <path>] [--s3]
|
||||||
|
|
||||||
(--clip will get the prompt from your clipboard's contents)
|
(--clip will get the prompt from your clipboard's contents, in addition to <prompt> if you supply one)
|
||||||
|
(--s3 will upload the result to your AI_GETTER_S3_BUCKET)
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
import datetime as dt
|
import datetime as dt
|
||||||
import pathlib as pl
|
import pathlib as pl
|
||||||
|
import typing as t
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
from requests import get
|
from requests import get
|
||||||
|
import trans
|
||||||
|
|
||||||
|
|
||||||
def save_output(prompt: str, content: str, save_path: pl.Path) -> pl.Path:
|
def save_output(prompt: str, content: str, save_path: pl.Path) -> pl.Path:
|
||||||
@@ -49,9 +51,39 @@ def save_images_from_openai(
|
|||||||
return file_paths
|
return file_paths
|
||||||
|
|
||||||
|
|
||||||
def upload_to_s3(bucket_name: str, file_path: str, key: str):
|
def transform_prompt_for_aws_metadata(prompt: str) -> str:
|
||||||
|
MAX_NUM_PYTHON_CHARS_IN_S3_METADATA = (
|
||||||
|
1849 # not sure why, but this is close to 1.9kb
|
||||||
|
# AWS's limit is supposedly 2kb
|
||||||
|
)
|
||||||
|
truncated = prompt[:MAX_NUM_PYTHON_CHARS_IN_S3_METADATA]
|
||||||
|
backslashes_removed = truncated.replace("\n", "").replace("\t", "")
|
||||||
|
return trans.trans(backslashes_removed)
|
||||||
|
|
||||||
|
|
||||||
|
def upload_to_s3(
|
||||||
|
bucket_name: str,
|
||||||
|
file_path: str,
|
||||||
|
key: str,
|
||||||
|
prompt: str,
|
||||||
|
typ: t.Literal["image", "text"],
|
||||||
|
vendor: str,
|
||||||
|
):
|
||||||
s3c = boto3.client("s3")
|
s3c = boto3.client("s3")
|
||||||
s3c.upload_file(file_path, bucket_name, key)
|
transformed_prompt = transform_prompt_for_aws_metadata(prompt)
|
||||||
|
s3c.upload_file(
|
||||||
|
file_path,
|
||||||
|
bucket_name,
|
||||||
|
key,
|
||||||
|
ExtraArgs={
|
||||||
|
"Metadata": {
|
||||||
|
"prompt": transformed_prompt,
|
||||||
|
"date": str(dt.datetime.now()),
|
||||||
|
"type": typ,
|
||||||
|
"vendor": vendor,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def download_images(prompt: str, res: dict, save_path: pl.Path) -> list[str]:
|
def download_images(prompt: str, res: dict, save_path: pl.Path) -> list[str]:
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
-e git+https://github.com/zevaverbach/ai_getter@bc472274f6f10c180daa9ee267247cb51ea4ea15#egg=ai_getter
|
|
||||||
aiohttp==3.8.4
|
aiohttp==3.8.4
|
||||||
aiosignal==1.3.1
|
aiosignal==1.3.1
|
||||||
async-timeout==4.0.2
|
async-timeout==4.0.2
|
||||||
|
|||||||
Reference in New Issue
Block a user