Я хочу изложить простой ответ с различными заметками о производительности. np.linalg.norm сделает, возможно, больше, чем вам нужно:
dist = numpy.linalg.norm(a-b)
Во-первых, эта функция предназначена для работы со списком и возврата всех значений, например, для сравнения расстояния pAдо набора точек sP:
sP = set(points)
pA = point
distances = np.linalg.norm(sP - pA, ord=2, axis=1.) # 'distances' is a list
Помните несколько вещей:
- Вызовы функций Python стоят дорого.
- [Обычный] Python не кэширует поиск имен.
Так
def distance(pointA, pointB):
dist = np.linalg.norm(pointA - pointB)
return dist
не так невинно, как кажется.
>>> dis.dis(distance)
2 0 LOAD_GLOBAL 0 (np)
2 LOAD_ATTR 1 (linalg)
4 LOAD_ATTR 2 (norm)
6 LOAD_FAST 0 (pointA)
8 LOAD_FAST 1 (pointB)
10 BINARY_SUBTRACT
12 CALL_FUNCTION 1
14 STORE_FAST 2 (dist)
3 16 LOAD_FAST 2 (dist)
18 RETURN_VALUE
Во-первых - каждый раз, когда мы вызываем его, мы должны выполнить глобальный поиск для «np», поиск в области видимости для «linalg» и поиск в области видимости для «norm» и накладные расходы, связанные с простым вызовом функции могут равняться десяткам Python инструкции.
Наконец, мы потратили две операции, чтобы сохранить результат и перезагрузить его для возврата ...
Первый шаг к улучшению: сделайте поиск быстрее, пропустите магазин
def distance(pointA, pointB, _norm=np.linalg.norm):
return _norm(pointA - pointB)
Мы получаем гораздо более упорядоченный:
>>> dis.dis(distance)
2 0 LOAD_FAST 2 (_norm)
2 LOAD_FAST 0 (pointA)
4 LOAD_FAST 1 (pointB)
6 BINARY_SUBTRACT
8 CALL_FUNCTION 1
10 RETURN_VALUE
Затраты на вызов функции по-прежнему составляют некоторую работу. И вы захотите сделать тесты, чтобы определить, лучше ли вам делать математику самостоятельно:
def distance(pointA, pointB):
return (
((pointA.x - pointB.x) ** 2) +
((pointA.y - pointB.y) ** 2) +
((pointA.z - pointB.z) ** 2)
) ** 0.5 # fast sqrt
На некоторых платформах **0.5это быстрее, чем math.sqrt. Ваш пробег может варьироваться.
**** Расширенные заметки производительности.
Почему вы рассчитываете расстояние? Если единственной целью является его отображение,
print("The target is %.2fm away" % (distance(a, b)))
двигаться вперед. Но если вы сравниваете расстояния, проводите проверки дальности и т. Д., Я хотел бы добавить некоторые полезные наблюдения за производительностью.
Давайте рассмотрим два случая: сортировка по расстоянию или отбор списка по элементам, которые соответствуют ограничению диапазона.
# Ultra naive implementations. Hold onto your hat.
def sort_things_by_distance(origin, things):
return things.sort(key=lambda thing: distance(origin, thing))
def in_range(origin, range, things):
things_in_range = []
for thing in things:
if distance(origin, thing) <= range:
things_in_range.append(thing)
Первое, что нам нужно помнить, это то, что мы используем Pythagoras для вычисления расстояния ( dist = sqrt(x^2 + y^2 + z^2)), поэтому мы делаем много sqrtвызовов. Математика 101:
dist = root ( x^2 + y^2 + z^2 )
:.
dist^2 = x^2 + y^2 + z^2
and
sq(N) < sq(M) iff M > N
and
sq(N) > sq(M) iff N > M
and
sq(N) = sq(M) iff N == M
Короче говоря: пока нам не потребуется расстояние в единице X, а не X ^ 2, мы можем исключить самую сложную часть вычислений.
# Still naive, but much faster.
def distance_sq(left, right):
""" Returns the square of the distance between left and right. """
return (
((left.x - right.x) ** 2) +
((left.y - right.y) ** 2) +
((left.z - right.z) ** 2)
)
def sort_things_by_distance(origin, things):
return things.sort(key=lambda thing: distance_sq(origin, thing))
def in_range(origin, range, things):
things_in_range = []
# Remember that sqrt(N)**2 == N, so if we square
# range, we don't need to root the distances.
range_sq = range**2
for thing in things:
if distance_sq(origin, thing) <= range_sq:
things_in_range.append(thing)
Отлично, обе функции больше не делают дорогих квадратных корней. Это будет намного быстрее. Мы также можем улучшить in_range, преобразовав его в генератор:
def in_range(origin, range, things):
range_sq = range**2
yield from (thing for thing in things
if distance_sq(origin, thing) <= range_sq)
Это особенно полезно, если вы делаете что-то вроде:
if any(in_range(origin, max_dist, things)):
...
Но если следующая вещь, которую вы собираетесь сделать, требует расстояния,
for nearby in in_range(origin, walking_distance, hotdog_stands):
print("%s %.2fm" % (nearby.name, distance(origin, nearby)))
рассмотреть возможность получения кортежей:
def in_range_with_dist_sq(origin, range, things):
range_sq = range**2
for thing in things:
dist_sq = distance_sq(origin, thing)
if dist_sq <= range_sq: yield (thing, dist_sq)
Это может быть особенно полезно, если вы можете связать проверки диапазона («найдите вещи, которые находятся около X и в пределах Nm от Y», так как вам не нужно снова вычислять расстояние).
Но что делать, если мы ищем действительно большой список thingsи ожидаем, что многие из них не заслуживают рассмотрения?
Там на самом деле очень простая оптимизация:
def in_range_all_the_things(origin, range, things):
range_sq = range**2
for thing in things:
dist_sq = (origin.x - thing.x) ** 2
if dist_sq <= range_sq:
dist_sq += (origin.y - thing.y) ** 2
if dist_sq <= range_sq:
dist_sq += (origin.z - thing.z) ** 2
if dist_sq <= range_sq:
yield thing
Будет ли это полезно, будет зависеть от размера «вещей».
def in_range_all_the_things(origin, range, things):
range_sq = range**2
if len(things) >= 4096:
for thing in things:
dist_sq = (origin.x - thing.x) ** 2
if dist_sq <= range_sq:
dist_sq += (origin.y - thing.y) ** 2
if dist_sq <= range_sq:
dist_sq += (origin.z - thing.z) ** 2
if dist_sq <= range_sq:
yield thing
elif len(things) > 32:
for things in things:
dist_sq = (origin.x - thing.x) ** 2
if dist_sq <= range_sq:
dist_sq += (origin.y - thing.y) ** 2 + (origin.z - thing.z) ** 2
if dist_sq <= range_sq:
yield thing
else:
... just calculate distance and range-check it ...
И снова, рассмотрите возможность выдачи dist_sq. Наш пример хот-дога становится:
# Chaining generators
info = in_range_with_dist_sq(origin, walking_distance, hotdog_stands)
info = (stand, dist_sq**0.5 for stand, dist_sq in info)
for stand, dist in info:
print("%s %.2fm" % (stand, dist))