55import itertools
66
77import torch
8+ from torch import Tensor
89
910
10- def generate_uniform_directions (num_thetas : int , d : int , seed : int , device : str ):
11+ def generate_uniform_directions (
12+ num_thetas : int , d : int , seed : int , device : str
13+ ) -> Tensor :
1114 """
1215 Generate randomly sampled directions from a sphere in d dimensions.
1316
1417 A standard normal is sampled and projected onto the unit sphere to
1518 yield a randomly sampled set of points on the unit spere. Please
1619 note that the generated tensor has shape [d, num_thetas].
1720
18- Parameters
19- ----------
20- num_thetas: int
21- The number of directions to generate.
22- d: int
23- The dimension of the unit sphere. Default is 3 (hence R^3)
21+ Args:
22+ num_thetas:
23+ The number of directions to generate.
24+ d:
25+ The dimension of the unit sphere. Default is 3 (hence R^3)
26+ Returns:
27+ A set of directions.
2428 """
2529 g = torch .Generator (device = device ).manual_seed (seed )
2630 v = torch .randn (size = (d , num_thetas ), device = device , generator = g )
27- v /= v .pow (2 ).sum (axis = 0 ).sqrt ()
31+ v /= v .pow (2 ).sum (dim = 0 ).sqrt ()
2832 return v
2933
3034
31- def generate_2d_directions (num_thetas : int = 64 ):
35+ def generate_2d_directions (num_thetas : int = 64 ) -> Tensor :
3236 """
3337 Provides a structured set of directions in two dimensions. First the
3438 interval [0,2*pi] is devided into a regular grid and the corresponding
3539 angles on the unit circle calculated.
3640
37- Parameters
38- ----------
39- num_thetas: int
40- The number of directions to generate.
41+ Args:
42+ num_thetas:
43+ The number of directions to generate.
4144
42- Returns
43- ----------
44- v: Tensor
45+ Returns:
4546 Tensor of shape [2,num_thetas] containing the directions where each
4647 column is one direction in 2D.
4748 The directions start at $theta=0$ and runs to $theta = 2 * pi$.
@@ -57,7 +58,7 @@ def generate_2d_directions(num_thetas: int = 64):
5758 return v
5859
5960
60- def generate_multiview_directions (num_thetas : int , d : int ):
61+ def generate_multiview_directions (num_thetas : int , d : int ) -> Tensor :
6162 """
6263 Generates multiple sets of structured directions in n dimensions.
6364
@@ -71,12 +72,11 @@ def generate_multiview_directions(num_thetas: int, d: int):
7172 would obtain a 3 channel ect with direction sampled along the xy, xz and yz
7273 planes in three dimensions.
7374
74- Parameters
75- ----------
76- num_thetas: int
77- The number of directions to generate.
78- d: int
79- The dimension of the unit sphere. Default is 3 (hence R^3)
75+ Args:
76+ num_thetas:
77+ The number of directions to generate.
78+ d:
79+ The dimension of the unit sphere. Default is 3 (hence R^3)
8080 """
8181
8282 # We obtain n choose 2 channels.
@@ -111,26 +111,26 @@ def generate_multiview_directions(num_thetas: int, d: int):
111111 return torch .hstack (multiview_dirs )
112112
113113
114- def generate_spherical_grid_directions (num_thetas : int , num_phis : int , d : int = 3 ):
114+ def generate_spherical_grid_directions (
115+ num_thetas : int , num_phis : int , d : int = 3
116+ ) -> Tensor :
115117 """
116118 Generates a smooth spherical grid of directions on the unit sphere in 3D using
117119 latitude–longitude (θ, φ) style sampling.
118120
119121 The directions are parameterized by θ (polar angle, [0, π]) and φ (azimuthal angle, [0, 2π)),
120122 and returned as a tensor of shape [3, num_thetas * num_phis], with each column a unit vector.
121123
122- Parameters
123- ----------
124- num_thetas: int
125- Number of θ samples (from 0 to π, inclusive).
126- num_phis: int
127- Number of φ samples (from 0 to 2π, exclusive).
128- d: int
129- Must be 3, as spherical coordinates are for 3D.
130-
131- Returns
132- -------
133- Tensor of shape [3, num_thetas * num_phis] containing unit vectors on the sphere.
124+ Args:
125+ num_thetas:
126+ Number of θ samples (from 0 to π, inclusive).
127+ num_phis:
128+ Number of φ samples (from 0 to 2π, exclusive).
129+ d:
130+ Must be 3, as spherical coordinates are for 3D.
131+
132+ Returns:
133+ Tensor of shape [3, num_thetas * num_phis] containing unit vectors on the sphere.
134134 """
135135 assert d == 3 , "Spherical coordinates are only defined for d=3."
136136
0 commit comments