From fa6bf6c6aba7d6ef9589d1d365e8b02ec9352b94 Mon Sep 17 00:00:00 2001 From: Zev Averbach Date: Sun, 2 Apr 2023 23:09:00 +0200 Subject: [PATCH] always save option --- ai_getter/main.py | 10 ++++++---- ai_getter/save.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/ai_getter/main.py b/ai_getter/main.py index 22dd746..7d60962 100644 --- a/ai_getter/main.py +++ b/ai_getter/main.py @@ -14,6 +14,7 @@ openai.api_key = os.getenv("OPENAI_TOKEN") S3_BUCKET = os.getenv("AI_GETTER_S3_BUCKET") SAVE_PATH = pl.Path(os.getenv("AI_GETTER_SAVE_PATH") or os.path.expanduser("~")) +ALWAYS_SAVE_TO_S3 = os.getenv("AI_GETTER_ALWAYS_SAVE_TO_S3") == "1" OPENAI_GPT_3_5_TURBO_COST_PER_1K_TOKENS_IN_TENTHS_OF_A_CENT = 2 OPENAI_DALE_COST_PER_IMAGE_IN_TENTHS_OF_A_CENT = 20 @@ -55,7 +56,7 @@ def generate_images(prompt: str, num_images: int, save_path: pl.Path = SAVE_PATH print("Uploading images to S3...") for idx, fp in enumerate(file_paths): print(idx) - upload_to_s3(bucket_name=S3_BUCKET, file_path=fp, key=fp, prompt=prompt, typ="image", vendor="openai") + upload_to_s3(bucket_name=S3_BUCKET, file_path=str(fp), key=fp.name, prompt=prompt, typ="image", vendor="openai") print(f"Time taken: {time.time() - start} seconds") print(f"Cost: {OPENAI_DALE_COST_PER_IMAGE_IN_TENTHS_OF_A_CENT * num_images / 10} cents") return res # type: ignore @@ -68,8 +69,9 @@ Usage: aig image [--clip] --num-images [--save-path ] [--s3] aig text [--clip] [--save-path ] [--s3] - (--clip will get the prompt from your clipboard's contents, in addition to if you supply one) - (--s3 will upload the result to your AI_GETTER_S3_BUCKET) + --clip will get the prompt from your clipboard's contents, in addition to if you supply one + --s3 will upload the result to your AI_GETTER_S3_BUCKET + set env var AI_GETTER_ALWAYS_SAVE_TO_S3=1 to save to s3 by default """) @@ -98,7 +100,7 @@ def main(): return save_to_s3 = False - if "--s3" in the_rest: + if "--s3" in the_rest or ALWAYS_SAVE_TO_S3: if S3_BUCKET is None: print("Please provide AI_GETTER_S3_BUCKET in .env") return diff --git a/ai_getter/save.py b/ai_getter/save.py index 8a69285..0ada496 100644 --- a/ai_getter/save.py +++ b/ai_getter/save.py @@ -86,12 +86,12 @@ def upload_to_s3( ) -def download_images(prompt: str, res: dict, save_path: pl.Path) -> list[str]: +def download_images(prompt: str, res: dict, save_path: pl.Path) -> list[pl.Path]: fns = [] for idx, image_dict in enumerate(res["data"]): fn = make_fp_from_prompt(prompt, save_path, index=idx, ext="jpg") url = image_dict["url"] - print(f"Downloading image from {url} to {fn}") + print(f"Downloading image {idx + 1}") download(url, fn) fns.append(fn) return fns