From 02d88bac918a131994c140590e6198cb6340e764 Mon Sep 17 00:00:00 2001 From: Tran Xen <137925069+glucauze@users.noreply.github.com> Date: Thu, 3 Aug 2023 15:25:44 +0200 Subject: [PATCH] add api for face building, add tests --- client_api/api_utils.py | 2 +- client_api/faceswaplab_api_example.py | 22 +++- client_api/test.safetensors | Bin 2256 -> 2256 bytes scripts/faceswaplab_api/faceswaplab_api.py | 24 ++++ scripts/faceswaplab_swapping/swapper.py | 9 +- scripts/faceswaplab_ui/faceswaplab_tab.py | 4 +- .../face_checkpoints_utils.py | 27 +++-- tests/test_api.py | 108 +++++++++++++++++- 8 files changed, 169 insertions(+), 27 deletions(-) diff --git a/client_api/api_utils.py b/client_api/api_utils.py index 9190b0d..4847acd 100644 --- a/client_api/api_utils.py +++ b/client_api/api_utils.py @@ -187,7 +187,7 @@ class FaceSwapRequest(BaseModel): default=None, ) units: List[FaceSwapUnit] - postprocessing: Optional[PostProcessingOptions] + postprocessing: Optional[PostProcessingOptions] = None class FaceSwapResponse(BaseModel): diff --git a/client_api/faceswaplab_api_example.py b/client_api/faceswaplab_api_example.py index b992f15..9a5a522 100644 --- a/client_api/faceswaplab_api_example.py +++ b/client_api/faceswaplab_api_example.py @@ -1,7 +1,9 @@ +from typing import List import requests from api_utils import ( FaceSwapUnit, InswappperOptions, + base64_to_safetensors, pil_to_base64, PostProcessingOptions, InpaintingWhen, @@ -98,12 +100,30 @@ for img in response.pil_images: img.show() +############################# +# Build checkpoint + +source_images: List[str] = [ + pil_to_base64("../references/man.png"), + pil_to_base64("../references/woman.png"), +] + +result = requests.post( + url=f"{address}/faceswaplab/build", + json=source_images, + headers={"Content-Type": "application/json; charset=utf-8"}, +) + +base64_to_safetensors(result.json(), output_path="test.safetensors") + ############################# # FaceSwap with local safetensors # First face unit : unit1 = FaceSwapUnit( - source_face=safetensors_to_base64("test.safetensors"), + source_face=safetensors_to_base64( + "test.safetensors" + ), # convert the checkpoint to base64 faces_index=(0,), # Replace first face swapping_options=InswappperOptions( face_restorer_name="CodeFormer", diff --git a/client_api/test.safetensors b/client_api/test.safetensors index c74264bdb49b4d996d1791eeac4b4a7b192cb435..a24d1b6a2fe97011273f695d28bf97b6b864875c 100644 GIT binary patch literal 2256 zcmb7?XHb-f7KRrK5Q2(`QiO=BD8h;~*A47 zBzn<^5;UUjcaCs{fCWn=5u?PhT`WnVTP(xxeoF=aiXe-gnNNI>~>vS{l13 zLF%zus#=z*e&0xaU($(KZ4vIbNsk?a$y&gBPD`QYnXA1#}uA4Ds;v@efp6>SjpFp#IdbPx{~$}#i)B+?6Upyk_==+zHuFm&I6 zK40lkSwX8ZSh5oJ`dt#jsuE$~>rU{`uHdNr+aUe9nNpNzc!dSxT{B;PBRdN|nD%aHY;JLn=~#$D2aSQJu6|un%zN)*>w``q#n@^Kf3ClOxEi%c(B6 zNzlEn5gVJ=;!Bwo8W+5#An&IXuxTrYSvIkMem;gKTi`s4a5NFTDPFdf7vDN84vX-` zLeUAIJZ}UsuNYL)KZI>lDxh340uG+JMOpH3%5G~%oR>GAJI6L1E|3FTO0t;*x`WJ}QW|iCcMLh8LNSZ-#_w9Vbn<oZ^;0WyO+fJ?@=hB(p z)qG&cBtep8h-!L7m%0Yg(Dg$Q?-x^qg%h6n*SEB0`cL#(_$2%!dn}&|AB4jr4q(d2 z@1Wa$I9h6(IBin~2jA#|N0YR?=0OuhYFF^6j;kc$39xp;7U-Ov5B1YzsPV8DT0e`S zo6WPZV(SPMX0MhElf>%4M&b_#(mGOiyfmRmyUSUg$G z9u*4Xoa}L)YbaU-xF|yc5%p>(Li;z8b8e+)dEAs5sC}FSZ5{UT)eAlTerPrrFWE>P zjn@GJTCvZ)8PnT~x$DP693N`SJr6tRp2h`NZI}o4QddF$?n0jadj&l)j^jaZ=W*^? z#s{OfL(R2%7_0Tg?Ugw=dz3E)+e9*4(?<`*dXhzcNHe7;!N16kvox`uy2A1NTXq@- z9vDwY(^^FrqgRkDb>xWay>!ZIAH6tXB2<*_;W(eQ{8SdmldtT73%5F9!+kfd`KB8E zl1Jm&pF;70N+Nu*+mFjy*R%12E2uVb0r};H?5vZp`Ct!pu4g!unnTYelTZ_HjKi+f z^PUUmDY#}Xm7F{PrWd@?aM(Lh&r-p^rev~T_hGs>C7va1_h>^`9`TZukl~q!PKrQO z2H5g!D-92bYN3MaL6ANunpcdRh&9&HSaz!i2MjAE_kc5?=o^hcIMkwgUnBUGZiSce zUNpbOpY5HpK%3Mho+zHpX(xkN=RXbIdu}7OCX=Z=hf1UDpwiNpi;94wc56tnZ=m3w z?#u%xlwiiiAkuXd@cDrw;PL4SkawvuF40+NJEBF2;@>3OWy+q>v#2yc4VqK>=o@pA z`j5$@u+%NAd*#TlzBR+uf`Lepja=u zRtm7@agO+@&Bt7qoB~xob)2661z1PU;*u;i=zTqwYAB8C2YAs=?|gpO5scoB2dVK$ zIi+?VhhEhb&N+0E?kw*|v#r;%+q$3?Oo}(wY$MO+Jg{pR4H++7 z(9^UNjs|b1czXl9yw?lY6fA*=v!1M7dK5ym=3E^;moFUafxwwJ#VD5(nDzRFM_KI= za?stP`^)|VwV6riQ00V v|3~Nj6!`6FQ5iCg40qfRC96)IzMtj8Znjlur;6mydfw9P4=X__|Cs*-&(#JB literal 2256 zcmb7^Yc$n)7so?M)C?zdlj@{$IXSK;Bj>kImqAL3D1>khdd8(flNymyIw?{rgpigI zk&xp5+efG5u8c{Yjbtuu!Wu>o*qFyKEa}p;BOZLj+t9HhwUp06pHr!KcE@M zkoOlU3h)vMg}VcN{|Q?f8~r2I)X?aEgAI8cBLkj^TX^{5Z?gWozs%n*-N%L5ErFQ( zLlibX%*W6XA&As6$W7mL=-HYAePyc9ZMF;IKiOeWIpR5w0GR0>4BVey6N{tk!2Ks% zv^@8Yh~q<1^m{2kF~=FV@b2J9%0bN4Tu+*sbgBMhWXdlxa9rgqqaE@`MPP$9Y)h(y zzEm-C59(mA+ykoqXE2DTHqq1_X=Kp#JUnqZi!W}xCD|wY>0tOQ@^yDD7I$gDy`2X^ zTOkKe%2lIqmJXOUH{h1W70|ggm$ZFo!NMaFnxJA&&nzhh6W%efY`=yqr4YhR{6dTJ zn=nB^hRBC`(98{`ME1*ZEU1V;_2v|Gzcq~veisVfz7l?Gt13BgUlUi{d`p>?C1}>< zge+$j@Q_WSRFDS(dl9)Gcpo#GHiPpRVS~%GW%(I)YKJ*+WTdjr< z^RMHqbT^=v{si$rFYzCf(lS9Ead-70Vr5s%Y`=_!Vl{r`QdP;Zp$xj0 z-?Kr{;2G}XsDNfo2(f-HFS!!NfhGRe8Jn6s7|c$@mo^n}!7iC*Wk!LyaV4E#q9Ll~ z9PV5`KwnvmFx>V4@=&+{AR!1l+Vrq+7FImzv%vtz4BUVHF3K;F0hVP0-RgG)aue2p z&7#?Kbf}ZqhbPiw`y=V;t@coQXcKr1eV{JN9%$<{4X-&Z1(D}4IU)a*bZSn)GQ}oZ zpdf0cq;n8Mm(rg4NGvN@BkVsA&y>-gmhWF0q!h=@ZBlf#h^K(rTMgV)HQ-6x%fUyg~_CGOR?E zWLCmR$uRo8@*xJcJlwNE7R}ts$PeWSpdBoOHqO2nBKF3^mu|zAu7~uC=qmC~7ei*G zBRD!pX|hyKFsUQO^;uz5b+Q<-x)IYxH$(5p1ktHj2s)c0A>?HzSvVo)EA;Qdzf!RmdT%0hx?RkQ2ohjdW(}yp1zM{yi66T?t&8- zH9QSIsavB(r9Oy*j8XMPCMk04CLaZ#VNn+Y8WYj{3sE7o?OXtE`!SRRdm0ON8h)Xx zx1NV55>~kgqJDOVu+0kvN~62M^oVOhnl2Z=oGs=)gJ!4R)lfe#)ceFIUiMg$s z40k8IQR|#9Y-qTEJ#u51R6IoOA8L@X+6ibp8;B3yC}ZKb|FpK{R}wkSwiwNPK|kKS z#MewbhTrH}H*0GpY^Nwhh$YTDf7#t~V!35{aFiS;-G%oeR zF(bf}pKI~%r9KRpcNd1@QyJ3{V_dw;8w}&hsra!H?%Tc&7rWmg4oyAymc4<>zc5F+ zUyebZpc$rCD?xADF5KuK#r{X(6@Cy;I+!c){y+_C7jFk{>wf6rX+z=iJai1WifOIc z zl#_*#hnDbVzw6-l)IX(1f13}&h#ahznJE}v83#4vF;JzPNV(RAl4U&x=<&7y<4=`< zX2g29ViyUQ4;DdePz^Rzaj9wIT^3X_pwQkGxaNzXcGsceeRTJy%V%Uhxlhy zoavec&5-JxiWdx0(9*CSvZm`poZuyFTV~HU$=*vJoG_wN{VGVGqYU$O*_6>}r@vmA z2VE;~fY>CJwALG9sCfmhf87nff=)Ak#A_y=FD(KvjW}d{{Z)5Ig0=Q diff --git a/scripts/faceswaplab_api/faceswaplab_api.py b/scripts/faceswaplab_api/faceswaplab_api.py index e4d9c3e..ef6803f 100644 --- a/scripts/faceswaplab_api/faceswaplab_api.py +++ b/scripts/faceswaplab_api/faceswaplab_api.py @@ -1,3 +1,4 @@ +import tempfile from PIL import Image import numpy as np from fastapi import FastAPI @@ -17,6 +18,9 @@ from scripts.faceswaplab_postprocessing.postprocessing_options import ( PostProcessingOptions, ) from client_api import api_utils +from scripts.faceswaplab_utils.face_checkpoints_utils import ( + build_face_checkpoint_and_save, +) def encode_to_base64(image: Union[str, Image.Image, np.ndarray]) -> str: # type: ignore @@ -135,3 +139,23 @@ def faceswaplab_api(_: gr.Blocks, app: FastAPI) -> None: result_images = [encode_to_base64(img) for img in faces] response = api_utils.FaceSwapExtractResponse(images=result_images) return response + + @app.post( + "/faceswaplab/build", + tags=["faceswaplab"], + description="Build a face checkpoint using base64 images, return base64 satetensors", + ) + async def build(base64_images: List[str]) -> Optional[str]: + if len(base64_images) > 0: + pil_images = [base64_to_pil(img) for img in base64_images] + with tempfile.NamedTemporaryFile( + delete=True, suffix=".safetensors" + ) as temp_file: + build_face_checkpoint_and_save( + images=pil_images, + name="api_ckpt", + overwrite=True, + path=temp_file.name, + ) + return api_utils.safetensors_to_base64(temp_file.name) + return None diff --git a/scripts/faceswaplab_swapping/swapper.py b/scripts/faceswaplab_swapping/swapper.py index e6fcae9..677e49d 100644 --- a/scripts/faceswaplab_swapping/swapper.py +++ b/scripts/faceswaplab_swapping/swapper.py @@ -468,12 +468,12 @@ def get_or_default(l: List[Any], index: int, default: Any) -> Any: return l[index] if index < len(l) else default -def get_faces_from_img_files(files: List[str]) -> List[Optional[CV2ImgU8]]: +def get_faces_from_img_files(images: List[PILImage]) -> List[Optional[CV2ImgU8]]: """ Extracts faces from a list of image files. Args: - files (list): A list of file objects representing image files. + images (list): A list of PILImage objects representing image files. Returns: list: A list of detected faces. @@ -482,9 +482,8 @@ def get_faces_from_img_files(files: List[str]) -> List[Optional[CV2ImgU8]]: faces = [] - if len(files) > 0: - for file in files: - img = Image.open(file) # Open the image file + if len(images) > 0: + for img in images: face = get_or_default( get_faces(pil_to_cv2(img)), 0, None ) # Extract faces from the image diff --git a/scripts/faceswaplab_ui/faceswaplab_tab.py b/scripts/faceswaplab_ui/faceswaplab_tab.py index 5db3dd7..c2c659f 100644 --- a/scripts/faceswaplab_ui/faceswaplab_tab.py +++ b/scripts/faceswaplab_ui/faceswaplab_tab.py @@ -153,9 +153,9 @@ def build_face_checkpoint_and_save( if not batch_files: logger.error("No face found") return None - filenames = [x.name for x in batch_files] + images = [Image.open(file.name) for file in batch_files] preview_image = face_checkpoints_utils.build_face_checkpoint_and_save( - filenames, name, overwrite=overwrite + images, name, overwrite=overwrite ) except Exception as e: logger.error("Failed to build checkpoint %s", e) diff --git a/scripts/faceswaplab_utils/face_checkpoints_utils.py b/scripts/faceswaplab_utils/face_checkpoints_utils.py index 481718f..bf652f1 100644 --- a/scripts/faceswaplab_utils/face_checkpoints_utils.py +++ b/scripts/faceswaplab_utils/face_checkpoints_utils.py @@ -38,7 +38,7 @@ def sanitize_name(name: str) -> str: def build_face_checkpoint_and_save( - batch_files: List[str], name: str, overwrite: bool = False + images: List[PILImage], name: str, overwrite: bool = False, path: str = None ) -> PILImage: """ Builds a face checkpoint using the provided image files, performs face swapping, @@ -55,9 +55,9 @@ def build_face_checkpoint_and_save( try: name = sanitize_name(name) - batch_files = batch_files or [] - logger.info("Build %s %s", name, [x for x in batch_files]) - faces = swapper.get_faces_from_img_files(batch_files) + images = images or [] + logger.info("Build %s with %s images", name, len(images)) + faces = swapper.get_faces_from_img_files(images) blended_face = swapper.blend_faces(faces) preview_path = os.path.join( scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references" @@ -95,14 +95,17 @@ def build_face_checkpoint_and_save( ) preview_image = result.image - file_path = os.path.join(get_checkpoint_path(), f"{name}.safetensors") - if not overwrite: - file_number = 1 - while os.path.exists(file_path): - file_path = os.path.join( - get_checkpoint_path(), f"{name}_{file_number}.safetensors" - ) - file_number += 1 + if path: + file_path = path + else: + file_path = os.path.join(get_checkpoint_path(), f"{name}.safetensors") + if not overwrite: + file_number = 1 + while os.path.exists(file_path): + file_path = os.path.join( + get_checkpoint_path(), f"{name}_{file_number}.safetensors" + ) + file_number += 1 save_face(filename=file_path, face=blended_face) preview_image.save(file_path + ".png") try: diff --git a/tests/test_api.py b/tests/test_api.py index 166ee24..8f20bce 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -2,22 +2,28 @@ from typing import List import pytest import requests import sys +import tempfile +import safetensors sys.path.append(".") +import requests from client_api.api_utils import ( FaceSwapUnit, - FaceSwapResponse, - PostProcessingOptions, - FaceSwapRequest, - base64_to_pil, + InswappperOptions, pil_to_base64, + PostProcessingOptions, InpaintingWhen, - FaceSwapCompareRequest, + InpaintingOptions, + FaceSwapRequest, + FaceSwapResponse, FaceSwapExtractRequest, + FaceSwapCompareRequest, FaceSwapExtractResponse, compare_faces, - InpaintingOptions, + base64_to_pil, + base64_to_safetensors, + safetensors_to_base64, ) from PIL import Image @@ -37,6 +43,13 @@ def face_swap_request() -> FaceSwapRequest: source_img=pil_to_base64("references/woman.png"), # The face you want to use same_gender=True, faces_index=(0,), # Replace first woman since same gender is on + swapping_options=InswappperOptions( + face_restorer_name="CodeFormer", + upscaler_name="LDSR", + improved_mask=True, + sharpen=True, + color_corrections=True, + ), ) # Post-processing config @@ -179,3 +192,86 @@ def test_faceswap_inpainting(face_swap_request: FaceSwapRequest) -> None: data = response.json() assert "images" in data assert "infos" in data + + +def test_faceswap_checkpoint_building() -> None: + source_images: List[str] = [ + pil_to_base64("references/man.png"), + pil_to_base64("references/woman.png"), + ] + + response = requests.post( + url=f"{base_url}/faceswaplab/build", + json=source_images, + headers={"Content-Type": "application/json; charset=utf-8"}, + ) + + assert response.status_code == 200 + + with tempfile.NamedTemporaryFile(delete=True) as temp_file: + base64_to_safetensors(response.json(), output_path=temp_file.name) + with safetensors.safe_open(temp_file.name, framework="pt") as f: + assert "age" in f.keys() + assert "gender" in f.keys() + assert "embedding" in f.keys() + + +def test_faceswap_checkpoint_building_and_using() -> None: + source_images: List[str] = [ + pil_to_base64("references/man.png"), + ] + + response = requests.post( + url=f"{base_url}/faceswaplab/build", + json=source_images, + headers={"Content-Type": "application/json; charset=utf-8"}, + ) + + assert response.status_code == 200 + + with tempfile.NamedTemporaryFile(delete=True) as temp_file: + base64_to_safetensors(response.json(), output_path=temp_file.name) + with safetensors.safe_open(temp_file.name, framework="pt") as f: + assert "age" in f.keys() + assert "gender" in f.keys() + assert "embedding" in f.keys() + + # First face unit : + unit1 = FaceSwapUnit( + source_face=safetensors_to_base64( + temp_file.name + ), # convert the checkpoint to base64 + faces_index=(0,), # Replace first face + swapping_options=InswappperOptions( + face_restorer_name="CodeFormer", + upscaler_name="LDSR", + improved_mask=True, + sharpen=True, + color_corrections=True, + ), + ) + + # Prepare the request + request = FaceSwapRequest( + image=pil_to_base64("tests/test_image.png"), units=[unit1] + ) + + # Face Swap + response = requests.post( + url=f"{base_url}/faceswaplab/swap_face", + data=request.json(), + headers={"Content-Type": "application/json; charset=utf-8"}, + ) + assert response.status_code == 200 + fsr = FaceSwapResponse.parse_obj(response.json()) + data = response.json() + assert "images" in data + assert "infos" in data + + # First face is the man + assert ( + compare_faces( + fsr.pil_images[0], Image.open("references/man.png"), base_url=base_url + ) + > 0.5 + )