-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathupscale.py
More file actions
103 lines (89 loc) · 4.17 KB
/
upscale.py
File metadata and controls
103 lines (89 loc) · 4.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from PIL import Image
import os
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
import warnings
from io import BytesIO
from eeeee import UpscaleException
from error_code import UpscaleErrorCode
import grpc
from dotenv import load_dotenv
load_dotenv()
'''
gRPC 방식 stability API 호출 코드
esrgan-v1-x2plus 엔진으로 사용시 0.2토큰 ($0.002)
가능한 파라미터
init_image=
width= and height= parameters are accepted.
only width= OR height= 가능 (width 와 height 둘 다 동시에 설정 X )
If no width= or height= parameter is provided,
the image will be upscaled to 2x or 4x its dimensions by default depending on the engine in use.
'''
class UpscaleManager:
def __init__(self, user_folder_path: str,
token: str = os.environ.get('STABILITY_KEY')) -> None:
self.width = 1024
self.user_folder_path = user_folder_path
self.token = token
# ------------------
def get_image_from_path(self, origin_img_name: str):
"""
Get image for Upscale
Param:
- origin_img_name : 원본 이미지 이름(example.png)
Return:
- Pillow Image 객체
"""
# get image from path
origin_img_path = os.path.abspath(os.path.join(self.user_folder_path, origin_img_name))
if os.path.exists(origin_img_path):
image = Image.open(origin_img_path).convert("RGB")
if all(pixel == (0, 0, 0) for pixel in list(image.getdata())) or all(pixel == (255, 255, 255) for pixel in list(image.getdata())):
raise UpscaleException(**UpscaleErrorCode.WrongImageError.value)
else:
raise UpscaleException(**UpscaleErrorCode.FileNotFoundError.value)
return image
def grpc_upscale_call(self, image: Image, origin_img_name: str, api_key=None): #token test 때문에 넣음
"""
Get generated upscaled image from stability ai grpc
Params:
- image : 사용할 이미지 개체
- origin_img_name: 저장할 img 이름
Return:
- origin_img_path: 저장한 img 경로
Description:
- image should be pillow Image
"""
if api_key is None:
api_key = self.token
origin_img_path = os.path.abspath(os.path.join(self.user_folder_path, 'upscaled_' + origin_img_name))
stability_api = client.StabilityInference(
key=api_key, # API Key reference.
upscale_engine="esrgan-v1-x2plus", # The name of upscaling model
verbose=True, # Print debug messages.
)
answers = stability_api.upscale(
init_image=image, # Pass image to API and call the upscaling process.
width=self.width, # Optional parameter to specify the desired output width.
)
# If adult content classifier is not tripped, save our image.
try:
for resp in answers:
for artifact in resp.artifacts:
if artifact.finish_reason == generation.FILTER: # 자체 필터 걸렸다는 워닝 출력
warnings.warn(
"Your request activated the API's safety filters and could not be processed."
"Please submit a different image and try again.")
if artifact.type == generation.ARTIFACT_IMAGE: # 아티펙트 타입이 이미지이면 이미지 열어서 원하는 path에 save
out_img = Image.open(BytesIO(artifact.binary))
if not os.path.exists(self.user_folder_path):
os.mkdir(self.user_folder_path)
out_img.save(origin_img_path) # Save our image to a local file.
except grpc.RpcError as e:
if e.code().value[0] == 8:
raise UpscaleException(**UpscaleErrorCode.NonTokenError.value, error=e)
elif e.code().value[0] == 16:
raise UpscaleException(**UpscaleErrorCode.WrongApiKeyError.value, error=e)
else:
raise UpscaleException(**UpscaleErrorCode.APIError.value, error=e)
return origin_img_path