Files
ai_getter/main.py
2023-03-31 23:59:21 +02:00

92 lines
2.8 KiB
Python

import datetime as dt
import os
import pathlib as pl
import sys
import openai
from save import save_images_from_openai, upload_to_s3, save_output
openai.organization = os.getenv("OPENAI_ORG")
openai.api_key = os.getenv("OPENAI_TOKEN")
S3_BUCKET = os.getenv("AI_GETTER_S3_BUCKET")
SAVE_PATH = pl.Path(os.getenv("AI_GETTER_SAVE_PATH") or os.path.expanduser("~"))
class NoBucket(Exception):
pass
def chat(prompt: str, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> str:
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": prompt}],
)
content = response["choices"][0]["message"]["content"] # type: ignore
fp = save_output(prompt, content, save_path)
if save_to_s3:
if S3_BUCKET is None:
raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env")
upload_to_s3(S3_BUCKET, str(fp), str(fp))
return content # type: ignore
def generate_images(prompt: str, num_images: int, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> dict:
if num_images > 10:
raise ValueError("num_images must be <= 10")
res = openai.Image.create(prompt=prompt, n=num_images) # type: ignore
file_paths = save_images_from_openai(prompt, res, save_path) # type: ignore
if save_to_s3:
if S3_BUCKET is None:
raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env")
for idx, fp in enumerate(file_paths):
upload_to_s3(S3_BUCKET, fp, f"{prompt}-{dt.date.today()}-{idx}")
return res # type: ignore
def main():
args = sys.argv
if len(args) < 3:
print("Please provide <image/text> and a prompt")
sys.exit()
typ, prompt = args[1:3]
if typ not in "image text".split():
print("Please provide <image/text> and a prompt")
sys.exit()
if "--save-path" in args:
save_path = pl.Path(args[args.index("--save-path") + 1])
else:
save_path = SAVE_PATH
save_to_s3 = False
if "--save-to-s3" in args:
save_to_s3 = True
if save_to_s3 and S3_BUCKET is None:
print("Please provide AI_GETTER_S3_BUCKET in .env")
sys.exit()
match typ:
case "image":
if "--num-images" not in args:
print("Please provide --num-images")
sys.exit()
num_images = int(args[args.index("--num-images") + 1])
try:
generate_images(prompt, num_images, save_path, save_to_s3)
except ValueError as e:
print(str(e))
sys.exit()
case "text":
chat(prompt, save_path, save_to_s3)
case _:
print("Please provide <image/text> and a prompt")
sys.exit()
print("Okay, we're done!")
if __name__ == "__main__":
main()