Skip to content

Commit e5b7ae5

Browse files
authored
Merge pull request #17 from aidos-lab/spherical-grid-for-3D-directions
Spherical grid for 3 d directions
2 parents d8ffe60 + fd1139c commit e5b7ae5

2 files changed

Lines changed: 44 additions & 0 deletions

File tree

dect/directions.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,46 @@ def generate_multiview_directions(num_thetas: int, d: int):
109109
multiview_dirs.append(v)
110110

111111
return torch.hstack(multiview_dirs)
112+
113+
114+
def generate_spherical_grid_directions(num_thetas: int, num_phis: int, d: int = 3):
115+
"""
116+
Generates a smooth spherical grid of directions on the unit sphere in 3D using
117+
latitude–longitude (θ, φ) style sampling.
118+
119+
The directions are parameterized by θ (polar angle, [0, π]) and φ (azimuthal angle, [0, 2π)),
120+
and returned as a tensor of shape [3, num_thetas * num_phis], with each column a unit vector.
121+
122+
Parameters
123+
----------
124+
num_thetas: int
125+
Number of θ samples (from 0 to π, inclusive).
126+
num_phis: int
127+
Number of φ samples (from 0 to 2π, exclusive).
128+
d: int
129+
Must be 3, as spherical coordinates are for 3D.
130+
131+
Returns
132+
-------
133+
Tensor of shape [3, num_thetas * num_phis] containing unit vectors on the sphere.
134+
"""
135+
assert d == 3, "Spherical coordinates are only defined for d=3."
136+
137+
# Removes both poles.
138+
theta = torch.linspace(0, torch.pi, num_thetas + 2)[1:-1]
139+
140+
# Removes one endpoint.
141+
phi = torch.linspace(0, 2 * torch.pi, num_phis + 1)[
142+
:-1
143+
] # Induces endpoint=False behavior.
144+
145+
phi_grid, theta_grid = torch.meshgrid(
146+
phi, theta, indexing="ij"
147+
) # shape [num_phis, num_thetas]
148+
sin_theta = torch.sin(theta_grid)
149+
x = sin_theta * torch.cos(phi_grid)
150+
y = sin_theta * torch.sin(phi_grid)
151+
z = torch.cos(theta_grid)
152+
dirs = torch.stack([x, y, z], dim=0) # [3, num_phis, num_thetas]
153+
dirs = dirs.reshape(3, -1) # [3, num_thetas*num_phis]
154+
return dirs

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dependencies = [
1818
"pdoc>=15.0.1,<16",
1919
"torch-geometric>=2.6.1,<3",
2020
"geotorch>=0.3.0,<0.4",
21+
"pyvista>=0.46.3",
2122
]
2223

2324
[dependency-groups]

0 commit comments

Comments
 (0)