always save option
This commit is contained in:
@@ -14,6 +14,7 @@ openai.api_key = os.getenv("OPENAI_TOKEN")
|
||||
|
||||
S3_BUCKET = os.getenv("AI_GETTER_S3_BUCKET")
|
||||
SAVE_PATH = pl.Path(os.getenv("AI_GETTER_SAVE_PATH") or os.path.expanduser("~"))
|
||||
ALWAYS_SAVE_TO_S3 = os.getenv("AI_GETTER_ALWAYS_SAVE_TO_S3") == "1"
|
||||
|
||||
OPENAI_GPT_3_5_TURBO_COST_PER_1K_TOKENS_IN_TENTHS_OF_A_CENT = 2
|
||||
OPENAI_DALE_COST_PER_IMAGE_IN_TENTHS_OF_A_CENT = 20
|
||||
@@ -55,7 +56,7 @@ def generate_images(prompt: str, num_images: int, save_path: pl.Path = SAVE_PATH
|
||||
print("Uploading images to S3...")
|
||||
for idx, fp in enumerate(file_paths):
|
||||
print(idx)
|
||||
upload_to_s3(bucket_name=S3_BUCKET, file_path=fp, key=fp, prompt=prompt, typ="image", vendor="openai")
|
||||
upload_to_s3(bucket_name=S3_BUCKET, file_path=str(fp), key=fp.name, prompt=prompt, typ="image", vendor="openai")
|
||||
print(f"Time taken: {time.time() - start} seconds")
|
||||
print(f"Cost: {OPENAI_DALE_COST_PER_IMAGE_IN_TENTHS_OF_A_CENT * num_images / 10} cents")
|
||||
return res # type: ignore
|
||||
@@ -68,8 +69,9 @@ Usage:
|
||||
aig image <prompt> [--clip] --num-images <num_images> [--save-path <path>] [--s3]
|
||||
aig text <prompt> [--clip] [--save-path <path>] [--s3]
|
||||
|
||||
(--clip will get the prompt from your clipboard's contents, in addition to <prompt> if you supply one)
|
||||
(--s3 will upload the result to your AI_GETTER_S3_BUCKET)
|
||||
--clip will get the prompt from your clipboard's contents, in addition to <prompt> if you supply one
|
||||
--s3 will upload the result to your AI_GETTER_S3_BUCKET
|
||||
set env var AI_GETTER_ALWAYS_SAVE_TO_S3=1 to save to s3 by default
|
||||
""")
|
||||
|
||||
|
||||
@@ -98,7 +100,7 @@ def main():
|
||||
return
|
||||
|
||||
save_to_s3 = False
|
||||
if "--s3" in the_rest:
|
||||
if "--s3" in the_rest or ALWAYS_SAVE_TO_S3:
|
||||
if S3_BUCKET is None:
|
||||
print("Please provide AI_GETTER_S3_BUCKET in .env")
|
||||
return
|
||||
|
||||
@@ -86,12 +86,12 @@ def upload_to_s3(
|
||||
)
|
||||
|
||||
|
||||
def download_images(prompt: str, res: dict, save_path: pl.Path) -> list[str]:
|
||||
def download_images(prompt: str, res: dict, save_path: pl.Path) -> list[pl.Path]:
|
||||
fns = []
|
||||
for idx, image_dict in enumerate(res["data"]):
|
||||
fn = make_fp_from_prompt(prompt, save_path, index=idx, ext="jpg")
|
||||
url = image_dict["url"]
|
||||
print(f"Downloading image from {url} to {fn}")
|
||||
print(f"Downloading image {idx + 1}")
|
||||
download(url, fn)
|
||||
fns.append(fn)
|
||||
return fns
|
||||
|
||||
Reference in New Issue
Block a user