From bc472274f6f10c180daa9ee267247cb51ea4ea15 Mon Sep 17 00:00:00 2001 From: Zev Averbach Date: Fri, 31 Mar 2023 23:59:21 +0200 Subject: [PATCH] first --- .gitignore | 2 ++ README.md | 10 ++++++ main.py | 91 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ save.py | 58 ++++++++++++++++++++++++++++++++++ 4 files changed, 161 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 main.py create mode 100644 save.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d50a09f --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.env +__pycache__/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..556bf06 --- /dev/null +++ b/README.md @@ -0,0 +1,10 @@ +# Environment Variables + +- `OPENAI_ORG` +- `OPENAI_TOKEN` +- `AI_GETTER_SAVE_PATH` +- `AI_GETTER_S3_BUCKET` + +# Credentials + +Make sure you have credentials in an `~/.aws` directory if you want to upload any outputs to S3. diff --git a/main.py b/main.py new file mode 100644 index 0000000..35d0543 --- /dev/null +++ b/main.py @@ -0,0 +1,91 @@ +import datetime as dt +import os +import pathlib as pl +import sys + +import openai + +from save import save_images_from_openai, upload_to_s3, save_output + +openai.organization = os.getenv("OPENAI_ORG") +openai.api_key = os.getenv("OPENAI_TOKEN") + +S3_BUCKET = os.getenv("AI_GETTER_S3_BUCKET") +SAVE_PATH = pl.Path(os.getenv("AI_GETTER_SAVE_PATH") or os.path.expanduser("~")) + +class NoBucket(Exception): + pass + + +def chat(prompt: str, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> str: + response = openai.ChatCompletion.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": prompt}], + ) + content = response["choices"][0]["message"]["content"] # type: ignore + fp = save_output(prompt, content, save_path) + if save_to_s3: + if S3_BUCKET is None: + raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env") + upload_to_s3(S3_BUCKET, str(fp), str(fp)) + return content # type: ignore + + +def generate_images(prompt: str, num_images: int, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> dict: + if num_images > 10: + raise ValueError("num_images must be <= 10") + res = openai.Image.create(prompt=prompt, n=num_images) # type: ignore + file_paths = save_images_from_openai(prompt, res, save_path) # type: ignore + if save_to_s3: + if S3_BUCKET is None: + raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env") + for idx, fp in enumerate(file_paths): + upload_to_s3(S3_BUCKET, fp, f"{prompt}-{dt.date.today()}-{idx}") + return res # type: ignore + + + +def main(): + args = sys.argv + if len(args) < 3: + print("Please provide and a prompt") + sys.exit() + + typ, prompt = args[1:3] + if typ not in "image text".split(): + print("Please provide and a prompt") + sys.exit() + + if "--save-path" in args: + save_path = pl.Path(args[args.index("--save-path") + 1]) + else: + save_path = SAVE_PATH + + save_to_s3 = False + if "--save-to-s3" in args: + save_to_s3 = True + + if save_to_s3 and S3_BUCKET is None: + print("Please provide AI_GETTER_S3_BUCKET in .env") + sys.exit() + + match typ: + case "image": + if "--num-images" not in args: + print("Please provide --num-images") + sys.exit() + num_images = int(args[args.index("--num-images") + 1]) + try: + generate_images(prompt, num_images, save_path, save_to_s3) + except ValueError as e: + print(str(e)) + sys.exit() + case "text": + chat(prompt, save_path, save_to_s3) + case _: + print("Please provide and a prompt") + sys.exit() + print("Okay, we're done!") + +if __name__ == "__main__": + main() diff --git a/save.py b/save.py new file mode 100644 index 0000000..7c8e200 --- /dev/null +++ b/save.py @@ -0,0 +1,58 @@ +import datetime as dt +import pathlib as pl + +import boto3 +from requests import get + + +def save_output(prompt: str, content: str, save_path: pl.Path) -> pl.Path: + fp = make_fp_from_prompt(prompt, save_path, ext="txt") + fp.write_text(content) + return fp + + +def make_fp_from_prompt( + prompt: str, + save_path: pl.Path, + ext: str, + index: int | None = None, +) -> pl.Path: + prompt_fn = ( + prompt.replace(" ", "_") + .replace("'", "") + .replace(".", "") + .replace(",", "") + .lower() + )[: 225 - len(str(save_path))] + if index is not None: + prompt_fn = f"{prompt_fn}{index}" + prompt_fn = f"{prompt_fn}-{dt.datetime.now()}" + prompt_fn = f"{prompt_fn}.{ext}" + return save_path / prompt_fn + + +def save_images_from_openai( + description: str, + res: dict, + save_path: pl.Path, +): + file_paths = download_images(description, res, save_path) # type: ignore + return file_paths + + +def upload_to_s3(bucket_name: str, file_path: str, key: str): + s3c = boto3.client("s3") + s3c.upload_file(file_path, bucket_name, key) + + +def download_images(prompt: str, res: dict, save_path: pl.Path) -> list[str]: + fns = [] + for idx, image_dict in enumerate(res["data"]): + fn = make_fp_from_prompt(prompt, save_path, index=idx, ext="jpg") + download(image_dict["url"], fn) + fns.append(fn) + return fns + + +def download(url: str, fp: pl.Path): + fp.write_bytes(get(url).content)