Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions DeepLense_Diffusion_Rishi/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@

import os
import torch
import numpy as np
from torch.utils.data import Dataset
from sklearn.preprocessing import LabelEncoder
from typing import Optional, Callable, List, Tuple

class LensDataset(Dataset):
"""
Base dataset class for loading lens data.

Args:
root_dir (str): Path to the directory containing .npy files.
transform (callable, optional): Optional transform to be applied on a sample.
max_samples (int, optional): Maximum number of samples to load.
"""
def __init__(self, root_dir: str, transform: Optional[Callable] = None, max_samples: Optional[int] = None):
self.root_dir = root_dir
self.transform = transform

# Load and sort file list for reproducibility
self.file_list = sorted([f for f in os.listdir(root_dir) if f.endswith('.npy')])
if max_samples:
self.file_list = self.file_list[:max_samples]

def __len__(self):
return len(self.file_list)

def __getitem__(self, idx: int) -> torch.Tensor:
file_name = self.file_list[idx]
file_path = os.path.join(self.root_dir, file_name)

try:
# Load numpy array
data = np.load(file_path, allow_pickle=True)

# Apply transforms
# Note: Normalization and type conversion should be handled by the transform
if self.transform:
data = self.transform(data)

return data

except Exception as e:
print(f"Error loading file {file_path}: {e}")
# Identify a strategy for failed loads; for now return None or raise
raise e

class WrapperDataset(Dataset):
"""
Wrapper dataset for conditional generation or multi-class scenarios.
Automatically iterates through subdirectories as classes.

Args:
root_dir (str): Root directory containing class subdirectories.
transform (callable, optional): Transform to apply to data.
max_samples_per_class (int, optional): Limit samples per class.
"""
def __init__(self, root_dir: str, transform: Optional[Callable] = None, max_samples_per_class: int = 5000):
self.root_dir = root_dir
self.transform = transform

# Identify class folders
self.class_folders = sorted([
f for f in os.listdir(root_dir)
if os.path.isdir(os.path.join(root_dir, f))
])

if not self.class_folders:
raise FileNotFoundError(f"No class subdirectories found in {root_dir}")

self.label_encoder = LabelEncoder()
self.labels = self.label_encoder.fit_transform(self.class_folders)

self.data_index = [] # List of (file_path, class_name)

for class_name in self.class_folders:
class_path = os.path.join(root_dir, class_name)
files = sorted([f for f in os.listdir(class_path) if f.endswith('.npy')])

if max_samples_per_class:
files = files[:max_samples_per_class]

for f in files:
self.data_index.append((os.path.join(class_path, f), class_name))

print(f"Found {len(self.class_folders)} classes: {self.class_folders}")
print(f"Total samples: {len(self.data_index)}")

def __len__(self):
return len(self.data_index)

def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
file_path, class_name = self.data_index[idx]

# Load Data
data = np.load(file_path, allow_pickle=True)

# Specific handling for 'axion' like cases could go into a specific transform
# or pre-processing function if strictly required, but generally we rely on shape.
# Minimal legacy support: if data is wrapper in an extra array [data, label, ...] like structure
# Checking if data has a shape like (1, 64, 64) vs (64, 64) done via transforms usually.

if self.transform:
data = self.transform(data)

# Get Label
label = self.label_encoder.transform([class_name])[0]

return data, label
66 changes: 66 additions & 0 deletions DeepLense_Diffusion_Rishi/dataset/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@

import torch
import numpy as np
from torchvision import transforms

class MinMaxNormalize:
"""Normalizes the image to [0, 1] range using min-max scaling."""
def __call__(self, x):
min_val = x.min()
max_val = x.max()
if max_val - min_val > 0:
return (x - min_val) / (max_val - min_val)
return x

class SelectChannel:
"""Selects a specific channel/index from the input array."""
def __init__(self, index=0):
self.index = index

def __call__(self, x):
# If x is numpy array or tensor with more than 1 dimension
if hasattr(x, "ndim") and x.ndim > 2:
return x[self.index]
return x

class ToTensor:
"""Converts a numpy array to a torch tensor."""
def __call__(self, x):
if isinstance(x, np.ndarray):
return torch.from_numpy(x).float()
return x.float()

class AddChannel:
"""Adds a channel dimension at the specified index."""
def __init__(self, dim=0):
self.dim = dim

def __call__(self, x):
return x.unsqueeze(self.dim)

def get_transforms(config=None):
"""
Returns a composition of transforms based on configuration.
For now, returns a default set of transforms if config is not provided.
"""
# Default transforms replicating the logic in CustomDataset
# 1. Select logical channel (if necessary, though logic was specific to 'axion' files)
# 2. MinMax Normalize
# 3. ToTensor
# 4. Add channel dimension

transform_list = [
MinMaxNormalize(),
ToTensor(),
AddChannel(dim=0)
]

return transforms.Compose(transform_list)

def get_conditional_transforms():
"""Returns transforms specifically for conditional generation."""
return transforms.Compose([
MinMaxNormalize(),
ToTensor(),
AddChannel(dim=0)
])
187 changes: 159 additions & 28 deletions DeepLense_Diffusion_Rishi/utils/test.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,164 @@
import torch
import numpy as np
"""
Usage:
python test.py --index 100 --output_dir plots_real --filename lens_100.jpg
"""
import argparse
import os

import numpy as np
import torch
import torchvision.transforms as Transforms
import matplotlib.pyplot as plt
from typing import List, Optional, Tuple

from dataset.transforms import get_transforms

# Default paths to check for data
# Determine the absolute path to the directory containing this script
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))

# Default paths to check for data (relative to script location)
DEFAULT_DATA_PATHS = [
os.path.join(SCRIPT_DIR, "../Data/cdm_regress_multi_param_model_ii/cdm_regress_multi_param/"),
os.path.join(SCRIPT_DIR, "../Data/npy_lenses-20240731T044737Z-001/npy_lenses/"),
os.path.join(SCRIPT_DIR, "../Data/real_lenses_dataset/lenses"),
]

# def get_transforms() -> Transforms.Compose: <-- Removing local definition
# Using imported get_transforms instead



def find_valid_data_dir(paths: List[str]) -> Optional[str]:
"""Iterates through possible paths and returns the first valid directory."""
for path in paths:
if os.path.exists(path) and os.path.isdir(path):
return path
return None


def load_file_list(data_dir: str) -> List[str]:
"""Returns a sorted list of .npy files in the directory."""
try:
files = sorted([f for f in os.listdir(data_dir) if f.endswith(".npy")])
return files
except OSError as e:
print(f"Error accessing directory {data_dir}: {e}")
return []


def load_data(file_path: str) -> Optional[np.ndarray]:
"""Loads numpy data from a file."""
try:
data = np.load(file_path)
return data
except Exception as e:
print(f"Error loading data from {file_path}: {e}")
return None


def normalize_data(data: np.ndarray) -> np.ndarray:
"""Normalizes data to [0, 1] range."""
min_val = np.min(data)
max_val = np.max(data)
if max_val - min_val > 0:
return (data - min_val) / (max_val - min_val)
else:
print("Warning: Data is constant. Skipping normalization.")
return data


def process_data(data: np.ndarray, transforms: Transforms.Compose) -> torch.Tensor:
"""Converts to tensor and applies transforms."""
data_torch = torch.from_numpy(data)
# Ensure float type for potential transforms
if data_torch.dtype not in [torch.float32, torch.float64]:
data_torch = data_torch.float()
return transforms(data_torch)


def save_plot(data_torch: torch.Tensor, output_dir: str, filename: str) -> bool:
"""Saves a visualization of the data."""
try:
# Permute (C, H, W) -> (H, W, C) for plotting if 3D
if data_torch.ndim == 3:
data_to_plot = data_torch.permute(1, 2, 0).to("cpu").numpy()
else:
data_to_plot = data_torch.to("cpu").numpy()

os.makedirs(output_dir, exist_ok=True)
save_path = os.path.join(output_dir, filename)

plt.figure() # Create new figure to avoid state leak
plt.imshow(data_to_plot)
plt.axis("off") # Optional: remove axes for clean image
plt.savefig(save_path, bbox_inches="tight")
plt.close()

print(f"Saved plot to {save_path}")
return True
except Exception as e:
print(f"Error saving plot: {e}")
return False


def main():
parser = argparse.ArgumentParser(description="Test script for DeepLense Diffusion")
parser.add_argument(
"--data_dirs",
nargs="+",
default=DEFAULT_DATA_PATHS,
help="List of dataset directories",
)
parser.add_argument(
"--index", type=int, default=50, help="Index of file to process"
)
parser.add_argument(
"--output_dir", type=str, default="plots", help="Output directory"
)
parser.add_argument(
"--filename", type=str, default="ddpm_ssl_actual.jpg", help="Output filename"
)

args = parser.parse_args()

data_dir = find_valid_data_dir(args.data_dirs)
if not data_dir:
print(f"Error: No valid data directory found in {args.data_dirs}")
return

print(f"Using data directory: {data_dir}")
files = load_file_list(data_dir)
if not files:
print("No .npy files found.")
return

if args.index < 0 or args.index >= len(files):
print(f"Error: Index {args.index} out of bounds ({len(files)} files).")
return

full_path = os.path.join(data_dir, files[args.index])
print(f"Processing: {full_path}")

data = load_data(full_path)
if data is None:
return

print(f"Original Shape: {data.shape}")
print(f"Range: [{np.min(data)}, {np.max(data)}]")

data = normalize_data(data)

try:
data_torch = process_data(data, get_transforms())
print(f"After transforms: {data_torch.shape}, "
f"range: [{data_torch.min().item():.4f}, {data_torch.max().item():.4f}]")
except Exception as e:
print(f"Transformation failed: {e}")
return

save_plot(data_torch, args.output_dir, args.filename)


#root_dir = '../Data/cdm_regress_multi_param_model_ii/cdm_regress_multi_param/'
#root_dir = '../Data/npy_lenses-20240731T044737Z-001/npy_lenses/'
root_dir = '../Data/real_lenses_dataset/lenses'
data_list_cdm = [ f for f in os.listdir(root_dir) if f.endswith('.npy')]
#print(data_list_cdm)
data_file_path = os.path.join(root_dir, data_list_cdm[50])
data = np.load(data_file_path)#, allow_pickle=True)
print(data.shape)
data = (data - np.min(data))/(np.max(data)-np.min(data))
print(np.min(data))
print(np.max(data))

transforms = Transforms.Compose([
# Transforms.ToTensor(), # npy loader returns torch.Tensor
Transforms.CenterCrop(64),
#Transforms.Normalize(mean = [0.06814773380756378, 0.21582692861557007, 0.4182431399822235],\
# std = [0.16798585653305054, 0.5532506108283997, 1.1966736316680908]),
])

data_torch = torch.from_numpy(data)
data_torch = transforms(data_torch)
# print(torch.min(data_torch))
# print(torch.max(data_torch))
data_torch = data_torch.permute(1, 2, 0).to('cpu').numpy()
plt.imshow(data_torch)
plt.savefig(os.path.join("plots", f"ddpm_ssl_actual.jpg"))
if __name__ == "__main__":
main()
Loading