You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
290 lines
9.5 KiB
Python
290 lines
9.5 KiB
Python
import io
|
|
from typing import List, Optional, Union, Dict
|
|
from PIL import Image
|
|
import cv2
|
|
import numpy as np
|
|
from math import isqrt, ceil
|
|
import torch
|
|
from ifnude import detect
|
|
from scripts.faceswaplab_globals import NSFW_SCORE_THRESHOLD
|
|
from modules import processing
|
|
import base64
|
|
from collections import Counter
|
|
from scripts.faceswaplab_utils.typing import BoxCoords, CV2ImgU8, PILImage
|
|
|
|
|
|
def check_against_nsfw(img: PILImage) -> bool:
|
|
"""
|
|
Check if an image exceeds the Not Safe for Work (NSFW) score.
|
|
|
|
Parameters:
|
|
img (PILImage): The image to be checked.
|
|
|
|
Returns:
|
|
bool: True if any part of the image is considered NSFW, False otherwise.
|
|
"""
|
|
|
|
shapes: List[bool] = []
|
|
chunks: List[Dict[str, Union[int, float]]] = detect(img)
|
|
|
|
for chunk in chunks:
|
|
shapes.append(chunk["score"] > NSFW_SCORE_THRESHOLD)
|
|
|
|
return any(shapes)
|
|
|
|
|
|
def pil_to_cv2(pil_img: PILImage) -> CV2ImgU8: # type: ignore
|
|
"""
|
|
Convert a PIL Image into an OpenCV image (cv2).
|
|
|
|
Args:
|
|
pil_img (PILImage): An image in PIL format.
|
|
|
|
Returns:
|
|
CV2ImgU8: The input image converted to OpenCV format (BGR).
|
|
"""
|
|
return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
def cv2_to_pil(cv2_img: CV2ImgU8) -> PILImage: # type: ignore
|
|
"""
|
|
Convert an OpenCV image (cv2) into a PIL Image.
|
|
|
|
Args:
|
|
cv2_img (CV2ImgU8): An image in OpenCV format (BGR).
|
|
|
|
Returns:
|
|
PILImage: The input image converted to PIL format (RGB).
|
|
"""
|
|
return Image.fromarray(cv2.cvtColor(cv2_img, cv2.COLOR_BGR2RGB))
|
|
|
|
|
|
def torch_to_pil(tensor: torch.Tensor) -> List[PILImage]:
|
|
"""
|
|
Converts a tensor image or a batch of tensor images to a PIL image or a list of PIL images.
|
|
|
|
Parameters
|
|
----------
|
|
images : torch.Tensor
|
|
A tensor representing an image or a batch of images.
|
|
|
|
Returns
|
|
-------
|
|
list
|
|
A list of PIL images.
|
|
"""
|
|
images: CV2ImgU8 = tensor.cpu().permute(0, 2, 3, 1).numpy()
|
|
if images.ndim == 3:
|
|
images = images[None, ...]
|
|
images = (images * 255).round().astype("uint8")
|
|
pil_images = [Image.fromarray(image) for image in images]
|
|
return pil_images
|
|
|
|
|
|
def pil_to_torch(pil_images: Union[PILImage, List[PILImage]]) -> torch.Tensor:
|
|
"""
|
|
Converts a PIL image or a list of PIL images to a torch tensor or a batch of torch tensors.
|
|
|
|
Parameters
|
|
----------
|
|
pil_images : Union[PILImage, List[PILImage]]
|
|
A PIL image or a list of PIL images.
|
|
|
|
Returns
|
|
-------
|
|
torch.Tensor
|
|
A tensor representing an image or a batch of images.
|
|
"""
|
|
if isinstance(pil_images, list):
|
|
numpy_images = [np.array(image) for image in pil_images]
|
|
torch_images = torch.from_numpy(np.stack(numpy_images)).permute(0, 3, 1, 2)
|
|
return torch_images
|
|
|
|
numpy_image = np.array(pil_images)
|
|
torch_image = torch.from_numpy(numpy_image).permute(2, 0, 1)
|
|
return torch_image
|
|
|
|
|
|
def create_square_image(image_list: List[PILImage]) -> Optional[PILImage]:
|
|
"""
|
|
Creates a square image by combining multiple images in a grid pattern.
|
|
|
|
Args:
|
|
image_list (list): List of PIL Image objects to be combined.
|
|
|
|
Returns:
|
|
PIL Image object: The resulting square image.
|
|
None: If the image_list is empty or contains only one image.
|
|
"""
|
|
|
|
# Count the occurrences of each image size in the image_list
|
|
size_counter = Counter(image.size for image in image_list)
|
|
|
|
# Get the most common image size (size with the highest count)
|
|
common_size = size_counter.most_common(1)[0][0]
|
|
|
|
# Filter the image_list to include only images with the common size
|
|
image_list = [image for image in image_list if image.size == common_size]
|
|
|
|
# Get the dimensions (width and height) of the common size
|
|
size = common_size
|
|
|
|
# If there are more than one image in the image_list
|
|
if len(image_list) > 1:
|
|
num_images = len(image_list)
|
|
|
|
# Calculate the number of rows and columns for the grid
|
|
rows = isqrt(num_images)
|
|
cols = ceil(num_images / rows)
|
|
|
|
# Calculate the size of the square image
|
|
square_size = (cols * size[0], rows * size[1])
|
|
|
|
# Create a new RGB image with the square size
|
|
square_image = Image.new("RGB", square_size)
|
|
|
|
# Paste each image onto the square image at the appropriate position
|
|
for i, image in enumerate(image_list):
|
|
row = i // cols
|
|
col = i % cols
|
|
|
|
square_image.paste(image, (col * size[0], row * size[1]))
|
|
|
|
# Return the resulting square image
|
|
return square_image
|
|
|
|
# Return None if there are no images or only one image in the image_list
|
|
return None
|
|
|
|
|
|
# def create_mask(image : PILImage, box_coords : Tuple[int, int, int, int]) -> PILImage:
|
|
# width, height = image.size
|
|
# mask = Image.new("L", (width, height), 255)
|
|
# x1, y1, x2, y2 = box_coords
|
|
# for x in range(width):
|
|
# for y in range(height):
|
|
# if x1 <= x <= x2 and y1 <= y <= y2:
|
|
# mask.putpixel((x, y), 255)
|
|
# else:
|
|
# mask.putpixel((x, y), 0)
|
|
# return mask
|
|
|
|
|
|
def create_mask(
|
|
image: PILImage,
|
|
box_coords: BoxCoords,
|
|
) -> PILImage:
|
|
"""
|
|
Create a binary mask for a given image and bounding box coordinates.
|
|
|
|
Args:
|
|
image (PILImage): The input image.
|
|
box_coords (Tuple[int, int, int, int]): A tuple of 4 integers defining the bounding box.
|
|
It follows the pattern (x1, y1, x2, y2), where (x1, y1) is the top-left coordinate of the
|
|
box and (x2, y2) is the bottom-right coordinate of the box.
|
|
|
|
Returns:
|
|
PILImage: A binary mask of the same size as the input image, where pixels within
|
|
the bounding box are white (255) and pixels outside the bounding box are black (0).
|
|
"""
|
|
width, height = image.size
|
|
mask = Image.new("L", (width, height), 0)
|
|
x1, y1, x2, y2 = box_coords
|
|
for x in range(x1, x2 + 1):
|
|
for y in range(y1, y2 + 1):
|
|
mask.putpixel((x, y), 255)
|
|
return mask
|
|
|
|
|
|
def apply_mask(
|
|
img: PILImage, p: processing.StableDiffusionProcessing, batch_index: int
|
|
) -> PILImage:
|
|
"""
|
|
Apply mask overlay and color correction to an image if enabled
|
|
|
|
Args:
|
|
img: PIL Image objects.
|
|
p : The processing object
|
|
batch_index : the batch index
|
|
|
|
Returns:
|
|
PIL Image object
|
|
"""
|
|
if isinstance(p, processing.StableDiffusionProcessingImg2Img):
|
|
if p.inpaint_full_res:
|
|
overlays = p.overlay_images
|
|
if overlays is None or batch_index >= len(overlays):
|
|
return img
|
|
overlay: PILImage = overlays[batch_index]
|
|
overlay = overlay.resize((img.size), resample=Image.Resampling.LANCZOS)
|
|
img = img.copy()
|
|
img.paste(overlay, (0, 0), overlay)
|
|
return img
|
|
|
|
img = processing.apply_overlay(img, p.paste_to, batch_index, p.overlay_images)
|
|
if p.color_corrections is not None and batch_index < len(p.color_corrections):
|
|
img = processing.apply_color_correction(
|
|
p.color_corrections[batch_index], img
|
|
)
|
|
return img
|
|
|
|
|
|
def prepare_mask(mask: PILImage, p: processing.StableDiffusionProcessing) -> PILImage:
|
|
"""
|
|
Prepare an image mask for the inpainting process. (This comes from controlnet)
|
|
|
|
This function takes as input a PIL Image object and an instance of the
|
|
StableDiffusionProcessing class, and performs the following steps to prepare the mask:
|
|
|
|
1. Convert the mask to grayscale (mode "L").
|
|
2. If the 'inpainting_mask_invert' attribute of the processing instance is True,
|
|
invert the mask colors.
|
|
3. If the 'mask_blur' attribute of the processing instance is greater than 0,
|
|
apply a Gaussian blur to the mask with a radius equal to 'mask_blur'.
|
|
|
|
Args:
|
|
mask (PILImage): The input mask as a PIL Image object.
|
|
p (processing.StableDiffusionProcessing): An instance of the StableDiffusionProcessing class
|
|
containing the processing parameters.
|
|
|
|
Returns:
|
|
mask (PILImage): The prepared mask as a PIL Image object.
|
|
"""
|
|
mask = mask.convert("L")
|
|
# FIXME : Properly fix blur
|
|
# if getattr(p, "mask_blur", 0) > 0:
|
|
# mask = mask.filter(ImageFilter.GaussianBlur(p.mask_blur))
|
|
return mask
|
|
|
|
|
|
def base64_to_pil(base64str: Optional[str]) -> Optional[PILImage]:
|
|
"""
|
|
Converts a base64 string to a PIL Image object.
|
|
|
|
Parameters:
|
|
base64str (Optional[str]): The base64 string to convert. This string may contain a data URL scheme
|
|
(i.e., 'data:image/jpeg;base64,') or just be the raw base64 encoded data. If None, the function
|
|
will return None.
|
|
|
|
Returns:
|
|
Optional[PILImage]: A PIL Image object created from the base64 string. If the input is None,
|
|
the function returns None.
|
|
|
|
Raises:
|
|
binascii.Error: If the base64 string is not properly formatted or encoded.
|
|
PIL.UnidentifiedImageError: If the image format cannot be identified.
|
|
"""
|
|
|
|
if base64str is None:
|
|
return None
|
|
|
|
# Check if the base64 string has a data URL scheme
|
|
if "base64," in base64str:
|
|
base64_data = base64str.split("base64,")[-1]
|
|
img_bytes = base64.b64decode(base64_data)
|
|
else:
|
|
# If no data URL scheme, just decode
|
|
img_bytes = base64.b64decode(base64str)
|
|
|
|
return Image.open(io.BytesIO(img_bytes))
|