This commit is contained in:
2023-03-31 23:59:21 +02:00
commit bc472274f6
4 changed files with 161 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
.env
__pycache__/

10
README.md Normal file
View File

@@ -0,0 +1,10 @@
# Environment Variables
- `OPENAI_ORG`
- `OPENAI_TOKEN`
- `AI_GETTER_SAVE_PATH`
- `AI_GETTER_S3_BUCKET`
# Credentials
Make sure you have credentials in an `~/.aws` directory if you want to upload any outputs to S3.

91
main.py Normal file
View File

@@ -0,0 +1,91 @@
import datetime as dt
import os
import pathlib as pl
import sys
import openai
from save import save_images_from_openai, upload_to_s3, save_output
openai.organization = os.getenv("OPENAI_ORG")
openai.api_key = os.getenv("OPENAI_TOKEN")
S3_BUCKET = os.getenv("AI_GETTER_S3_BUCKET")
SAVE_PATH = pl.Path(os.getenv("AI_GETTER_SAVE_PATH") or os.path.expanduser("~"))
class NoBucket(Exception):
pass
def chat(prompt: str, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> str:
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": prompt}],
)
content = response["choices"][0]["message"]["content"] # type: ignore
fp = save_output(prompt, content, save_path)
if save_to_s3:
if S3_BUCKET is None:
raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env")
upload_to_s3(S3_BUCKET, str(fp), str(fp))
return content # type: ignore
def generate_images(prompt: str, num_images: int, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> dict:
if num_images > 10:
raise ValueError("num_images must be <= 10")
res = openai.Image.create(prompt=prompt, n=num_images) # type: ignore
file_paths = save_images_from_openai(prompt, res, save_path) # type: ignore
if save_to_s3:
if S3_BUCKET is None:
raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env")
for idx, fp in enumerate(file_paths):
upload_to_s3(S3_BUCKET, fp, f"{prompt}-{dt.date.today()}-{idx}")
return res # type: ignore
def main():
args = sys.argv
if len(args) < 3:
print("Please provide <image/text> and a prompt")
sys.exit()
typ, prompt = args[1:3]
if typ not in "image text".split():
print("Please provide <image/text> and a prompt")
sys.exit()
if "--save-path" in args:
save_path = pl.Path(args[args.index("--save-path") + 1])
else:
save_path = SAVE_PATH
save_to_s3 = False
if "--save-to-s3" in args:
save_to_s3 = True
if save_to_s3 and S3_BUCKET is None:
print("Please provide AI_GETTER_S3_BUCKET in .env")
sys.exit()
match typ:
case "image":
if "--num-images" not in args:
print("Please provide --num-images")
sys.exit()
num_images = int(args[args.index("--num-images") + 1])
try:
generate_images(prompt, num_images, save_path, save_to_s3)
except ValueError as e:
print(str(e))
sys.exit()
case "text":
chat(prompt, save_path, save_to_s3)
case _:
print("Please provide <image/text> and a prompt")
sys.exit()
print("Okay, we're done!")
if __name__ == "__main__":
main()

58
save.py Normal file
View File

@@ -0,0 +1,58 @@
import datetime as dt
import pathlib as pl
import boto3
from requests import get
def save_output(prompt: str, content: str, save_path: pl.Path) -> pl.Path:
fp = make_fp_from_prompt(prompt, save_path, ext="txt")
fp.write_text(content)
return fp
def make_fp_from_prompt(
prompt: str,
save_path: pl.Path,
ext: str,
index: int | None = None,
) -> pl.Path:
prompt_fn = (
prompt.replace(" ", "_")
.replace("'", "")
.replace(".", "")
.replace(",", "")
.lower()
)[: 225 - len(str(save_path))]
if index is not None:
prompt_fn = f"{prompt_fn}{index}"
prompt_fn = f"{prompt_fn}-{dt.datetime.now()}"
prompt_fn = f"{prompt_fn}.{ext}"
return save_path / prompt_fn
def save_images_from_openai(
description: str,
res: dict,
save_path: pl.Path,
):
file_paths = download_images(description, res, save_path) # type: ignore
return file_paths
def upload_to_s3(bucket_name: str, file_path: str, key: str):
s3c = boto3.client("s3")
s3c.upload_file(file_path, bucket_name, key)
def download_images(prompt: str, res: dict, save_path: pl.Path) -> list[str]:
fns = []
for idx, image_dict in enumerate(res["data"]):
fn = make_fp_from_prompt(prompt, save_path, index=idx, ext="jpg")
download(image_dict["url"], fn)
fns.append(fn)
return fns
def download(url: str, fp: pl.Path):
fp.write_bytes(get(url).content)