import cv2 import numpy as np import torch from torchvision.transforms.functional import normalize from scripts.faceswaplab_swapping.parsing import init_parsing_model from functools import lru_cache from typing import Union, List from torch import device as torch_device @lru_cache def get_parsing_model(device: torch_device) -> torch.nn.Module: """ Returns an instance of the parsing model. The returned model is cached for faster subsequent access. Args: device: The torch device to use for computations. Returns: The parsing model. """ return init_parsing_model(device=device) # type: ignore def convert_image_to_tensor( images: Union[np.ndarray, List[np.ndarray]], convert_bgr_to_rgb: bool = True, use_float32: bool = True, ) -> Union[torch.Tensor, List[torch.Tensor]]: """ Converts an image or a list of images to PyTorch tensor. Args: images: An image or a list of images in numpy.ndarray format. convert_bgr_to_rgb: A boolean flag indicating if the conversion from BGR to RGB should be performed. use_float32: A boolean flag indicating if the tensor should be converted to float32. Returns: PyTorch tensor or a list of PyTorch tensors. """ def _convert_single_image_to_tensor( image: np.ndarray, convert_bgr_to_rgb: bool, use_float32: bool ) -> torch.Tensor: if image.shape[2] == 3 and convert_bgr_to_rgb: if image.dtype == "float64": image = image.astype("float32") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image_tensor = torch.from_numpy(image.transpose(2, 0, 1)) if use_float32: image_tensor = image_tensor.float() return image_tensor if isinstance(images, list): return [ _convert_single_image_to_tensor(image, convert_bgr_to_rgb, use_float32) for image in images ] else: return _convert_single_image_to_tensor(images, convert_bgr_to_rgb, use_float32) def generate_face_mask(face_image: np.ndarray, device: torch.device) -> np.ndarray: """ Generates a face mask given a face image. Args: face_image: The face image in numpy.ndarray format. device: The torch device to use for computations. Returns: The face mask as a numpy.ndarray. """ # Resize the face image for the model resized_face_image = cv2.resize( face_image, (512, 512), interpolation=cv2.INTER_LINEAR ) # Preprocess the image face_input = convert_image_to_tensor( (resized_face_image.astype("float32") / 255.0), convert_bgr_to_rgb=True, use_float32=True, ) normalize(face_input, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) # type: ignore assert isinstance(face_input, torch.Tensor) face_input = torch.unsqueeze(face_input, 0).to(device) # Pass the image through the model with torch.no_grad(): model_output = get_parsing_model(device)(face_input)[0] model_output = model_output.argmax(dim=1).squeeze().cpu().numpy() # Generate the mask from the model output parse_mask = np.zeros(model_output.shape) MASK_COLOR_MAP = [ 0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0, ] for idx, color in enumerate(MASK_COLOR_MAP): parse_mask[model_output == idx] = color # Resize the mask to match the original image face_mask = cv2.resize(parse_mask, (face_image.shape[1], face_image.shape[0])) return face_mask