import numpy as np
from mip import Model, xsum, minimize, BINARY
def solve(tiles):
rotations = [[[] for _ in row] for row in tiles]
rot_vars = [[[] for _ in row] for row in tiles]
model = Model("rotaboxes")
for i, row in enumerate(tiles):
for j, tile in enumerate(row):
for _ in range(3):
tile = tile.transpose((1, 0, 2))[::-1]
rotations[i][j].append(tile)
rot_vars[i][j].append(model.add_var(var_type=BINARY))
model += (xsum(rot_vars[i][j]) == 1)
costs = []
for i, row in enumerate(tiles):
for j, tile in enumerate(row):
for (di, dj) in [(0, 1), (1, 0)]:
if (i + di >= len(tiles)) or (j + dj >= len(tiles[0])):
continue
for r0, t0 in enumerate(rotations[i][j]):
for r1, t1 in enumerate(rotations[i + di][j + dj]):
if di:
a, b = t0[-1], t1[0]
else:
a, b = t0[:,-1], t1[:,0]
n = np.square(a - b).mean()
d = np.square(a[1:] - a[:1]).mean() + np.square(b[1:] - b[:1]).mean()
e = 2. * n / (d + 0.01)
v = model.add_var(var_type=BINARY)
model += (v >= rot_vars[i][j][r0] + rot_vars[i + di][j + dj][r1] - 1)
costs.append(e * v)
model.objective = minimize(xsum(costs))
model.optimize()
solution = [[[r for v, r in zip(vs, rs) if v.x >= 0.99][0] for vs, rs in zip(row0, row1)] for row0, row1 in zip(rot_vars, rotations)]
return solution