import functools
import random
import time
import typing
from dataclasses import dataclass
from pathlib import Path
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import dotenv
import icechunk
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import odc.geo.xr # noqa
import pydantic
import pydantic_settings
import xarray as xr
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle
from pydantic_extra_types.semantic_version import SemanticVersion
from shapely.geometry import Polygon, box
from upath import UPath
from ocr import catalog
from ocr.console import console
from ocr.types import Environment
[docs]
class CoiledConfig(pydantic_settings.BaseSettings):
tag: dict[str, str] = pydantic.Field({'Project': 'OCR'})
forward_aws_credentials: bool = pydantic.Field(
False, description='Whether to forward AWS credentials to the worker nodes'
)
spot_policy: typing.Literal['on-demand', 'spot', 'spot_with_fallback'] = pydantic.Field(
'spot_with_fallback',
description='Spot instance policy for Coiled cluster. See Coiled docs for details.',
)
region: str = pydantic.Field('us-west-2', description='AWS region to use for the worker nodes')
ntasks: pydantic.PositiveInt = pydantic.Field(
1, description='Number of tasks to run in parallel'
)
vm_type: str = pydantic.Field('m8g.2xlarge', description='VM type to use for the worker nodes')
scheduler_vm_type: str = pydantic.Field(
'm8g.2xlarge', description='VM type to use for the scheduler node'
)
model_config = {
'env_prefix': 'ocr_coiled_',
'case_sensitive': False,
}
[docs]
class ChunkingConfig(pydantic_settings.BaseSettings):
chunks: dict | None = pydantic.Field(None, description='Chunk sizes for longitude and latitude')
debug: bool = pydantic.Field(False, description='Enable debugging mode')
model_config = {
'env_prefix': 'ocr_chunking_',
'case_sensitive': False,
}
[docs]
def model_post_init(self, __context):
self.chunks = self.chunks or dict(zip(self.ds['CRPS'].dims, self.ds['CRPS'].data.chunksize))
def __repr__(self):
return self.extent.__repr__()
[docs]
@functools.cached_property
def extent(self):
from shapely.geometry import box
return box(
minx=float(self.ds.longitude.min()),
maxx=float(self.ds.longitude.max()),
miny=float(self.ds.latitude.min()),
maxy=float(self.ds.latitude.max()),
)
[docs]
@functools.cached_property
def extent_as_tuple(self):
bounds = self.extent.bounds
return (bounds[0], bounds[2], bounds[1], bounds[3])
[docs]
@functools.cached_property
def extent_as_tuple_5070(self):
"""Get extent in EPSG:5070 projection as tuple (xmin, xmax, ymin, ymax)"""
from pyproj import Transformer
bounds = self.extent.bounds
transformer = Transformer.from_crs('EPSG:4326', 'EPSG:5070', always_xy=True)
# Transform corner points
xmin_5070, ymin_5070 = transformer.transform(bounds[0], bounds[1])
xmax_5070, ymax_5070 = transformer.transform(bounds[2], bounds[3])
return (xmin_5070, xmax_5070, ymin_5070, ymax_5070)
[docs]
@functools.cached_property
def ds(self):
dataset = (
catalog.get_dataset('scott-et-al-2024-30m-4326').to_xarray().astype('float32')[['CRPS']]
)
dataset = dataset.odc.assign_crs('epsg:4326')
return dataset
[docs]
@functools.cached_property
def chunk_info(self) -> dict:
"""Get information about the dataset's chunks"""
y_chunks, x_chunks = self.ds['CRPS'].data.chunks
y_starts = np.cumsum([0] + list(y_chunks[:-1]))
x_starts = np.cumsum([0] + list(x_chunks[:-1]))
return {
'y_chunks': y_chunks,
'x_chunks': x_chunks,
'y_starts': y_starts,
'x_starts': x_starts,
}
[docs]
@functools.cached_property
def valid_region_ids(self) -> list:
"""Generate valid region IDs by checking which regions contain non-null data.
Returns
-------
list
List of valid region IDs (e.g., 'y1_x3', 'y2_x4', etc.)
"""
import json
# Use cache file in the package directory
cache_file = Path(__file__).parent / 'data' / 'valid_region_ids.json'
# Try to load from cache first
if cache_file.exists():
try:
with open(cache_file) as f:
cached_data = json.load(f)
if self.debug:
console.log(
f'Loaded {len(cached_data)} valid region IDs from cache: {cache_file}'
)
return cached_data
except Exception as e:
if self.debug:
console.log(f'Failed to load cache file: {e}. Regenerating...')
# If cache doesn't exist or failed to load, compute valid region IDs
chunk_info = self.chunk_info
y_starts = chunk_info['y_starts']
x_starts = chunk_info['x_starts']
if self.debug:
console.log('Computing valid region IDs (this may take a while)...')
valid_region_ids = []
for iy, _ in enumerate(y_starts):
for ix, _ in enumerate(x_starts):
region_id = f'y{iy}_x{ix}'
y_slice, x_slice = self.region_id_to_latlon_slices(region_id=region_id)
subds = self.ds.sel(latitude=y_slice, longitude=x_slice)
all_null = bool(subds.CRPS.isnull().all().values)
if not all_null:
valid_region_ids.append(region_id)
# Save to cache for future use
cache_file.parent.mkdir(parents=True, exist_ok=True)
with open(cache_file, 'w') as f:
json.dump(valid_region_ids, f, indent=2)
if self.debug:
console.log(
f'Computed and cached {len(valid_region_ids)} valid region IDs to {cache_file}'
)
return valid_region_ids
[docs]
def index_to_coords(self, x_idx: int, y_idx: int) -> tuple[float, float]:
"""Convert array indices to EPSG:4326 coordinates
Parameters
----------
x_idx : int
Index along the x-dimension (longitude)
y_idx : int
Index along the y-dimension (latitude)
Returns
-------
x, y : tuple[float, float]
Corresponding EPSG:4326 coordinates (longitude, latitude)
"""
x, y = self.transform * (x_idx, y_idx)
return x, y
[docs]
def chunks_to_slices(self, chunks: dict) -> dict:
"""Create a dict of chunk_ids and slices from input chunk dict
Parameters
----------
chunks : dict
Dictionary with chunk sizes for 'longitude' and 'latitude'
Returns
-------
dict
Dictionary with chunk IDs as keys and corresponding slices as values
"""
return {key: self.chunk_id_to_slice(value) for key, value in chunks.items()}
[docs]
def region_id_chunk_lookup(self, region_id: str) -> tuple:
"""given a region_id, ex: 'y5_x14, returns the corresponding chunk (5, 14)
Parameters
----------
region_id : str
The region_id for chunk_id lookup.
Returns
-------
index : tuple[int, int]
The corresponding chunk (iy, ix) for the given region_id.
"""
return self.get_chunk_mapping()[region_id]
[docs]
def region_id_slice_lookup(self, region_id: str) -> tuple:
"""given a region_id, ex: 'y5_x14, returns the corresponding x,y slices. ex:
(slice(np.int64(30000), np.int64(36000), None),
slice(np.int64(85500), np.int64(90000), None))
Parameters
----------
region_id : str
The region_id for chunk_id lookup.
Returns
-------
indexer : tuple[slice]
The corresponding slices (y_slice, x_slice) for the given region_id.
"""
return self.chunk_id_to_slice(self.region_id_chunk_lookup(region_id))
[docs]
def chunk_id_to_slice(self, chunk_id: tuple) -> tuple:
"""
Convert a chunk ID (iy, ix) to corresponding array slices
Parameters
----------
chunk_id : tuple
The chunk identifier as a tuple (iy, ix) where:
- iy is the index along y-dimension
- ix is the index along x-dimension
Returns
-------
chunk_slices : tuple[slice]
A tuple of slices (y_slice, x_slice) to extract data for this chunk
"""
iy, ix = chunk_id
# Get chunk info
chunk_info = self.chunk_info
y_chunks = chunk_info['y_chunks']
x_chunks = chunk_info['x_chunks']
y_starts = chunk_info['y_starts']
x_starts = chunk_info['x_starts']
# Validate chunk indices
if iy < 0 or iy >= len(y_chunks) or ix < 0 or ix >= len(x_chunks):
raise ValueError(f'Invalid chunk ID: {chunk_id}. Out of bounds.')
# Get start positions for this chunk
y_start = y_starts[iy]
x_start = x_starts[ix]
# Get sizes for this chunk
y_size = y_chunks[iy]
x_size = x_chunks[ix]
# Create and return the slices
y_slice = slice(y_start, y_start + y_size)
x_slice = slice(x_start, x_start + x_size)
return (y_slice, x_slice)
[docs]
def region_id_to_latlon_slices(self, region_id: str) -> tuple:
"""
Get latitude and longitude slices from region_id
Returns (lat_slice, lon_slice) where lat_slice.start < lat_slice.stop
and lon_slice.start < lon_slice.stop (lower-left origin, lat ascending).
"""
chunk_id = self.region_id_chunk_lookup(region_id)
# Get array index slices for this chunk (y, x)
y_slice_idx, x_slice_idx = self.chunk_id_to_slice(chunk_id)
# Convert corners to coordinates
x_min, y_max = self.index_to_coords(x_slice_idx.start, y_slice_idx.start) # upper-left
x_max, y_min = self.index_to_coords(x_slice_idx.stop, y_slice_idx.stop) # lower-right
# Ensure slices are ascending (lower-left origin)
lat_start, lat_stop = sorted((y_min, y_max))
lon_start, lon_stop = sorted((x_min, x_max))
lat_slice = slice(lat_start, lat_stop)
lon_slice = slice(lon_start, lon_stop)
return (lat_slice, lon_slice)
[docs]
def get_chunk_mapping(self) -> dict[str, tuple[int, int]]:
"""Returns a dict of region_ids and their corresponding chunk_indexes.
Returns
-------
chunk_mapping : dict
Dictionary with region IDs as keys and corresponding chunk indexes (iy, ix) as values
"""
chunk_info = self.chunk_info
y_starts = chunk_info['y_starts']
x_starts = chunk_info['x_starts']
chunk_mapping = {}
for iy, y0 in enumerate(y_starts):
for ix, x0 in enumerate(x_starts):
chunk_mapping[f'y{iy}_x{ix}'] = (iy, ix)
return chunk_mapping
[docs]
def plot_all_chunks(self, color_by_size: bool = False) -> None:
"""
Plot all data chunks across the entire CONUS with their indices as labels
Parameters
----------
color_by_size : bool, default False
If True, color chunks based on their size (useful to identify irregularities)
"""
# Create figure
fig, ax = plt.subplots(figsize=(24, 16), subplot_kw={'projection': ccrs.PlateCarree()})
# Set extent to show CONUS
ax.set_extent(self.extent_as_tuple, crs=ccrs.PlateCarree())
# Get chunk information
chunk_info = self.chunk_info
y_chunks = chunk_info['y_chunks']
x_chunks = chunk_info['x_chunks']
y_starts = chunk_info['y_starts']
x_starts = chunk_info['x_starts']
# Track chunk sizes for coloring if needed
norm = None
cmap = None
if color_by_size:
sizes = [
y_chunks[iy] * x_chunks[ix]
for iy in range(len(y_chunks))
for ix in range(len(x_chunks))
]
min_size = min(sizes)
max_size = max(sizes)
norm = mcolors.Normalize(vmin=min_size, vmax=max_size)
cmap = cm.viridis
# Draw each chunk with label
for iy, y0 in enumerate(y_starts):
h = y_chunks[iy]
for ix, x0 in enumerate(x_starts):
w = x_chunks[ix]
# Get chunk boundaries in geographic coordinates
xx0, yy0 = self.index_to_coords(x0, y0)
xx1, yy1 = self.index_to_coords(x0 + w, y0 + h)
# Choose color based on size or use default cycle
if color_by_size and cmap is not None and norm is not None:
size = h * w
color = cmap(norm(size))
else:
# Use a simple coloring scheme based on indices
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']
color = colors[(iy * len(x_starts) + ix) % len(colors)]
# Draw rectangle around the chunk
rect = Rectangle(
(xx0, yy1), # lower left (x, y)
xx1 - xx0, # width
yy0 - yy1, # height
transform=ccrs.PlateCarree(),
fill=True,
facecolor=color,
alpha=0.3,
edgecolor=color,
linewidth=1.5,
zorder=10,
)
ax.add_patch(rect)
center_x = (xx0 + xx1) / 2
center_y = (yy0 + yy1) / 2
region_id = f'y{iy}_x{ix}'
ax.text(
center_x,
center_y,
region_id,
transform=ccrs.PlateCarree(),
ha='center',
va='center',
fontsize=6,
fontweight='bold',
color='black',
bbox=dict(boxstyle='round,pad=0.2', facecolor='white', alpha=0.7),
zorder=20,
)
# Add geographic features
ax.coastlines(resolution='10m')
ax.add_feature(cfeature.BORDERS, linewidth=0.8)
ax.add_feature(cfeature.STATES, linewidth=0.5, edgecolor='gray')
# Add a colorbar if coloring by size
if color_by_size:
sm = cm.ScalarMappable(norm=norm, cmap=cmap)
cbar = plt.colorbar(sm, ax=ax, shrink=0.6, pad=0.01)
cbar.set_label('Chunk Size (pixels)')
# Set title
ax.set_title(
f'All Chunks ({len(y_chunks)}×{len(x_chunks)} = {len(y_chunks) * len(x_chunks)})'
)
plt.tight_layout()
plt.show()
[docs]
def bbox_from_wgs84(self, xmin: float, ymin: float, xmax: float, ymax: float):
"https://observablehq.com/@rdmurphy/u-s-state-bounding-boxes"
# Create and return bounding box in EPSG:4326 (WGS84)
# This matches the coordinate system of the data
bbox = box(xmin, ymin, xmax, ymax)
return bbox
[docs]
def get_chunks_for_bbox(self, bbox: Polygon | tuple) -> list[tuple[int, int]]:
"""
Find all chunks that intersect with the given bounding box
Parameters
----------
bbox : BoundingBox or tuple
Bounding box to check for intersection. If tuple, format is (minx, miny, maxx, maxy)
Returns
-------
list of tuples
List of (iy, ix) tuples identifying the intersecting chunks
"""
# Convert tuple to BoundingBox if needed
if isinstance(bbox, tuple):
if len(bbox) == 4:
bbox = box(minx=bbox[0], miny=bbox[1], maxx=bbox[2], maxy=bbox[3])
else:
raise ValueError('Bounding box tuple must have 4 elements (minx, miny, maxx, maxy)')
# Get chunk info
chunk_info = self.chunk_info
y_chunks = chunk_info['y_chunks']
x_chunks = chunk_info['x_chunks']
y_starts = chunk_info['y_starts']
x_starts = chunk_info['x_starts']
# Find intersecting chunks
intersecting_chunks = []
for iy, y0 in enumerate(y_starts):
h = y_chunks[iy]
for ix, x0 in enumerate(x_starts):
w = x_chunks[ix]
# Get chunk boundaries in geographic coordinates
xx0, yy0 = self.index_to_coords(x0, y0)
xx1, yy1 = self.index_to_coords(x0 + w, y0 + h)
# Create a box for this chunk (note Y axis flip)
chunk_box = box(xx0, yy1, xx1, yy0)
# Check for intersection
if bbox.intersects(chunk_box):
intersecting_chunks.append((iy, ix))
return intersecting_chunks
[docs]
def visualize_chunks_on_conus(
self,
chunks: list[tuple[int, int]] | None = None,
color_by_size: bool = False,
highlight_chunks: list[tuple[int, int]] | None = None,
include_all_chunks: bool = False,
) -> None:
"""
Visualize specified chunks on CONUS map
Parameters
----------
chunks : list of tuples, optional
List of (iy, ix) tuples specifying chunks to visualize
If None, will show all chunks
color_by_size : bool, default False
If True, color chunks based on their size
highlight_chunks : list of tuples, optional
List of (iy, ix) tuples specifying chunks to highlight
include_all_chunks : bool, default False
If True, show all chunks in background with low opacity
"""
# Create figure
fig, ax = plt.subplots(figsize=(16, 12), subplot_kw={'projection': ccrs.PlateCarree()})
# Set extent - either full CONUS or custom extent
ax.set_extent(self.extent_as_tuple, crs=ccrs.PlateCarree())
# Get chunk information
chunk_info = self.chunk_info
y_chunks = chunk_info['y_chunks']
x_chunks = chunk_info['x_chunks']
y_starts = chunk_info['y_starts']
x_starts = chunk_info['x_starts']
# Set up colors
norm = None
cmap = None
if color_by_size:
sizes = [
y_chunks[iy] * x_chunks[ix]
for iy in range(len(y_chunks))
for ix in range(len(x_chunks))
]
min_size = min(sizes)
max_size = max(sizes)
norm = mcolors.Normalize(vmin=min_size, vmax=max_size)
cmap = cm.viridis
# Default to all chunks if none specified
if chunks is None:
chunks = [(iy, ix) for iy in range(len(y_chunks)) for ix in range(len(x_chunks))]
# Draw background chunks if requested
if include_all_chunks and chunks != [
(iy, ix) for iy in range(len(y_chunks)) for ix in range(len(x_chunks))
]:
for iy, y0 in enumerate(y_starts):
h = y_chunks[iy]
for ix, x0 in enumerate(x_starts):
# Skip chunks that are in the main visualization
if (iy, ix) in chunks:
continue
w = x_chunks[ix]
xx0, yy0 = self.index_to_coords(x0, y0)
xx1, yy1 = self.index_to_coords(x0 + w, y0 + h)
rect = Rectangle(
(xx0, yy1),
xx1 - xx0,
yy0 - yy1,
transform=ccrs.PlateCarree(),
fill=True,
facecolor='lightgray',
alpha=0.2,
edgecolor='gray',
linewidth=0.5,
zorder=5,
)
ax.add_patch(rect)
# Draw the specified chunks with proper styling
for iy, ix in chunks:
y0 = y_starts[iy]
h = y_chunks[iy]
x0 = x_starts[ix]
w = x_chunks[ix]
# Get chunk boundaries in geographic coordinates
xx0, yy0 = self.index_to_coords(x0, y0)
xx1, yy1 = self.index_to_coords(x0 + w, y0 + h)
# Determine styling
is_highlighted = highlight_chunks is not None and (iy, ix) in highlight_chunks
# Choose color based on size or use default cycle
if is_highlighted:
color = 'red'
fill_alpha = 0.4
linewidth = 2.0
zorder = 15
elif color_by_size and cmap is not None and norm is not None:
size = h * w
color = cmap(norm(size))
# edge_alpha = 0.8
fill_alpha = 0.3
linewidth = 1.5
zorder = 10
else:
# Use a simple coloring scheme based on indices
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b']
color = colors[(iy * len(x_starts) + ix) % len(colors)]
fill_alpha = 0.3
linewidth = 1.5
zorder = 10
# Draw rectangle around the chunk
rect = Rectangle(
(xx0, yy1), # lower left (x, y)
xx1 - xx0, # width
yy0 - yy1, # height
transform=ccrs.PlateCarree(),
fill=True,
facecolor=color,
alpha=fill_alpha,
edgecolor=color,
linewidth=linewidth,
zorder=zorder,
)
ax.add_patch(rect)
# Add geographic features
ax.coastlines(resolution='10m')
ax.add_feature(cfeature.BORDERS, linewidth=0.8)
ax.add_feature(cfeature.STATES, linewidth=0.5, edgecolor='gray')
# Add a colorbar if coloring by size
if color_by_size:
sm = cm.ScalarMappable(norm=norm, cmap=cmap)
cbar = plt.colorbar(sm, ax=ax, shrink=0.6, pad=0.01)
cbar.set_label('Chunk Size (pixels)')
# Set title
if len(chunks) == len(y_chunks) * len(x_chunks):
ax.set_title(f'All Chunks ({len(y_chunks)}×{len(x_chunks)} = {len(chunks)})')
else:
ax.set_title(
f'Selected Chunks ({len(chunks)} of {len(y_chunks)}×{len(x_chunks)} total)'
)
# Add a legend
legend_elements = [Line2D([0], [0], color='blue', lw=2, label='Selected Chunks')]
if highlight_chunks:
legend_elements.append(Line2D([0], [0], color='red', lw=2, label='Highlighted Chunks'))
if include_all_chunks:
legend_elements.append(Line2D([0], [0], color='gray', lw=1, label='Other Chunks'))
ax.legend(handles=legend_elements, loc='lower right')
plt.tight_layout()
plt.show()
[docs]
class PyramidConfig(pydantic_settings.BaseSettings):
"""Configuration for visualization pyramid / multiscales"""
environment: Environment = pydantic.Field(
default=Environment.QA, description='Environment for pyramid'
)
version: SemanticVersion | None = pydantic.Field(
default=None, description='Version of the pyramid processing pipeline'
)
storage_root: str = pydantic.Field(
..., description='Root storage path for pyramid. can be a bucket name or local path'
)
output_prefix: str | None = pydantic.Field(
None, description='Sub-path within the storage root for pipeline output products'
)
debug: bool = pydantic.Field(default=False, description='Enable debugging mode')
model_config = {'env_prefix': 'ocr_vector_', 'case_sensitive': False}
[docs]
def model_post_init(self, __context):
"""Post-initialization to set up prefixes and URIs based on environment."""
common_part = f'fire-risk/pyramid/{self.environment.value}'
if self.output_prefix is None:
if self.version:
self.output_prefix = f'output/{common_part}/v{self.version}/pyramid.zarr'
else:
self.output_prefix = f'output/{common_part}/pyramid.zarr'
if self.output_prefix and self.version:
if f'v{self.version}' not in self.output_prefix:
# insert version right before the last part of the prefix
parts = self.output_prefix.rsplit('/', 1)
if len(parts) == 2:
self.output_prefix = (
f'{parts[0]}/{self.environment.value}/v{self.version}/{parts[1]}'
)
else:
self.output_prefix = (
f'{self.environment.value}/v{self.version}/{self.output_prefix}'
)
@property
def pyramid_uri(self) -> UPath:
path = UPath(f'{self.storage_root}/{self.output_prefix}')
return path
[docs]
def wipe(self):
"""Wipe the pyramid data storage."""
if self.debug:
console.log(f'Wiping pyramid data:\n- {self.pyramid_uri.parent}\n')
self.upath_delete(self.pyramid_uri.parent)
[docs]
class VectorConfig(pydantic_settings.BaseSettings):
"""Configuration for vector data processing."""
environment: Environment = pydantic.Field(
default=Environment.QA, description='Environment for vector processing'
)
version: SemanticVersion | None = pydantic.Field(
default=None, description='Version of the vector processing pipeline'
)
storage_root: str = pydantic.Field(
..., description='Root storage path for vector data, can be a bucket name or local path'
)
prefix: str | None = pydantic.Field(None, description='Sub-path within the storage root')
output_prefix: str | None = pydantic.Field(
None, description='Sub-path within the storage root for pipeline output products'
)
debug: bool = pydantic.Field(default=False, description='Enable debugging mode')
model_config = {'env_prefix': 'ocr_vector_', 'case_sensitive': False}
metadata: dict[str, str] | None = pydantic.Field(
None, description='metadata to add to datasets'
)
@property
def metadata_dict(self) -> dict[str, str]:
"""Get metadata dict with ODBL license for vector data."""
return {
'version': str(self.version) if self.version else 'unversioned',
'provider': 'CarbonPlan',
'terms_of_access': 'https://docs.carbonplan.org/ocr/en/latest/terms-of-data-access.html',
'data_sources': 'https://docs.carbonplan.org/ocr/en/latest/reference/data-sources.html',
'license_name': 'ODBL',
'license_url': 'https://opendatacommons.org/licenses/odbl/',
'notice': 'Contains information from the Overture Maps Foundation database, which is made available here under the Open Database License (ODbL), a copy of which is available at https://opendatacommons.org/licenses/odbl/1-0/.',
}
[docs]
def model_post_init(self, __context):
"""Post-initialization to set up prefixes and URIs based on environment."""
common_part = f'fire-risk/vector/{self.environment.value}'
if self.prefix is None:
if self.version:
self.prefix = f'intermediate/{common_part}/v{self.version}'
else:
self.prefix = f'intermediate/{common_part}'
if self.output_prefix is None:
if self.version:
self.output_prefix = f'output/{common_part}/v{self.version}'
else:
self.output_prefix = f'output/{common_part}'
if self.prefix and self.version:
if f'v{self.version}' not in self.prefix:
# insert version right before the last part of the prefix
parts = self.prefix.rsplit('/', 1)
if len(parts) == 2:
self.prefix = f'{parts[0]}/{self.environment.value}/v{self.version}/{parts[1]}'
else:
self.prefix = f'{self.environment.value}/v{self.version}/{self.prefix}'
if self.output_prefix and self.version:
if f'v{self.version}' not in self.output_prefix:
# insert version right before the last part of the prefix
parts = self.output_prefix.rsplit('/', 1)
if len(parts) == 2:
self.output_prefix = (
f'{parts[0]}/{self.environment.value}/v{self.version}/{parts[1]}'
)
else:
self.output_prefix = (
f'{self.environment.value}/v{self.version}/{self.output_prefix}'
)
[docs]
def wipe(self):
"""Wipe the vector data storage."""
if self.debug:
console.log(
f'Wiping vector data storage at these locations:\n'
f'- {self.building_geoparquet_uri.parent}\n'
f'- {self.buildings_pmtiles_uri.parent}\n'
f'- {self.building_centroids_pmtiles_uri.parent}\n'
f'- {self.region_geoparquet_uri}\n'
f'- {self.aggregated_region_analysis_uri}\n'
f'- {self.tracts_summary_stats_uri.parent}\n'
)
self.upath_delete(self.building_geoparquet_uri.parent)
self.upath_delete(self.buildings_pmtiles_uri.parent)
self.upath_delete(self.building_centroids_pmtiles_uri.parent)
self.upath_delete(self.region_geoparquet_uri)
self.upath_delete(self.aggregated_region_analysis_uri)
self.upath_delete(self.tracts_summary_stats_uri.parent)
# ----------------------------
# output pmtiles
# ----------------------------
[docs]
@functools.cached_property
def pmtiles_prefix(self) -> str:
return f'{self.output_prefix}/pmtiles'
[docs]
@functools.cached_property
def buildings_pmtiles_uri(self) -> UPath:
path = UPath(f'{self.storage_root}/{self.pmtiles_prefix}/buildings.pmtiles')
path.parent.mkdir(parents=True, exist_ok=True)
return path
[docs]
@functools.cached_property
def building_centroids_pmtiles_uri(self) -> UPath:
path = UPath(f'{self.storage_root}/{self.pmtiles_prefix}/building_centroids.pmtiles')
path.parent.mkdir(parents=True, exist_ok=True)
return path
[docs]
@functools.cached_property
def region_pmtiles_uri(self) -> UPath:
path = UPath(f'{self.storage_root}/{self.pmtiles_prefix}/regions.pmtiles')
path.parent.mkdir(parents=True, exist_ok=True)
return path
# ----------------------------
# geoparquet
# ----------------------------
[docs]
@functools.cached_property
def region_geoparquet_prefix(self) -> str:
return f'{self.prefix}/geoparquet-regions'
[docs]
@functools.cached_property
def geoparquet_prefix(self) -> str:
return f'{self.output_prefix}/geoparquet'
[docs]
@functools.cached_property
def region_geoparquet_uri(self) -> UPath:
path = UPath(f'{self.storage_root}/{self.region_geoparquet_prefix}')
path.mkdir(parents=True, exist_ok=True)
return path
[docs]
@functools.cached_property
def aggregated_region_analysis_prefix(self) -> str:
return f'{self.output_prefix}/region-analysis'
[docs]
@functools.cached_property
def aggregated_region_analysis_uri(self) -> UPath:
path = UPath(f'{self.storage_root}/{self.aggregated_region_analysis_prefix}')
path.mkdir(parents=True, exist_ok=True)
return path
@property
def building_geoparquet_uri(self) -> UPath:
path = UPath(f'{self.storage_root}/{self.geoparquet_prefix}/buildings.parquet')
path.mkdir(parents=True, exist_ok=True)
return path
[docs]
@functools.cached_property
def region_summary_stats_prefix(self) -> UPath:
path = UPath(f'{self.storage_root}/{self.prefix}/region-summary-stats/')
path.mkdir(parents=True, exist_ok=True)
return path
[docs]
@functools.cached_property
def block_summary_stats_uri(self) -> UPath:
"""URI for the block summary statistics file."""
geo_table_name = 'blocks'
return self.region_summary_stats_prefix / f'{geo_table_name}_summary_stats.parquet'
[docs]
@functools.cached_property
def tracts_summary_stats_uri(self) -> UPath:
"""URI for the tracts summary statistics file."""
geo_table_name = 'tracts'
return self.region_summary_stats_prefix / f'{geo_table_name}_summary_stats.parquet'
[docs]
@functools.cached_property
def counties_summary_stats_uri(self) -> UPath:
"""URI for the counties summary statistics file."""
geo_table_name = 'counties'
return self.region_summary_stats_prefix / f'{geo_table_name}_summary_stats.parquet'
[docs]
@functools.cached_property
def states_summary_stats_uri(self) -> UPath:
"""URI for the states summary statistics file."""
geo_table_name = 'states'
return self.region_summary_stats_prefix / f'{geo_table_name}_summary_stats.parquet'
[docs]
@functools.cached_property
def nation_summary_stats_uri(self) -> UPath:
"""URI for the nation summary statistics file."""
geo_table_name = 'nation'
return self.region_summary_stats_prefix / f'{geo_table_name}_summary_stats.parquet'
[docs]
def upath_delete(self, path: UPath) -> None:
"""Use UPath to handle deletion in a cloud-agnostic way"""
if not path.exists():
if self.debug:
console.log('No files found to delete.')
return
protocol = path.protocol
# For S3, use fsspec's rm method which supports recursive deletion
if protocol == 's3':
if self.debug:
console.log(f'Deleting S3 path: {path}')
# Use the underlying filesystem's rm method for efficient batch deletion
fs = path.fs
fs.rm(path.path, recursive=True)
else:
# For local filesystems, use standard recursive deletion
if self.debug:
console.log(f'Deleting local path: {path}')
import shutil
if path.is_dir():
shutil.rmtree(path)
else:
path.unlink()
[docs]
def pretty_paths(self) -> None:
"""Pretty print key VectorConfig paths and URIs.
This method intentionally touches cached properties that create
directories (e.g., via mkdir) so you can verify real locations.
"""
from rich.panel import Panel
from rich.table import Table
def nv(name: str, value: str | None):
return name, (str(value) if value not in (None, '') else '—')
rows: list[tuple[str, str]] = []
# high-level
rows.append(nv('Environment', getattr(self.environment, 'value', str(self.environment))))
rows.append(nv('Version', (str(self.version) if self.version else '—')))
rows.append(nv('Storage root', self.storage_root))
# prefixes (touch real properties)
rows.append(nv('Intermediate prefix', self.prefix))
rows.append(nv('Output prefix', self.output_prefix))
rows.append(nv('Geoparquet prefix', self.geoparquet_prefix))
rows.append(nv('Region Geoparquet prefix', self.region_geoparquet_prefix))
rows.append(nv('PMTiles prefix', self.pmtiles_prefix))
# derived URIs (touch cached properties that mkdir/prepare parents)
rows.extend(
[
nv('Region Geoparquet URI', str(self.region_geoparquet_uri)),
nv('Buildings Geoparquet', str(self.building_geoparquet_uri)),
nv('Region summary stats dir', str(self.region_summary_stats_prefix)),
nv('Block summary stats', str(self.block_summary_stats_uri)),
nv('Tracts summary stats', str(self.tracts_summary_stats_uri)),
nv('Counties summary stats', str(self.counties_summary_stats_uri)),
nv('States summary stats', str(self.states_summary_stats_uri)),
nv('Nation summary stats', str(self.nation_summary_stats_uri)),
nv('Buildings PMTiles', str(self.buildings_pmtiles_uri)),
nv('Building centroids PMTiles', str(self.building_centroids_pmtiles_uri)),
nv('Region PMTiles', str(self.region_pmtiles_uri)),
]
)
table = Table(title=None, show_header=True, header_style='bold magenta')
table.add_column('Vector setting', style='bold cyan', no_wrap=True)
table.add_column('Value', style='green')
for k, v in rows:
table.add_row(k, v)
console.print(Panel(table, title='VectorConfig paths', title_align='left'))
[docs]
class IcechunkConfig(pydantic_settings.BaseSettings):
"""Configuration for icechunk processing."""
environment: Environment = pydantic.Field(
default=Environment.QA, description='Environment for icechunk processing'
)
version: SemanticVersion | None = pydantic.Field(
None, description='Version of the icechunk processing pipeline'
)
storage_root: str = pydantic.Field(
..., description='Root storage path for icechunk data, can be a bucket name or local path'
)
prefix: str | None = pydantic.Field(None, description='Sub-path within the storage root')
debug: bool = pydantic.Field(default=False, description='Enable debugging mode')
metadata: dict[str, str] | None = pydantic.Field(
None, description='metadata to add to datasets'
)
@property
def metadata_dict(self) -> dict[str, str]:
"""Get metadata dict with CC-BY-4.0 license for icechunk data."""
return {
'version': str(self.version) if self.version else 'unversioned',
'provider': 'CarbonPlan',
'terms_of_access': 'https://docs.carbonplan.org/ocr/en/latest/terms-of-data-access.html',
'data_sources': 'https://docs.carbonplan.org/ocr/en/latest/reference/data-sources.html',
'license_name': 'CC-BY-4.0',
'license_url': 'https://creativecommons.org/licenses/by/4.0/',
}
[docs]
def model_post_init(self, __context):
"""Post-initialization to set up prefixes and URIs based on environment."""
common_part = f'fire-risk/tensor/{self.environment.value}'
if self.prefix is None:
name = 'ocr.icechunk' if self.version is None else f'v{self.version}/ocr.icechunk'
prefix = f'output/{common_part}/{name}'
self.prefix = prefix
if self.prefix and self.version:
if f'v{self.version}' not in self.prefix:
# insert version right before the last part of the prefix
parts = self.prefix.rsplit('/', 1)
if len(parts) == 2:
self.prefix = f'{parts[0]}/{self.environment.value}/v{self.version}/{parts[1]}'
else:
self.prefix = f'{self.environment.value}/v{self.version}/{self.prefix}'
[docs]
def wipe(self):
"""Wipe the icechunk repository."""
self.delete()
self.init_repo()
[docs]
@functools.cached_property
def uri(self) -> UPath:
"""Return the URI for the icechunk repository."""
if self.prefix is None:
raise ValueError('Prefix must be set before initializing the icechunk repo.')
return UPath(f'{self.storage_root}/{self.prefix}')
[docs]
@functools.cached_property
def storage(self) -> icechunk.Storage:
if self.uri is None:
raise ValueError('URI must be set before initializing the icechunk repo.')
protocol = self.uri.protocol
if protocol == 's3':
parts = self.uri.parts
bucket = parts[0].strip('/')
prefix = '/'.join(parts[1:])
storage = icechunk.s3_storage(bucket=bucket, prefix=prefix, from_env=True)
elif protocol in {'file', 'local'} or protocol == '':
storage = icechunk.local_filesystem_storage(path=str(self.uri.path))
else:
raise ValueError(
f'Unsupported protocol: {protocol}. Supported protocols are: [s3, file, local]'
)
return storage
[docs]
def init_repo(self):
"""Creates an icechunk repo or opens if does not exist"""
icechunk.Repository.open_or_create(self.storage)
if self.debug:
console.log('Initialized/Opened icechunk repository')
commits = self.commit_messages_ancestry()
if 'initialize store with template' not in commits:
if self.debug:
console.log('No template found in icechunk store. Creating a new template dataset.')
self.create_template()
[docs]
def repo_and_session(self, readonly: bool = False, branch: str = 'main') -> dict:
"""Open an icechunk repository and return the session."""
storage = self.storage
repo = icechunk.Repository.open(storage)
if readonly:
session = repo.readonly_session(branch=branch)
else:
session = repo.writable_session(branch=branch)
if self.debug:
console.log(
f'Opened icechunk repository at {self.uri} with branch {branch} in {"readonly" if readonly else "writable"} mode.'
)
return {'repo': repo, 'session': session}
[docs]
def delete(self):
"""Delete the icechunk repository."""
if self.uri is None:
raise ValueError('URI must be set before deleting the icechunk repo.')
if self.uri.protocol == 's3':
if self.uri.exists():
# Use the underlying filesystem's rm method for efficient batch deletion
fs = self.uri.fs
console.log(f'Deleting icechunk repository at {self.uri}')
fs.rm(self.uri.path, recursive=True)
else:
if self.debug:
console.log('No files found to delete.')
elif self.uri.protocol in {'file', 'local'} or self.uri.protocol == '':
path = self.uri.path
import shutil
if UPath(path).exists():
console.log(f'Deleting icechunk repository at {self.uri}')
shutil.rmtree(path)
else:
if self.debug:
console.log('No files found to delete.')
if self.debug:
console.log('Deleted icechunk repository')
[docs]
def create_template(self):
"""Create a template dataset for icechunk store"""
import dask
import dask.array
import numpy as np
import xarray as xr
repo_and_session = self.repo_and_session()
# NOTE: This is hardcoded as using the USFS 30m chunking scheme!
config = ChunkingConfig()
ds = config.ds
ds['CRPS'].encoding = {}
template = xr.Dataset(ds.coords).drop_vars('spatial_ref')
template.attrs.update(self.metadata_dict)
var_encoding_dict = {
'chunks': ((config.chunks['latitude'], config.chunks['longitude'])),
'fill_value': np.nan,
}
template_data_array = xr.DataArray(
dask.array.empty(
(config.ds.sizes['latitude'], config.ds.sizes['longitude']),
dtype='float32',
chunks=-1,
),
dims=('latitude', 'longitude'),
)
variables = [
'rps_scott',
'rps_2011',
'rps_2047',
'bp_2011',
'bp_2047',
'crps_scott',
'bp_2011_riley',
'bp_2047_riley',
]
template_encoding_dict = {}
for variable in variables:
template[variable] = template_data_array
template_encoding_dict[variable] = var_encoding_dict
template.to_zarr(
repo_and_session['session'].store,
compute=False,
mode='w',
encoding=template_encoding_dict,
consolidated=False,
)
repo_and_session['session'].commit('initialize store with template')
if self.debug:
console.log('Created icechunk template')
[docs]
def commit_messages_ancestry(self, branch: str = 'main') -> list[str]:
"""Get the commit messages ancestry for the icechunk repository."""
repo_and_session = self.repo_and_session(readonly=True)
repo = repo_and_session['repo']
commit_messages = [commit.message for commit in list(repo.ancestry(branch=branch))]
# separate commits by ',' and handle case of single length ancestry commit history
split_commits = [
msg
for message in commit_messages
for msg in (message.split(',') if ',' in message else [message])
]
return split_commits
[docs]
def region_id_exists(self, region_id: str, *, branch: str = 'main') -> bool:
region_ids_in_ancestry = self.commit_messages_ancestry(branch=branch)
if region_id in region_ids_in_ancestry:
return True
return False
[docs]
def processed_regions(self, *, branch: str = 'main') -> list[str]:
"""Get a list of region IDs that have already been processed."""
region_ids = set()
for message in self.commit_messages_ancestry(branch=branch):
if message.startswith('wrote region_id'):
region_ids.add(message.split('(')[1].split(')')[0])
result = sorted(region_ids)
if self.debug:
console.log(f'Found processed {len(result)} region IDs: {result}')
return result
[docs]
def insert_region_uncooperative(
self, subset_ds: xr.Dataset, *, region_id: str, branch: str = 'main'
):
"""Insert region into Icechunk store
Parameters
----------
subset_ds : xr.Dataset
The subset dataset to insert into the Icechunk store.
region_id : str
The region ID corresponding to the subset dataset.
branch : str, optional
The branch to use in the Icechunk repository, by default 'main'.
"""
if self.debug:
console.log(f'Inserting region: {region_id} into Icechunk store: ')
while True:
try:
session = self.repo_and_session(readonly=False, branch=branch)['session']
subset_ds.to_zarr(
session.store,
region='auto',
consolidated=False,
)
# Trying out the rebase strategy described here: https://github.com/earth-mover/icechunk/discussions/802#discussioncomment-13064039
# We should be in the same position, where we don't have real conflicts, just write timing conflicts.
session.commit(
f'wrote region_id ({region_id})', rebase_with=icechunk.ConflictDetector()
)
if self.debug:
console.log(f'Wrote dataset: {subset_ds} to region: {region_id}')
break
except Exception as exc:
delay = random.uniform(3.0, 10.0)
if self.debug:
console.log(f'Conflict detected while writing region {region_id}: {exc}')
console.log(f'retrying to write region_id: {region_id} in {delay:.2f}s')
time.sleep(delay)
pass
[docs]
def pretty_paths(self) -> None:
"""Pretty print key IcechunkConfig paths and URIs.
This version touches cached properties (e.g., uri, storage) to
surface real configuration and types.
"""
from rich.panel import Panel
from rich.table import Table
def nv(name: str, value: str | None):
return name, (str(value) if value not in (None, '') else '—')
rows: list[tuple[str, str]] = []
rows.append(nv('Environment', getattr(self.environment, 'value', str(self.environment))))
rows.append(nv('Version', (str(self.version) if self.version else '—')))
rows.append(nv('Storage root', self.storage_root))
rows.append(nv('Prefix', self.prefix))
# Touch real cached properties
uri = self.uri
rows.append(nv('Repository URI', str(uri)))
rows.append(nv('Protocol', uri.protocol or 'file'))
table = Table(title=None, show_header=True, header_style='bold magenta')
table.add_column('Icechunk setting', style='bold cyan', no_wrap=True)
table.add_column('Value', style='green')
for k, v in rows:
table.add_row(k, v)
console.print(Panel(table, title='IcechunkConfig paths', title_align='left'))
[docs]
@dataclass
class RegionIDStatus:
provided_region_ids: set[str]
valid_region_ids: set[str]
invalid_region_ids: set[str]
processed_region_ids: set[str]
previously_processed_ids: set[str]
unprocessed_valid_region_ids: set[str]
[docs]
class OCRConfig(pydantic_settings.BaseSettings):
"""Configuration settings for OCR processing."""
environment: Environment = pydantic.Field(
default=Environment.QA, description='Environment for OCR processing'
)
version: SemanticVersion | None = pydantic.Field(
default=None,
description=(
'Optional semantic version (e.g., 1.2.3 or v1.2.3). When provided, appended to '
'intermediate and output prefixes for versioned storage.'
),
)
storage_root: str = pydantic.Field(
..., description='Root storage path for OCR data, can be a bucket name or local path'
)
vector: VectorConfig | None = pydantic.Field(None, description='Vector configuration')
icechunk: IcechunkConfig | None = pydantic.Field(None, description='Icechunk configuration')
pyramid: PyramidConfig | None = pydantic.Field(None, description='Pyramid configuration')
chunking: ChunkingConfig | None = pydantic.Field(
None, description='Chunking configuration for OCR processing'
)
coiled: CoiledConfig | None = pydantic.Field(None, description='Coiled configuration')
debug: bool = pydantic.Field(False, description='Enable debugging mode')
model_config = {'env_prefix': 'ocr_', 'case_sensitive': False}
[docs]
def model_post_init(self, __context):
if self.vector is None:
object.__setattr__(
self,
'vector',
VectorConfig(
storage_root=self.storage_root,
environment=self.environment,
debug=self.debug,
version=self.version,
),
)
if self.icechunk is None:
object.__setattr__(
self,
'icechunk',
IcechunkConfig(
storage_root=self.storage_root,
environment=self.environment,
debug=self.debug,
version=self.version,
),
)
if self.pyramid is None:
object.__setattr__(
self,
'pyramid',
PyramidConfig(
storage_root=self.storage_root,
environment=self.environment,
debug=self.debug,
version=self.version,
),
)
if self.chunking is None:
object.__setattr__(
self,
'chunking',
ChunkingConfig(debug=self.debug),
)
if self.coiled is None:
object.__setattr__(
self,
'coiled',
CoiledConfig(),
)
[docs]
def pretty_paths(self) -> None:
"""Pretty print key OCRConfig paths and URIs.
This method intentionally touches cached properties that create
directories (e.g., via mkdir) so you can verify real locations.
"""
from rich.panel import Panel
from rich.table import Table
def nv(name: str, value: str | None):
return name, (str(value) if value not in (None, '') else '—')
rows: list[tuple[str, str]] = []
# high-level
rows.append(nv('Environment', getattr(self.environment, 'value', str(self.environment))))
rows.append(nv('Version', (str(self.version) if self.version else '—')))
rows.append(nv('Storage root', self.storage_root))
table = Table(title=None, show_header=True, header_style='bold magenta')
table.add_column('OCR setting', style='bold cyan', no_wrap=True)
table.add_column('Value', style='green')
for k, v in rows:
table.add_row(k, v)
console.print(Panel(table, title='OCRConfig paths', title_align='left'))
if self.vector:
self.vector.pretty_paths()
if self.icechunk:
self.icechunk.pretty_paths()
# ------------------------------------------------------------------
# Region ID selection / validation helpers (used by CLI pipeline)
# ------------------------------------------------------------------
def _compose_region_id_error(self, status: 'RegionIDStatus') -> str:
"""Compose a detailed error message mirroring previous CLI behavior.
Parameters
----------
status : RegionIDStatus
Computed status object.
"""
error_message = 'No valid region IDs to process. All provided region IDs were rejected for the following reasons:\n'
# Ensure required sub-config present (defensive; model_post_init guarantees this)
assert self.chunking is not None, 'Chunking configuration not initialized'
if status.invalid_region_ids:
error_message += (
f'- Invalid region IDs: {", ".join(sorted(status.invalid_region_ids))}\n'
)
# include (truncated) list of valid ids for reference
error_message += (
' Valid region IDs: '
f'{", ".join(sorted(list(self.chunking.valid_region_ids)))}...\n'
)
if status.previously_processed_ids:
error_message += (
'- Already processed region IDs: '
f'{", ".join(sorted(status.previously_processed_ids))}\n'
)
error_message += "\nPlease provide valid region IDs that haven't been processed yet."
return error_message
[docs]
def resolve_region_ids(
self, provided_region_ids: set[str], *, allow_all_processed: bool = False
) -> 'RegionIDStatus':
"""Validate provided region IDs against valid + processed sets.
Parameters
----------
provided_region_ids : set[str]
The set of region IDs to validate.
allow_all_processed : bool, optional
If True, don't raise an error when all regions are already processed.
This is useful for production reruns where you want to regenerate
vector outputs even if icechunk regions are complete. Default is False.
Returns
-------
RegionIDStatus
Status object with validation results.
Raises
------
ValueError
If no valid unprocessed region IDs remain and allow_all_processed is False.
"""
assert self.chunking is not None, 'Chunking configuration not initialized'
assert self.icechunk is not None, 'Icechunk configuration not initialized'
all_valid = set(self.chunking.valid_region_ids)
valid_region_ids = provided_region_ids.intersection(all_valid)
processed_region_ids = set(self.icechunk.processed_regions())
unprocessed_valid_region_ids = valid_region_ids.difference(processed_region_ids)
invalid_region_ids = provided_region_ids.difference(all_valid)
previously_processed_ids = provided_region_ids.intersection(processed_region_ids)
status = RegionIDStatus(
provided_region_ids=provided_region_ids,
valid_region_ids=valid_region_ids,
invalid_region_ids=invalid_region_ids,
processed_region_ids=processed_region_ids,
previously_processed_ids=previously_processed_ids,
unprocessed_valid_region_ids=unprocessed_valid_region_ids,
)
if len(unprocessed_valid_region_ids) == 0 and not allow_all_processed:
raise ValueError(self._compose_region_id_error(status))
return status
[docs]
def select_region_ids(
self,
region_ids: list[str] | None,
*,
all_region_ids: bool = False,
allow_all_processed: bool = False,
) -> 'RegionIDStatus':
"""Helper to pick the effective set of region IDs (all or user-provided) and
return the validated status object.
Parameters
----------
region_ids : list[str] | None
User-provided region IDs to process.
all_region_ids : bool, optional
If True, use all valid region IDs instead of user-provided ones. Default is False.
allow_all_processed : bool, optional
If True, don't raise an error when all regions are already processed.
Passed through to resolve_region_ids. Default is False.
Returns
-------
RegionIDStatus
Status object with validation results.
"""
assert self.chunking is not None, 'Chunking configuration not initialized'
provided = set(self.chunking.valid_region_ids) if all_region_ids else set(region_ids or [])
return self.resolve_region_ids(provided, allow_all_processed=allow_all_processed)
[docs]
def load_config(file_path: Path | None) -> OCRConfig:
"""Load OCR configuration from an env file (dotenv) or current environment."""
if file_path is None:
return OCRConfig()
dotenv.load_dotenv(file_path)
return OCRConfig()