@@ -151,11 +151,11 @@ def generate_noise(shape, seed=None, device="cpu", dtype=torch.float16):
151151 noise = torch .randn (shape , generator = generator , device = device , dtype = dtype )
152152 return noise
153153
154- def encode_image (self , image : torch .Tensor ) -> torch .Tensor :
154+ def encode_image (
155+ self , image : torch .Tensor , tiled : bool = False , tile_size : int = 64 , tile_stride : int = 32
156+ ) -> torch .Tensor :
155157 image = image .to (self .device , self .vae_encoder .dtype )
156- latents = self .vae_encoder (
157- image , tiled = self .vae_tiled , tile_size = self .vae_tile_size , tile_stride = self .vae_tile_stride
158- )
158+ latents = self .vae_encoder (image , tiled = tiled , tile_size = tile_size , tile_stride = tile_stride )
159159 return latents
160160
161161 def decode_image (self , latent : torch .Tensor ) -> torch .Tensor :
@@ -187,7 +187,7 @@ def prepare_latents(
187187 self .load_models_to_device (["vae_encoder" ])
188188 noise = latents
189189 image = self .preprocess_image (input_image ).to (device = self .device , dtype = self .dtype )
190- latents = self .encode_image (image , tiled , tile_size , tile_stride )
190+ latents = self .encode_image (image , tiled = tiled , tile_size = tile_size , tile_stride = tile_stride )
191191 init_latents = latents .clone ()
192192 latents = self .sampler .add_noise (latents , noise , sigma_start )
193193 else :
0 commit comments