You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

125 lines
3.6 KiB
Python

import cv2
import numpy as np
import torch
from torchvision.transforms.functional import normalize
from scripts.faceswaplab_swapping.parsing import init_parsing_model
from functools import lru_cache
from typing import Union, List
from torch import device as torch_device
@lru_cache
def get_parsing_model(device: torch_device) -> torch.nn.Module:
"""
Returns an instance of the parsing model.
The returned model is cached for faster subsequent access.
Args:
device: The torch device to use for computations.
Returns:
The parsing model.
"""
return init_parsing_model(device=device)
def convert_image_to_tensor(
images: Union[np.ndarray, List[np.ndarray]],
convert_bgr_to_rgb: bool = True,
use_float32: bool = True,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Converts an image or a list of images to PyTorch tensor.
Args:
images: An image or a list of images in numpy.ndarray format.
convert_bgr_to_rgb: A boolean flag indicating if the conversion from BGR to RGB should be performed.
use_float32: A boolean flag indicating if the tensor should be converted to float32.
Returns:
PyTorch tensor or a list of PyTorch tensors.
"""
def _convert_single_image_to_tensor(
image: np.ndarray, convert_bgr_to_rgb: bool, use_float32: bool
) -> torch.Tensor:
if image.shape[2] == 3 and convert_bgr_to_rgb:
if image.dtype == "float64":
image = image.astype("float32")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_tensor = torch.from_numpy(image.transpose(2, 0, 1))
if use_float32:
image_tensor = image_tensor.float()
return image_tensor
if isinstance(images, list):
return [
_convert_single_image_to_tensor(image, convert_bgr_to_rgb, use_float32)
for image in images
]
else:
return _convert_single_image_to_tensor(images, convert_bgr_to_rgb, use_float32)
def generate_face_mask(face_image: np.ndarray, device: torch.device) -> np.ndarray:
"""
Generates a face mask given a face image.
Args:
face_image: The face image in numpy.ndarray format.
device: The torch device to use for computations.
Returns:
The face mask as a numpy.ndarray.
"""
# Resize the face image for the model
resized_face_image = cv2.resize(
face_image, (512, 512), interpolation=cv2.INTER_LINEAR
)
# Preprocess the image
face_input = convert_image_to_tensor(
(resized_face_image.astype("float32") / 255.0),
convert_bgr_to_rgb=True,
use_float32=True,
)
normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
assert isinstance(face_input, torch.Tensor)
face_input = torch.unsqueeze(face_input, 0).to(device)
# Pass the image through the model
with torch.no_grad():
model_output = get_parsing_model(device)(face_input)[0]
model_output = model_output.argmax(dim=1).squeeze().cpu().numpy()
# Generate the mask from the model output
parse_mask = np.zeros(model_output.shape)
MASK_COLOR_MAP = [
0,
255,
255,
255,
255,
255,
255,
255,
255,
255,
255,
255,
255,
255,
0,
255,
0,
0,
0,
]
for idx, color in enumerate(MASK_COLOR_MAP):
parse_mask[model_output == idx] = color
# Resize the mask to match the original image
face_mask = cv2.resize(parse_mask, (face_image.shape[1], face_image.shape[0]))
return face_mask