Skip to content

Commit 23567eb

Browse files
ruff and black passing
1 parent 24c747b commit 23567eb

1 file changed

Lines changed: 26 additions & 10 deletions

File tree

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
from scipy.optimize import minimize
33

4+
45
def get_initial_conditions(N=5000):
56
G = 1
67
m = 1
@@ -9,32 +10,47 @@ def get_initial_conditions(N=5000):
910
if N <= 0:
1011
raise ValueError("N must be a positive integer.")
1112
if N % 3 != 0:
12-
N += (3 - (N % 3))
13+
N += 3 - (N % 3)
1314
dt = T / N
1415

15-
def shift(n, s): return (np.arange(n) + s) % n
16+
def shift(n, s):
17+
return (np.arange(n) + s) % n
1618

1719
def action(X_flat):
1820
X = X_flat.reshape((N, 2))
1921
body1 = X
20-
body2 = X[shift(N, N//3)]
21-
body3 = X[shift(N, 2*N//3)]
22+
body2 = X[shift(N, N // 3)]
23+
body3 = X[shift(N, 2 * N // 3)]
2224
vel = (X[shift(N, -1)] - X[shift(N, 1)]) / (2 * dt)
2325
K = 0.5 * m * np.sum(vel**2, axis=1) * 3
24-
def dist(a, b): return np.sqrt(np.sum((a - b)**2, axis=1) + 1e-12)
25-
U = -abs(G) * m**2 * (1 / dist(body1, body2) + 1 / dist(body1, body3) + 1 / dist(body2, body3))
26+
27+
def dist(a, b):
28+
return np.sqrt(np.sum((a - b) ** 2, axis=1) + 1e-12)
29+
30+
U = (
31+
-abs(G)
32+
* m**2
33+
* (1 / dist(body1, body2) + 1 / dist(body1, body3) + 1 / dist(body2, body3))
34+
)
2635
return np.sum((K - U) * dt)
2736

2837
a = 0.97000436
2938
b = 0.24308753
3039
theta = np.linspace(0, T, N, endpoint=False)
3140
X0 = np.column_stack([a * np.sin(theta), b * np.sin(2 * theta)])
32-
res = minimize(action, X0.ravel(), method='L-BFGS-B', options={'maxiter': 2000})
41+
res = minimize(action, X0.ravel(), method="L-BFGS-B", options={"maxiter": 2000})
3342
X_opt = res.x.reshape((N, 2))
34-
shift_indices = [0, N//3, 2*N//3]
43+
shift_indices = [0, N // 3, 2 * N // 3]
3544
positions = [X_opt[i] for i in shift_indices]
3645
vel = (X_opt[shift(N, -1)] - X_opt[shift(N, 1)]) / (2 * dt)
3746
velocities = [vel[i] for i in shift_indices]
38-
47+
3948
# Return 6 vectors: pos1, pos2, pos3, vel1, vel2, vel3
40-
return positions[0], positions[1], positions[2], velocities[0], velocities[1], velocities[2]
49+
return (
50+
positions[0],
51+
positions[1],
52+
positions[2],
53+
velocities[0],
54+
velocities[1],
55+
velocities[2],
56+
)

0 commit comments

Comments
 (0)