Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/deploy-documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install .[docs]
python -m pip install --group docs
- name: Deploy docs
run: mkdocs gh-deploy --force
153 changes: 76 additions & 77 deletions climanet/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np
from .utils import add_month_day_dims
import xarray as xr
import torch
from torch.utils.data import Dataset
Expand All @@ -17,38 +19,54 @@ def __init__(
patch_size: Tuple[int, int] = (16, 16),
overlap: int = 0,
):
self.daily_da = daily_da
self.monthly_da = monthly_da
self.land_mask = land_mask
self.time_dim = time_dim
self.spatial_dims = spatial_dims
self.patch_size = patch_size
self.overlap = overlap

# Group daily data
# Create "YYYY-MM" string labels
daily_labels = self.daily_da[time_dim].dt.strftime("%Y-%m")
monthly_labels = self.monthly_da[time_dim].dt.strftime("%Y-%m")
# Check that the input data has the expected dimensions
if time_dim not in daily_da.dims or time_dim not in monthly_da.dims:
raise ValueError(f"Time dimension '{time_dim}' not found in input data")
for dim in spatial_dims:
if dim not in daily_da.dims or dim not in monthly_da.dims:
raise ValueError(f"Spatial dimension '{dim}' not found in input data")

# Reshape daily → (M, T=31, H, W), monthly → (M, H, W),
# and get padded_days_mask → (M, T=31)
daily_mt, monthly_m, padded_days_mask = add_month_day_dims(
daily_da, monthly_da, time_dim=time_dim
)

# Convert to numpy once — all __getitem__ calls use these
self.daily_np = daily_mt.to_numpy().copy() # (M, T=31, H, W) float
self.monthly_np = monthly_m.to_numpy().copy() # (M, H, W) float
self.padded_mask_np = padded_days_mask.to_numpy().copy() # (M, T=31) bool

if land_mask is not None:
lm = land_mask.to_numpy().copy()
if lm.ndim == 3:
lm = lm.squeeze(0) # (1, H, W) → (H, W)
self.land_mask_np = lm
else:
self.land_mask_np = None

# Precompute the NaN mask before filling NaNs
# daily_mask: True where NaN (i.e. missing ocean data, not land)
self.daily_nan_mask = np.isnan(self.daily_np) # (M, T=31, H, W)

# Group daily indices by month label
daily_groups = daily_labels.groupby(daily_labels).groups
# Fill NaNs with 0 in-place
np.nan_to_num(self.daily_np, copy=False, nan=0.0)

self.month_to_days = {}
for month_idx, period in enumerate(monthly_labels.values):
self.month_to_days[month_idx] = daily_groups.get(period, [])
if len(self.month_to_days[month_idx]) == 0:
raise ValueError(f"No daily data found for month index {month_idx}")
# Precompute padded_days_mask as a tensor (same for all patches)
self.padded_days_tensor = torch.from_numpy(self.padded_mask_np).bool()

# Precompute lazy index mapping for patches
dim_y, dim_x = self.spatial_dims
self.stride = self.patch_size[0] - self.overlap
self.n_i = (
self.daily_da.sizes[dim_y] - self.patch_size[0]
) // self.stride + 1 # number of horizontal patches
self.n_j = (
self.daily_da.sizes[dim_x] - self.patch_size[1]
) // self.stride + 1 # number of vertical patches
self.total_len = len(self.monthly_da[time_dim]) * self.n_i * self.n_j
H, W = self.daily_np.shape[2], self.daily_np.shape[3]
self.n_i = (H - self.patch_size[0]) // self.stride + 1
self.n_j = (W - self.patch_size[1]) // self.stride + 1

# Total length is only spatial patches (all months included in each sample)
self.total_len = self.n_i * self.n_j

def __len__(self):
return self.total_len
Expand All @@ -58,63 +76,44 @@ def __getitem__(self, idx):
if idx < 0 or idx >= self.total_len:
raise IndexError("Index out of range")

dim_y, dim_x = self.spatial_dims
per_t = self.n_i * self.n_j
t, rem = divmod(idx, per_t)
i_idx, j_idx = divmod(rem, self.n_j)
i_idx, j_idx = divmod(idx, self.n_j)
i = i_idx * self.stride
j = j_idx * self.stride

# Extract spatial patch
y_slice = slice(i, i + self.patch_size[0])
x_slice = slice(j, j + self.patch_size[1])

# Get daily data (all days in month)
# Assuming monthly timestamp corresponds to days in that month
daily_patch = self.daily_da.isel(
{
self.time_dim: self.month_to_days[t],
dim_y: y_slice,
dim_x: x_slice,
}
).to_numpy() # shape: (T, H, W)

# Add channel dim → (C=1, T, H, W)
daily_patch = torch.from_numpy(daily_patch).float().unsqueeze(0)

# Get monthly target
monthly_patch = self.monthly_da.isel(
{
self.time_dim: t,
dim_y: y_slice,
dim_x: x_slice,
}
).to_numpy()
monthly_patch = torch.from_numpy(monthly_patch).float()

if self.land_mask is not None:
land_mask_patch = self.land_mask.isel(
{dim_y: y_slice, dim_x: x_slice}
).to_numpy()
land_mask_patch = torch.from_numpy(land_mask_patch).bool() # (H,W)
ph, pw = self.patch_size

# Extract spatial patch via numpy slicing — faster than xarray indexing
daily_patch = self.daily_np[:, :, i : i + ph, j : j + pw] # (M, T, H, W)
monthly_patch = self.monthly_np[:, i : i + ph, j : j + pw] # (M, H, W)
daily_nan_mask = self.daily_nan_mask[
:, :, i : i + ph, j : j + pw
] # (M, T, H, W)

if self.land_mask_np is not None:
land_patch = self.land_mask_np[i : i + ph, j : j + pw] # (H, W)
land_tensor = torch.from_numpy(land_patch.copy()).bool()
else:
# No land mask → all ocean (False)
land_mask_patch = torch.zeros(
self.patch_size[0], self.patch_size[1], dtype=torch.bool
)

daily_mask_patch = torch.isnan(daily_patch) & (~land_mask_patch)

# Replace NaNs in daily data with zeros (after creating mask)
daily_patch = torch.nan_to_num(daily_patch, nan=0.0)
land_tensor = torch.zeros(ph, pw, dtype=torch.bool)

# Convert to tensors (from_numpy is zero-copy on contiguous arrays)
# (1, M, T, H, W)
daily_tensor = torch.from_numpy(daily_patch).float().unsqueeze(0)
# (M, H, W)
monthly_tensor = torch.from_numpy(monthly_patch).float()
# (1, M, T, H, W)
daily_nan_mask = torch.from_numpy(daily_nan_mask).unsqueeze(0)

# daily_mask: NaN locations that are NOT land
# Reshape land_tensor for broadcasting: (H, W) → (1, 1, 1, H, W)
daily_mask_tensor = daily_nan_mask & (
~land_tensor.unsqueeze(0).unsqueeze(0).unsqueeze(0)
)

# Convert to tensors
sample = {
"daily_patch": daily_patch, # (C=1, T, H, W)
"monthly_patch": monthly_patch, # (H, W)
"daily_mask_patch": daily_mask_patch, # (C=1, T, H, W)
"land_mask_patch": land_mask_patch.squeeze(), # (H,W)
"coords": (t, i, j),
return {
"daily_patch": daily_tensor, # (C=1, M, T=31, H, W)
"monthly_patch": monthly_tensor, # (M, H, W)
"daily_mask_patch": daily_mask_tensor, # (C=1, M, T=31, H, W)
"land_mask_patch": land_tensor, # (H,W) True=Land
"padded_days_mask": self.padded_days_tensor, # (M, T=31) True=padded
"coords": (i, j),
}

return sample
Loading