Spaces:
Build error
Build error
| ''' | |
| Coarse Gaussian Rendering -- RGB-D as init | |
| RGB-D add noise (MV init) | |
| Cycling: | |
| denoise to x0 and d0 -- optimize Gaussian | |
| re-rendering RGB-D | |
| render RGB-D to rectified noise | |
| noise rectification | |
| step denoise with rectified noise | |
| -- Finally the Gaussian | |
| ''' | |
| import torch | |
| import numpy as np | |
| from copy import deepcopy | |
| from ops.utils import * | |
| from ops.gs.train import * | |
| from ops.trajs import _generate_trajectory | |
| from ops.gs.basic import Frame,Gaussian_Scene | |
| class Refinement_Tool_MCS(): | |
| def __init__(self, | |
| coarse_GS:Gaussian_Scene, | |
| device = 'cuda', | |
| refiner = None, | |
| traj_type = 'spiral', | |
| n_view = 8, | |
| rect_w = 0.7, | |
| n_gsopt_iters = 256) -> None: | |
| # input coarse GS | |
| # refine frames to be refined; here we refine frames rather than gaussian paras | |
| self.n_view = n_view | |
| self.rect_w = rect_w | |
| self.n_gsopt_iters = n_gsopt_iters | |
| self.coarse_GS = coarse_GS | |
| self.refine_frames: list[Frame] = [] | |
| # hyperparameters total is 50 steps and here is the last N steps | |
| self.process_res = 512 | |
| self.device = device | |
| self.traj_type = traj_type | |
| # models | |
| self.RGB_LCM = refiner | |
| self.RGB_LCM.to('cuda') | |
| self.steps = self.RGB_LCM.denoise_steps | |
| # prompt for diffusion | |
| prompt = self.coarse_GS.frames[-1].prompt | |
| self.rgb_prompt_latent = self.RGB_LCM.model._encode_text_prompt(prompt) | |
| # loss function | |
| self.rgb_lossfunc = RGB_Loss(w_ssim=0.2) | |
| def _pre_process(self): | |
| # determine the diffusion target shape | |
| strict_times = 32 | |
| origin_H = self.coarse_GS.frames[0].H | |
| origin_W = self.coarse_GS.frames[0].W | |
| self.target_H,self.target_W = self.process_res,self.process_res | |
| # reshape to the same (target) shape for rendering and denoising | |
| intrinsic = deepcopy(self.coarse_GS.frames[0].intrinsic) | |
| H_ratio, W_ratio = self.target_H/origin_H, self.target_W/origin_W | |
| intrinsic[0] *= W_ratio | |
| intrinsic[1] *= H_ratio | |
| target_H, target_W = self.target_H+2*strict_times, self.target_W+2*strict_times | |
| intrinsic[0,-1] = target_W/2 | |
| intrinsic[1,-1] = target_H/2 | |
| # generate a set of cameras | |
| trajs = _generate_trajectory(None,self.coarse_GS,nframes=self.n_view+2)[1:-1] | |
| for i, pose in enumerate(trajs): | |
| fine_frame = Frame() | |
| fine_frame.H = target_H | |
| fine_frame.W = target_W | |
| fine_frame.extrinsic = pose | |
| fine_frame.intrinsic = deepcopy(intrinsic) | |
| fine_frame.prompt = self.coarse_GS.frames[-1].prompt | |
| self.refine_frames.append(fine_frame) | |
| # determine inpaint mask | |
| temp_scene = Gaussian_Scene() | |
| temp_scene._add_trainable_frame(self.coarse_GS.frames[0],require_grad=False) | |
| temp_scene._add_trainable_frame(self.coarse_GS.frames[1],require_grad=False) | |
| for frame in self.refine_frames: | |
| frame = temp_scene._render_for_inpaint(frame) | |
| def _mv_init(self): | |
| rgbs = [] | |
| # only for inpainted images | |
| for frame in self.refine_frames: | |
| # rendering at now; all in the same shape | |
| render_rgb,render_dpt,render_alpha=self.coarse_GS._render_RGBD(frame) | |
| # diffusion images | |
| rgbs.append(render_rgb.permute(2,0,1)[None]) | |
| self.rgbs = torch.cat(rgbs,dim=0) | |
| self.RGB_LCM._encode_mv_init_images(self.rgbs) | |
| def _to_cuda(self,tensor): | |
| tensor = torch.from_numpy(tensor.astype(np.float32)).to('cuda') | |
| return tensor | |
| def _x0_rectification(self, denoise_rgb, iters): | |
| # gaussian initialization | |
| CGS = deepcopy(self.coarse_GS) | |
| for gf in CGS.gaussian_frames: | |
| gf._require_grad(True) | |
| self.refine_GS = GS_Train_Tool(CGS) | |
| # rectification | |
| for iter in range(iters): | |
| loss = 0. | |
| # supervise on input view | |
| for i in range(2): | |
| keep_frame :Frame = self.coarse_GS.frames[i] | |
| render_rgb,render_dpt,render_alpha = self.refine_GS._render(keep_frame) | |
| loss_rgb = self.rgb_lossfunc(render_rgb,self._to_cuda(keep_frame.rgb),valid_mask=keep_frame.inpaint) | |
| loss += loss_rgb*len(self.refine_frames) | |
| # then multiview supervision | |
| for i,frame in enumerate(self.refine_frames): | |
| render_rgb,render_dpt,render_alpha = self.refine_GS._render(frame) | |
| loss_rgb_item = self.rgb_lossfunc(denoise_rgb[i],render_rgb) | |
| loss += loss_rgb_item | |
| # optimization | |
| loss.backward() | |
| self.refine_GS.optimizer.step() | |
| self.refine_GS.optimizer.zero_grad() | |
| def _step_gaussian_optimization(self,step): | |
| # denoise to x0 and d0 | |
| with torch.no_grad(): | |
| # we left the last 2 steps for stronger guidances | |
| rgb_t = self.RGB_LCM.timesteps[-self.steps+step] | |
| rgb_t = torch.tensor([rgb_t]).to(self.device) | |
| rgb_noise_pr,rgb_denoise = self.RGB_LCM._denoise_to_x0(rgb_t,self.rgb_prompt_latent) | |
| rgb_denoise = rgb_denoise.permute(0,2,3,1) | |
| # rendering each frames and weight-able refinement | |
| self._x0_rectification(rgb_denoise,self.n_gsopt_iters) | |
| return rgb_t, rgb_noise_pr | |
| def _step_diffusion_rectification(self, rgb_t, rgb_noise_pr): | |
| # re-rendering RGB | |
| with torch.no_grad(): | |
| x0_rect = [] | |
| for i,frame in enumerate(self.refine_frames): | |
| re_render_rgb,_,re_render_alpha= self.refine_GS._render(frame) | |
| # avoid rasterization holes yield more block holes and more | |
| x0_rect.append(re_render_rgb.permute(2,0,1)[None]) | |
| x0_rect = torch.cat(x0_rect,dim=0) | |
| # rectification | |
| self.RGB_LCM._step_denoise(rgb_t,rgb_noise_pr,x0_rect,rect_w=self.rect_w) | |
| def __call__(self): | |
| # warmup | |
| self._pre_process() | |
| self._mv_init() | |
| for step in tqdm.tqdm(range(self.steps)): | |
| rgb_t, rgb_noise_pr = self._step_gaussian_optimization(step) | |
| self._step_diffusion_rectification(rgb_t, rgb_noise_pr) | |
| scene = self.refine_GS.GS | |
| for gf in scene.gaussian_frames: | |
| gf._require_grad(False) | |
| return scene |