first
This commit is contained in:
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
.env
|
||||
__pycache__/
|
||||
10
README.md
Normal file
10
README.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# Environment Variables
|
||||
|
||||
- `OPENAI_ORG`
|
||||
- `OPENAI_TOKEN`
|
||||
- `AI_GETTER_SAVE_PATH`
|
||||
- `AI_GETTER_S3_BUCKET`
|
||||
|
||||
# Credentials
|
||||
|
||||
Make sure you have credentials in an `~/.aws` directory if you want to upload any outputs to S3.
|
||||
91
main.py
Normal file
91
main.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import datetime as dt
|
||||
import os
|
||||
import pathlib as pl
|
||||
import sys
|
||||
|
||||
import openai
|
||||
|
||||
from save import save_images_from_openai, upload_to_s3, save_output
|
||||
|
||||
openai.organization = os.getenv("OPENAI_ORG")
|
||||
openai.api_key = os.getenv("OPENAI_TOKEN")
|
||||
|
||||
S3_BUCKET = os.getenv("AI_GETTER_S3_BUCKET")
|
||||
SAVE_PATH = pl.Path(os.getenv("AI_GETTER_SAVE_PATH") or os.path.expanduser("~"))
|
||||
|
||||
class NoBucket(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def chat(prompt: str, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> str:
|
||||
response = openai.ChatCompletion.create(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
)
|
||||
content = response["choices"][0]["message"]["content"] # type: ignore
|
||||
fp = save_output(prompt, content, save_path)
|
||||
if save_to_s3:
|
||||
if S3_BUCKET is None:
|
||||
raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env")
|
||||
upload_to_s3(S3_BUCKET, str(fp), str(fp))
|
||||
return content # type: ignore
|
||||
|
||||
|
||||
def generate_images(prompt: str, num_images: int, save_path: pl.Path = SAVE_PATH, save_to_s3: bool = False) -> dict:
|
||||
if num_images > 10:
|
||||
raise ValueError("num_images must be <= 10")
|
||||
res = openai.Image.create(prompt=prompt, n=num_images) # type: ignore
|
||||
file_paths = save_images_from_openai(prompt, res, save_path) # type: ignore
|
||||
if save_to_s3:
|
||||
if S3_BUCKET is None:
|
||||
raise NoBucket("Please provide AI_GETTER_S3_BUCKET in .env")
|
||||
for idx, fp in enumerate(file_paths):
|
||||
upload_to_s3(S3_BUCKET, fp, f"{prompt}-{dt.date.today()}-{idx}")
|
||||
return res # type: ignore
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
args = sys.argv
|
||||
if len(args) < 3:
|
||||
print("Please provide <image/text> and a prompt")
|
||||
sys.exit()
|
||||
|
||||
typ, prompt = args[1:3]
|
||||
if typ not in "image text".split():
|
||||
print("Please provide <image/text> and a prompt")
|
||||
sys.exit()
|
||||
|
||||
if "--save-path" in args:
|
||||
save_path = pl.Path(args[args.index("--save-path") + 1])
|
||||
else:
|
||||
save_path = SAVE_PATH
|
||||
|
||||
save_to_s3 = False
|
||||
if "--save-to-s3" in args:
|
||||
save_to_s3 = True
|
||||
|
||||
if save_to_s3 and S3_BUCKET is None:
|
||||
print("Please provide AI_GETTER_S3_BUCKET in .env")
|
||||
sys.exit()
|
||||
|
||||
match typ:
|
||||
case "image":
|
||||
if "--num-images" not in args:
|
||||
print("Please provide --num-images")
|
||||
sys.exit()
|
||||
num_images = int(args[args.index("--num-images") + 1])
|
||||
try:
|
||||
generate_images(prompt, num_images, save_path, save_to_s3)
|
||||
except ValueError as e:
|
||||
print(str(e))
|
||||
sys.exit()
|
||||
case "text":
|
||||
chat(prompt, save_path, save_to_s3)
|
||||
case _:
|
||||
print("Please provide <image/text> and a prompt")
|
||||
sys.exit()
|
||||
print("Okay, we're done!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
58
save.py
Normal file
58
save.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import datetime as dt
|
||||
import pathlib as pl
|
||||
|
||||
import boto3
|
||||
from requests import get
|
||||
|
||||
|
||||
def save_output(prompt: str, content: str, save_path: pl.Path) -> pl.Path:
|
||||
fp = make_fp_from_prompt(prompt, save_path, ext="txt")
|
||||
fp.write_text(content)
|
||||
return fp
|
||||
|
||||
|
||||
def make_fp_from_prompt(
|
||||
prompt: str,
|
||||
save_path: pl.Path,
|
||||
ext: str,
|
||||
index: int | None = None,
|
||||
) -> pl.Path:
|
||||
prompt_fn = (
|
||||
prompt.replace(" ", "_")
|
||||
.replace("'", "")
|
||||
.replace(".", "")
|
||||
.replace(",", "")
|
||||
.lower()
|
||||
)[: 225 - len(str(save_path))]
|
||||
if index is not None:
|
||||
prompt_fn = f"{prompt_fn}{index}"
|
||||
prompt_fn = f"{prompt_fn}-{dt.datetime.now()}"
|
||||
prompt_fn = f"{prompt_fn}.{ext}"
|
||||
return save_path / prompt_fn
|
||||
|
||||
|
||||
def save_images_from_openai(
|
||||
description: str,
|
||||
res: dict,
|
||||
save_path: pl.Path,
|
||||
):
|
||||
file_paths = download_images(description, res, save_path) # type: ignore
|
||||
return file_paths
|
||||
|
||||
|
||||
def upload_to_s3(bucket_name: str, file_path: str, key: str):
|
||||
s3c = boto3.client("s3")
|
||||
s3c.upload_file(file_path, bucket_name, key)
|
||||
|
||||
|
||||
def download_images(prompt: str, res: dict, save_path: pl.Path) -> list[str]:
|
||||
fns = []
|
||||
for idx, image_dict in enumerate(res["data"]):
|
||||
fn = make_fp_from_prompt(prompt, save_path, index=idx, ext="jpg")
|
||||
download(image_dict["url"], fn)
|
||||
fns.append(fn)
|
||||
return fns
|
||||
|
||||
|
||||
def download(url: str, fp: pl.Path):
|
||||
fp.write_bytes(get(url).content)
|
||||
Reference in New Issue
Block a user